Complete batch import v2 runtime and host capability recovery
This commit is contained in:
@@ -20,6 +20,10 @@ type ConfirmationItemStore interface {
|
||||
Upsert(ctx context.Context, item sqlite.ImportRunItem) error
|
||||
}
|
||||
|
||||
type ConfirmationLeaseClaimer interface {
|
||||
TryAcquireLease(ctx context.Context, itemID, workerID string, now time.Time, leaseDuration time.Duration) (sqlite.ImportRunItem, bool, error)
|
||||
}
|
||||
|
||||
type ConfirmationEventStore interface {
|
||||
Append(ctx context.Context, event sqlite.ImportRunItemEvent) error
|
||||
}
|
||||
@@ -53,6 +57,16 @@ func (w ConfirmationWorker) Tick(ctx context.Context, now time.Time) error {
|
||||
if !isConfirmationCandidate(item, now) {
|
||||
continue
|
||||
}
|
||||
if claimer, ok := w.ItemStore.(ConfirmationLeaseClaimer); ok {
|
||||
claimedItem, claimed, err := claimer.TryAcquireLease(ctx, item.ItemID, w.WorkerID, now, w.LeaseDuration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !claimed {
|
||||
continue
|
||||
}
|
||||
item = claimedItem
|
||||
}
|
||||
if err := w.ConfirmItem(ctx, item, now); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -66,9 +80,11 @@ func (w ConfirmationWorker) ConfirmItem(ctx context.Context, item sqlite.ImportR
|
||||
return err
|
||||
}
|
||||
|
||||
item.ConfirmationAttempts++
|
||||
item.LeaseOwner = strings.TrimSpace(w.WorkerID)
|
||||
item.LeaseUntil = now.Add(defaultDuration(w.LeaseDuration, time.Minute)).Format(time.RFC3339)
|
||||
if strings.TrimSpace(item.LeaseOwner) == "" {
|
||||
item.ConfirmationAttempts++
|
||||
item.LeaseOwner = strings.TrimSpace(w.WorkerID)
|
||||
item.LeaseUntil = now.Add(defaultDuration(w.LeaseDuration, time.Minute)).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
switch {
|
||||
case result.StatusCode >= 200 && result.StatusCode < 300:
|
||||
|
||||
@@ -3,6 +3,8 @@ package batch
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -160,6 +162,50 @@ func TestConfirmationWorker(t *testing.T) {
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("concurrent workers do not both call confirmer before lease is persisted", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
now := time.Date(2026, 5, 22, 13, 3, 30, 0, time.UTC)
|
||||
store := newFakeConfirmationStore([]sqlite.ImportRunItem{
|
||||
{ItemID: "shared", RunID: "run-1", CurrentStage: "confirm", ConfirmationStatus: "pending"},
|
||||
})
|
||||
|
||||
started := make(chan struct{}, 2)
|
||||
release := make(chan struct{})
|
||||
var calls atomic.Int32
|
||||
|
||||
confirmer := func(ctx context.Context, item sqlite.ImportRunItem) (ConfirmationResult, error) {
|
||||
calls.Add(1)
|
||||
started <- struct{}{}
|
||||
<-release
|
||||
return ConfirmationResult{StatusCode: 200}, nil
|
||||
}
|
||||
|
||||
workerA := ConfirmationWorker{WorkerID: "worker-a", ItemStore: store, EventStore: store, LeaseDuration: time.Minute, RetryDelay: time.Second, Confirmer: confirmer}
|
||||
workerB := ConfirmationWorker{WorkerID: "worker-b", ItemStore: store, EventStore: store, LeaseDuration: time.Minute, RetryDelay: time.Second, Confirmer: confirmer}
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
go func() { errCh <- workerA.Tick(context.Background(), now) }()
|
||||
go func() { errCh <- workerB.Tick(context.Background(), now) }()
|
||||
|
||||
<-started
|
||||
select {
|
||||
case <-started:
|
||||
t.Fatal("second worker reached confirmer before lease was acquired")
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(release)
|
||||
for range 2 {
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("Tick() error = %v", err)
|
||||
}
|
||||
}
|
||||
if got := calls.Load(); got != 1 {
|
||||
t.Fatalf("confirmer calls = %d, want 1", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("reactivated account metadata is preserved", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -200,6 +246,7 @@ func TestConfirmationWorker(t *testing.T) {
|
||||
}
|
||||
|
||||
type fakeConfirmationStore struct {
|
||||
mu sync.Mutex
|
||||
items map[string]sqlite.ImportRunItem
|
||||
processed []string
|
||||
events []sqlite.ImportRunItemEvent
|
||||
@@ -217,6 +264,9 @@ func newFakeConfirmationStore(items []sqlite.ImportRunItem) *fakeConfirmationSto
|
||||
}
|
||||
|
||||
func (f *fakeConfirmationStore) List(ctx context.Context) ([]sqlite.ImportRunItem, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
items := make([]sqlite.ImportRunItem, 0, len(f.items))
|
||||
for _, item := range f.items {
|
||||
items = append(items, item)
|
||||
@@ -225,12 +275,36 @@ func (f *fakeConfirmationStore) List(ctx context.Context) ([]sqlite.ImportRunIte
|
||||
}
|
||||
|
||||
func (f *fakeConfirmationStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
f.items[item.ItemID] = item
|
||||
f.processed = append(f.processed, item.ItemID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeConfirmationStore) TryAcquireLease(ctx context.Context, itemID, workerID string, now time.Time, leaseDuration time.Duration) (sqlite.ImportRunItem, bool, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
item, ok := f.items[itemID]
|
||||
if !ok {
|
||||
return sqlite.ImportRunItem{}, false, nil
|
||||
}
|
||||
if !isConfirmationCandidate(item, now) {
|
||||
return sqlite.ImportRunItem{}, false, nil
|
||||
}
|
||||
item.ConfirmationAttempts++
|
||||
item.LeaseOwner = workerID
|
||||
item.LeaseUntil = now.Add(leaseDuration).Format(time.RFC3339)
|
||||
f.items[itemID] = item
|
||||
return item, true, nil
|
||||
}
|
||||
|
||||
func (f *fakeConfirmationStore) Append(ctx context.Context, event sqlite.ImportRunItemEvent) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
f.events = append(f.events, event)
|
||||
return nil
|
||||
}
|
||||
@@ -238,6 +312,9 @@ func (f *fakeConfirmationStore) Append(ctx context.Context, event sqlite.ImportR
|
||||
func (f *fakeConfirmationStore) mustItem(t *testing.T, itemID string) sqlite.ImportRunItem {
|
||||
t.Helper()
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
item, ok := f.items[itemID]
|
||||
if !ok {
|
||||
t.Fatalf("item %q not found", itemID)
|
||||
|
||||
@@ -19,12 +19,15 @@ type BatchImportEntry struct {
|
||||
}
|
||||
|
||||
type BatchImportRunRequest struct {
|
||||
RunID string
|
||||
Mode string
|
||||
AccessMode string
|
||||
HostID string
|
||||
HostBaseURL string
|
||||
Entries []BatchImportEntry
|
||||
RunID string
|
||||
Mode string
|
||||
AccessMode string
|
||||
HostID string
|
||||
HostBaseURL string
|
||||
SubscriptionUsers []string
|
||||
SubscriptionDays int
|
||||
ProbeAPIKey string
|
||||
Entries []BatchImportEntry
|
||||
}
|
||||
|
||||
type BatchImportRunResult struct {
|
||||
@@ -34,6 +37,7 @@ type BatchImportRunResult struct {
|
||||
|
||||
type RunStateStore interface {
|
||||
Create(ctx context.Context, run sqlite.ImportRun) error
|
||||
Update(ctx context.Context, run sqlite.ImportRun) error
|
||||
}
|
||||
|
||||
type ItemStateStore interface {
|
||||
@@ -114,11 +118,15 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
}
|
||||
|
||||
if err := s.RunStore.Create(ctx, sqlite.ImportRun{
|
||||
RunID: runID,
|
||||
Mode: strings.TrimSpace(req.Mode),
|
||||
AccessMode: strings.TrimSpace(req.AccessMode),
|
||||
State: string(RunStateRunning),
|
||||
TotalItems: len(req.Entries),
|
||||
RunID: runID,
|
||||
HostID: strings.TrimSpace(req.HostID),
|
||||
Mode: strings.TrimSpace(req.Mode),
|
||||
AccessMode: strings.TrimSpace(req.AccessMode),
|
||||
SubscriptionUsersJSON: mustMarshalJSON(req.SubscriptionUsers, "[]"),
|
||||
SubscriptionDays: req.SubscriptionDays,
|
||||
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
|
||||
State: string(RunStateRunning),
|
||||
TotalItems: len(req.Entries),
|
||||
}); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
}
|
||||
@@ -152,17 +160,26 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
|
||||
modelsResult, err := s.ProbeModels(ctx, entry.BaseURL, entry.APIKey)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
rawModels := append([]string(nil), modelsResult.RawModels...)
|
||||
capabilityProfile, err := s.ProbeCapabilities(ctx, entry.BaseURL, entry.APIKey, rawModels)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
routingStrategy := BuildImportRoutingStrategy(capabilityProfile)
|
||||
resolvedSmokeModel, recommendedModels, err := probe.ResolveSmokeModel(entry.RequestedModels, rawModels, capabilityProfile)
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
canonicalFamilies := uniqueCanonicalFamilies(rawModels)
|
||||
@@ -176,7 +193,10 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
CanonicalModelFamilies: canonicalFamilies,
|
||||
})
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProbe, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -217,18 +237,27 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
patchContract := ModelMappingDelta(reuseLookup.ExistingModelMapping, probe.BuildAliasTable(rawModels))
|
||||
if shouldPatchAliases(reuseLookup.ExistingModelMapping, patchContract.ModelMapping) {
|
||||
if s.Provisioner == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("provisioner is required for patch-only flow")
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, fmt.Errorf("provisioner is required for patch-only flow")); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
if err := s.Provisioner.Patch(ctx, PatchProvisionRequest{
|
||||
ProviderID: reuseDecision.ReusedFromProviderID,
|
||||
Contract: patchContract,
|
||||
}); err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if s.Provisioner == nil {
|
||||
return BatchImportRunResult{}, fmt.Errorf("provisioner is required")
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, fmt.Errorf("provisioner is required")); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
provisionResult, err := s.Provisioner.Provision(ctx, ProvisionRequest{
|
||||
RunID: runID,
|
||||
@@ -240,7 +269,10 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
CapabilityProfile: capabilityProfile,
|
||||
})
|
||||
if err != nil {
|
||||
return BatchImportRunResult{}, err
|
||||
if failErr := s.failRun(ctx, req, initialItem, ItemStageProvision, err); failErr != nil {
|
||||
return BatchImportRunResult{}, failErr
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
finalItem.LegacyBatchID = provisionResult.LegacyBatchID
|
||||
finalItem.LegacyProviderID = strings.TrimSpace(provisionResult.LegacyProviderID)
|
||||
@@ -254,6 +286,35 @@ func (s BatchImportService) StartRun(ctx context.Context, req BatchImportRunRequ
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s BatchImportService) failRun(ctx context.Context, req BatchImportRunRequest, item sqlite.ImportRunItem, stage ItemStage, cause error) error {
|
||||
item.CurrentStage = string(ItemStageDone)
|
||||
item.ConfirmationStatus = string(ConfirmationFailed)
|
||||
item.AccessStatus = string(AccessStatusBroken)
|
||||
item.LastErrorStage = string(stage)
|
||||
item.LastError = strings.TrimSpace(cause.Error())
|
||||
item.LeaseOwner = ""
|
||||
item.LeaseUntil = ""
|
||||
item.NextRetryAt = ""
|
||||
if err := s.ItemStore.Upsert(ctx, item); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.RunStore.Update(ctx, sqlite.ImportRun{
|
||||
RunID: strings.TrimSpace(req.RunID),
|
||||
HostID: strings.TrimSpace(req.HostID),
|
||||
Mode: strings.TrimSpace(req.Mode),
|
||||
AccessMode: strings.TrimSpace(req.AccessMode),
|
||||
SubscriptionUsersJSON: mustMarshalJSON(req.SubscriptionUsers, "[]"),
|
||||
SubscriptionDays: req.SubscriptionDays,
|
||||
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
|
||||
State: string(RunStateFailed),
|
||||
TotalItems: len(req.Entries),
|
||||
CompletedItems: 1,
|
||||
BrokenItems: 1,
|
||||
FinishedAt: time.Now().UTC().Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
func uniqueCanonicalFamilies(rawModels []string) []string {
|
||||
seen := make(map[string]struct{}, len(rawModels))
|
||||
families := make([]string, 0, len(rawModels))
|
||||
|
||||
@@ -115,6 +115,7 @@ func TestBatchImport_StartRun(t *testing.T) {
|
||||
RunID: "run-2",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
HostID: "host-1",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
@@ -170,6 +171,7 @@ func TestBatchImport_StartRun(t *testing.T) {
|
||||
RunID: "run-3",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
HostID: "host-1",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
@@ -221,6 +223,7 @@ func TestBatchImport_StartRun(t *testing.T) {
|
||||
RunID: "run-4",
|
||||
Mode: "strict",
|
||||
AccessMode: "subscription",
|
||||
HostID: "host-1",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.kimi.com/v1", APIKey: "sk-live", RequestedModels: []string{"kimi 2.6"}},
|
||||
},
|
||||
@@ -243,10 +246,71 @@ func TestBatchImport_StartRun(t *testing.T) {
|
||||
t.Fatal("ProvisionReused = false, want true for patch-only flow")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("probe failure marks run failed instead of leaving running half state", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
runStore := &fakeRunStore{}
|
||||
itemStore := &fakeItemStore{}
|
||||
service := BatchImportService{
|
||||
RunStore: runStore,
|
||||
ItemStore: itemStore,
|
||||
ProbeModels: func(context.Context, string, string) (*probe.ModelsResult, error) {
|
||||
return nil, context.DeadlineExceeded
|
||||
},
|
||||
ProbeCapabilities: func(context.Context, string, string, []string) (*probe.CapabilityProfile, error) {
|
||||
t.Fatal("ProbeCapabilities should not be called after probe failure")
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
|
||||
result, err := service.StartRun(context.Background(), BatchImportRunRequest{
|
||||
RunID: "run-probe-fail",
|
||||
Mode: "strict",
|
||||
AccessMode: "self_service",
|
||||
HostID: "host-1",
|
||||
Entries: []BatchImportEntry{
|
||||
{BaseURL: "https://api.deepseek.com/v1", APIKey: "sk-live", RequestedModels: []string{"DeepSeek V4 Pro"}},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartRun() error = %v, want persisted failed run without transport error", err)
|
||||
}
|
||||
if result.RunID != "run-probe-fail" {
|
||||
t.Fatalf("result.RunID = %q, want run-probe-fail", result.RunID)
|
||||
}
|
||||
if len(runStore.updated) == 0 {
|
||||
t.Fatal("run store was not updated to failed state")
|
||||
}
|
||||
gotRun := runStore.updated[len(runStore.updated)-1]
|
||||
if gotRun.State != string(RunStateFailed) {
|
||||
t.Fatalf("run.State = %q, want failed", gotRun.State)
|
||||
}
|
||||
if gotRun.CompletedItems != 1 || gotRun.BrokenItems != 1 {
|
||||
t.Fatalf("run counters = %+v, want completed_items=1 broken_items=1", gotRun)
|
||||
}
|
||||
if len(itemStore.upserts) < 2 {
|
||||
t.Fatalf("item upserts = %d, want initial + failed terminal state", len(itemStore.upserts))
|
||||
}
|
||||
gotItem := itemStore.upserts[len(itemStore.upserts)-1]
|
||||
if gotItem.CurrentStage != string(ItemStageDone) {
|
||||
t.Fatalf("item.CurrentStage = %q, want done", gotItem.CurrentStage)
|
||||
}
|
||||
if gotItem.ConfirmationStatus != string(ConfirmationFailed) {
|
||||
t.Fatalf("item.ConfirmationStatus = %q, want failed", gotItem.ConfirmationStatus)
|
||||
}
|
||||
if gotItem.AccessStatus != string(AccessStatusBroken) {
|
||||
t.Fatalf("item.AccessStatus = %q, want broken", gotItem.AccessStatus)
|
||||
}
|
||||
if gotItem.LastErrorStage != string(ItemStageProbe) {
|
||||
t.Fatalf("item.LastErrorStage = %q, want probe", gotItem.LastErrorStage)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type fakeRunStore struct {
|
||||
created []sqlite.ImportRun
|
||||
updated []sqlite.ImportRun
|
||||
}
|
||||
|
||||
func (f *fakeRunStore) Create(ctx context.Context, run sqlite.ImportRun) error {
|
||||
@@ -254,6 +318,11 @@ func (f *fakeRunStore) Create(ctx context.Context, run sqlite.ImportRun) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeRunStore) Update(ctx context.Context, run sqlite.ImportRun) error {
|
||||
f.updated = append(f.updated, run)
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeItemStore struct {
|
||||
upserts []sqlite.ImportRunItem
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user