diff --git a/internal/store/migrations/0007_batch_import_runs.sql b/internal/store/migrations/0007_batch_import_runs.sql index 6105db03..734a4d10 100644 --- a/internal/store/migrations/0007_batch_import_runs.sql +++ b/internal/store/migrations/0007_batch_import_runs.sql @@ -26,15 +26,22 @@ CREATE TABLE import_run_items ( run_id TEXT NOT NULL, base_url TEXT NOT NULL, provider_id TEXT NOT NULL, + api_key_fingerprint TEXT NOT NULL, requested_models_json TEXT NOT NULL DEFAULT '[]', raw_models_json TEXT NOT NULL DEFAULT '[]', normalized_models_json TEXT NOT NULL DEFAULT '[]', + canonical_model_families_json TEXT NOT NULL DEFAULT '[]', recommended_models_json TEXT NOT NULL DEFAULT '[]', resolved_smoke_model TEXT NULL, capability_profile_json TEXT NOT NULL DEFAULT '{}', current_stage TEXT NOT NULL, confirmation_status TEXT NOT NULL, access_status TEXT NOT NULL, + matched_account_state TEXT NOT NULL, + account_resolution TEXT NOT NULL, + provision_reused INTEGER NOT NULL DEFAULT 0, + reused_from_provider_id TEXT NULL, + reused_from_account_id INTEGER NULL, channel_id INTEGER NULL, account_id INTEGER NULL, retry_count INTEGER NOT NULL DEFAULT 0, @@ -53,11 +60,15 @@ CREATE TABLE import_run_items ( CONSTRAINT fk_import_run_items_run FOREIGN KEY (run_id) REFERENCES import_runs(run_id) ON DELETE CASCADE, CHECK (current_stage IN ('probe', 'provision', 'confirm', 'validate', 'done')), CHECK (confirmation_status IN ('pending', 'confirmed', 'advisory', 'failed')), - CHECK (access_status IN ('unknown', 'active', 'degraded', 'broken')) + CHECK (access_status IN ('unknown', 'active', 'degraded', 'broken')), + CHECK (matched_account_state IN ('none', 'active', 'disabled', 'deprecated', 'broken')), + CHECK (account_resolution IN ('created', 'reused', 'reactivated', 'replaced')), + CHECK (provision_reused IN (0, 1)) ); CREATE INDEX idx_import_run_items_run_id ON import_run_items(run_id); CREATE INDEX idx_import_run_items_provider_id ON import_run_items(provider_id); +CREATE INDEX idx_import_run_items_api_key_fingerprint ON import_run_items(api_key_fingerprint); CREATE INDEX idx_import_run_items_current_stage ON import_run_items(current_stage); CREATE INDEX idx_import_run_items_confirmation_status ON import_run_items(confirmation_status); CREATE INDEX idx_import_run_items_access_status ON import_run_items(access_status); diff --git a/internal/store/migrations/0008_batch_import_run_events.sql b/internal/store/migrations/0008_batch_import_run_events.sql index d82689e1..bfd24186 100644 --- a/internal/store/migrations/0008_batch_import_run_events.sql +++ b/internal/store/migrations/0008_batch_import_run_events.sql @@ -15,5 +15,6 @@ CREATE TABLE import_run_item_events ( CREATE INDEX idx_import_run_item_events_run_id ON import_run_item_events(run_id); CREATE INDEX idx_import_run_item_events_item_id ON import_run_item_events(item_id); CREATE INDEX idx_import_run_item_events_created_at ON import_run_item_events(created_at); +CREATE INDEX idx_import_run_item_events_item_created_at ON import_run_item_events(item_id, created_at); CREATE INDEX idx_import_run_item_events_stage ON import_run_item_events(stage); CREATE INDEX idx_import_run_item_events_type ON import_run_item_events(event_type); diff --git a/internal/store/sqlite/db.go b/internal/store/sqlite/db.go index b6ca4611..97f155ec 100644 --- a/internal/store/sqlite/db.go +++ b/internal/store/sqlite/db.go @@ -106,6 +106,10 @@ func (db *DB) ImportRunEvents() *ImportRunItemEventsRepo { return db.queries.ImportRunEvents } +func (db *DB) ImportRunItemEvents() *ImportRunItemEventsRepo { + return db.queries.ImportRunEvents +} + func (db *DB) ManagedResources() *ManagedResourcesRepo { return db.queries.ManagedResources } diff --git a/internal/store/sqlite/import_run_item_events_repo.go b/internal/store/sqlite/import_run_item_events_repo.go new file mode 100644 index 00000000..820c4f03 --- /dev/null +++ b/internal/store/sqlite/import_run_item_events_repo.go @@ -0,0 +1,88 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +type ImportRunItemEvent struct { + EventID string + RunID string + ItemID string + EventType string + Stage string + Attempt int + Message string + PayloadJSON string + CreatedAt string +} + +type ImportRunItemEventsRepo struct { + db execQuerier +} + +func newImportRunItemEventsRepo(db execQuerier) *ImportRunItemEventsRepo { + return &ImportRunItemEventsRepo{db: db} +} + +func (r *ImportRunItemEventsRepo) Create(ctx context.Context, event ImportRunItemEvent) error { + return r.Append(ctx, event) +} + +func (r *ImportRunItemEventsRepo) Append(ctx context.Context, event ImportRunItemEvent) error { + eventID := strings.TrimSpace(event.EventID) + runID := strings.TrimSpace(event.RunID) + itemID := strings.TrimSpace(event.ItemID) + eventType := strings.TrimSpace(event.EventType) + stage := strings.TrimSpace(event.Stage) + message := strings.TrimSpace(event.Message) + payloadJSON := defaultJSON(event.PayloadJSON, "{}") + + switch { + case eventID == "": + return fmt.Errorf("event_id is required") + case runID == "": + return fmt.Errorf("run_id is required") + case itemID == "": + return fmt.Errorf("item_id is required") + case eventType == "": + return fmt.Errorf("event_type is required") + case stage == "": + return fmt.Errorf("stage is required") + case message == "": + return fmt.Errorf("message is required") + } + + if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_item_events (event_id, run_id, item_id, event_type, stage, attempt, message, payload_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + eventID, runID, itemID, eventType, stage, event.Attempt, message, payloadJSON); err != nil { + return fmt.Errorf("insert import run item event %q: %w", eventID, err) + } + return nil +} + +func (r *ImportRunItemEventsRepo) ListByItemID(ctx context.Context, itemID string) ([]ImportRunItemEvent, error) { + itemID = strings.TrimSpace(itemID) + if itemID == "" { + return nil, fmt.Errorf("item_id is required") + } + + rows, err := r.db.QueryContext(ctx, `SELECT event_id, run_id, item_id, event_type, stage, attempt, message, payload_json, created_at FROM import_run_item_events WHERE item_id = ? ORDER BY created_at, event_id`, itemID) + if err != nil { + return nil, fmt.Errorf("list import run item events by item_id %q: %w", itemID, err) + } + defer rows.Close() + + events := make([]ImportRunItemEvent, 0) + for rows.Next() { + var event ImportRunItemEvent + if err := rows.Scan(&event.EventID, &event.RunID, &event.ItemID, &event.EventType, &event.Stage, &event.Attempt, &event.Message, &event.PayloadJSON, &event.CreatedAt); err != nil { + return nil, fmt.Errorf("scan import run item event: %w", err) + } + events = append(events, event) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate import run item events by item_id %q: %w", itemID, err) + } + return events, nil +} diff --git a/internal/store/sqlite/import_run_items_repo.go b/internal/store/sqlite/import_run_items_repo.go new file mode 100644 index 00000000..d3ed5528 --- /dev/null +++ b/internal/store/sqlite/import_run_items_repo.go @@ -0,0 +1,252 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +type ImportRunItem struct { + ItemID string + RunID string + BaseURL string + ProviderID string + APIKeyFingerprint string + RequestedModelsJSON string + RawModelsJSON string + NormalizedModelsJSON string + CanonicalFamiliesJSON string + RecommendedModelsJSON string + ResolvedSmokeModel string + CapabilityProfileJSON string + CurrentStage string + ConfirmationStatus string + AccessStatus string + MatchedAccountState string + AccountResolution string + ProvisionReused bool + ReusedFromProviderID string + ReusedFromAccountID *int64 + ChannelID *int64 + AccountID *int64 + RetryCount int + ConfirmationAttempts int + LastRetryAt string + NextRetryAt string + LeaseOwner string + LeaseUntil string + AdvisoryMessagesJSON string + LastErrorStage string + LastError string + LegacyBatchID *int64 + LegacyProviderID string + CreatedAt string + UpdatedAt string +} + +type ImportRunItemsRepo struct { + db execQuerier +} + +func newImportRunItemsRepo(db execQuerier) *ImportRunItemsRepo { + return &ImportRunItemsRepo{db: db} +} + +func (r *ImportRunItemsRepo) Create(ctx context.Context, item ImportRunItem) error { + return r.Upsert(ctx, item) +} + +func (r *ImportRunItemsRepo) Update(ctx context.Context, item ImportRunItem) error { + return r.Upsert(ctx, item) +} + +func (r *ImportRunItemsRepo) Upsert(ctx context.Context, item ImportRunItem) error { + itemID := strings.TrimSpace(item.ItemID) + runID := strings.TrimSpace(item.RunID) + baseURL := strings.TrimSpace(item.BaseURL) + providerID := strings.TrimSpace(item.ProviderID) + apiKeyFingerprint := strings.TrimSpace(item.APIKeyFingerprint) + currentStage := strings.TrimSpace(item.CurrentStage) + confirmationStatus := strings.TrimSpace(item.ConfirmationStatus) + accessStatus := strings.TrimSpace(item.AccessStatus) + matchedAccountState := strings.TrimSpace(item.MatchedAccountState) + accountResolution := strings.TrimSpace(item.AccountResolution) + + switch { + case itemID == "": + return fmt.Errorf("item_id is required") + case runID == "": + return fmt.Errorf("run_id is required") + case baseURL == "": + return fmt.Errorf("base_url is required") + case providerID == "": + return fmt.Errorf("provider_id is required") + case apiKeyFingerprint == "": + return fmt.Errorf("api_key_fingerprint is required") + case currentStage == "": + return fmt.Errorf("current_stage is required") + case confirmationStatus == "": + return fmt.Errorf("confirmation_status is required") + case accessStatus == "": + return fmt.Errorf("access_status is required") + case matchedAccountState == "": + return fmt.Errorf("matched_account_state is required") + case accountResolution == "": + return fmt.Errorf("account_resolution is required") + } + + if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_items ( + item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json, + canonical_model_families_json, recommended_models_json, resolved_smoke_model, capability_profile_json, current_stage, + confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, reused_from_provider_id, + reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, last_retry_at, next_retry_at, + lease_owner, lease_until, advisory_messages_json, last_error_stage, last_error, legacy_batch_id, legacy_provider_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(item_id) DO UPDATE SET + run_id = excluded.run_id, + base_url = excluded.base_url, + provider_id = excluded.provider_id, + api_key_fingerprint = excluded.api_key_fingerprint, + requested_models_json = excluded.requested_models_json, + raw_models_json = excluded.raw_models_json, + normalized_models_json = excluded.normalized_models_json, + canonical_model_families_json = excluded.canonical_model_families_json, + recommended_models_json = excluded.recommended_models_json, + resolved_smoke_model = excluded.resolved_smoke_model, + capability_profile_json = excluded.capability_profile_json, + current_stage = excluded.current_stage, + confirmation_status = excluded.confirmation_status, + access_status = excluded.access_status, + matched_account_state = excluded.matched_account_state, + account_resolution = excluded.account_resolution, + provision_reused = excluded.provision_reused, + reused_from_provider_id = excluded.reused_from_provider_id, + reused_from_account_id = excluded.reused_from_account_id, + channel_id = excluded.channel_id, + account_id = excluded.account_id, + retry_count = excluded.retry_count, + confirmation_attempts = excluded.confirmation_attempts, + last_retry_at = excluded.last_retry_at, + next_retry_at = excluded.next_retry_at, + lease_owner = excluded.lease_owner, + lease_until = excluded.lease_until, + advisory_messages_json = excluded.advisory_messages_json, + last_error_stage = excluded.last_error_stage, + last_error = excluded.last_error, + legacy_batch_id = excluded.legacy_batch_id, + legacy_provider_id = excluded.legacy_provider_id, + updated_at = CURRENT_TIMESTAMP`, + itemID, + runID, + baseURL, + providerID, + apiKeyFingerprint, + defaultJSON(item.RequestedModelsJSON, "[]"), + defaultJSON(item.RawModelsJSON, "[]"), + defaultJSON(item.NormalizedModelsJSON, "[]"), + defaultJSON(item.CanonicalFamiliesJSON, "[]"), + defaultJSON(item.RecommendedModelsJSON, "[]"), + nullableString(strings.TrimSpace(item.ResolvedSmokeModel)), + defaultJSON(item.CapabilityProfileJSON, "{}"), + currentStage, + confirmationStatus, + accessStatus, + matchedAccountState, + accountResolution, + boolToInt(item.ProvisionReused), + nullableString(strings.TrimSpace(item.ReusedFromProviderID)), + item.ReusedFromAccountID, + item.ChannelID, + item.AccountID, + item.RetryCount, + item.ConfirmationAttempts, + nullableString(strings.TrimSpace(item.LastRetryAt)), + nullableString(strings.TrimSpace(item.NextRetryAt)), + nullableString(strings.TrimSpace(item.LeaseOwner)), + nullableString(strings.TrimSpace(item.LeaseUntil)), + defaultJSON(item.AdvisoryMessagesJSON, "[]"), + nullableString(strings.TrimSpace(item.LastErrorStage)), + nullableString(strings.TrimSpace(item.LastError)), + item.LegacyBatchID, + nullableString(strings.TrimSpace(item.LegacyProviderID)), + ); err != nil { + return fmt.Errorf("upsert import run item %q: %w", itemID, err) + } + return nil +} + +func (r *ImportRunItemsRepo) GetByItemID(ctx context.Context, itemID string) (ImportRunItem, error) { + itemID = strings.TrimSpace(itemID) + if itemID == "" { + return ImportRunItem{}, fmt.Errorf("item_id is required") + } + + var item ImportRunItem + var reusedFromAccountID sql.NullInt64 + var channelID sql.NullInt64 + var accountID sql.NullInt64 + var legacyBatchID sql.NullInt64 + var provisionReused int + if err := r.db.QueryRowContext(ctx, `SELECT item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json, canonical_model_families_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, COALESCE(reused_from_provider_id, ''), reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE item_id = ?`, itemID). + Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.APIKeyFingerprint, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.CanonicalFamiliesJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &item.MatchedAccountState, &item.AccountResolution, &provisionReused, &item.ReusedFromProviderID, &reusedFromAccountID, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { + return ImportRunItem{}, err + } + item.ProvisionReused = provisionReused == 1 + item.ReusedFromAccountID = ptrFromNullInt64(reusedFromAccountID) + item.ChannelID = ptrFromNullInt64(channelID) + item.AccountID = ptrFromNullInt64(accountID) + item.LegacyBatchID = ptrFromNullInt64(legacyBatchID) + return item, nil +} + +func (r *ImportRunItemsRepo) ListByRunID(ctx context.Context, runID string) ([]ImportRunItem, error) { + runID = strings.TrimSpace(runID) + if runID == "" { + return nil, fmt.Errorf("run_id is required") + } + + rows, err := r.db.QueryContext(ctx, `SELECT item_id, run_id, base_url, provider_id, api_key_fingerprint, requested_models_json, raw_models_json, normalized_models_json, canonical_model_families_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, matched_account_state, account_resolution, provision_reused, COALESCE(reused_from_provider_id, ''), reused_from_account_id, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE run_id = ? ORDER BY created_at, item_id`, runID) + if err != nil { + return nil, fmt.Errorf("list import run items by run_id %q: %w", runID, err) + } + defer rows.Close() + + items := make([]ImportRunItem, 0) + for rows.Next() { + var item ImportRunItem + var reusedFromAccountID sql.NullInt64 + var channelID sql.NullInt64 + var accountID sql.NullInt64 + var legacyBatchID sql.NullInt64 + var provisionReused int + if err := rows.Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.APIKeyFingerprint, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.CanonicalFamiliesJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &item.MatchedAccountState, &item.AccountResolution, &provisionReused, &item.ReusedFromProviderID, &reusedFromAccountID, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan import run item: %w", err) + } + item.ProvisionReused = provisionReused == 1 + item.ReusedFromAccountID = ptrFromNullInt64(reusedFromAccountID) + item.ChannelID = ptrFromNullInt64(channelID) + item.AccountID = ptrFromNullInt64(accountID) + item.LegacyBatchID = ptrFromNullInt64(legacyBatchID) + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate import run items by run_id %q: %w", runID, err) + } + return items, nil +} + +func boolToInt(value bool) int { + if value { + return 1 + } + return 0 +} + +func ptrFromNullInt64(value sql.NullInt64) *int64 { + if !value.Valid { + return nil + } + result := value.Int64 + return &result +} diff --git a/internal/store/sqlite/import_runs_repo.go b/internal/store/sqlite/import_runs_repo.go index c5220f7c..14c8b135 100644 --- a/internal/store/sqlite/import_runs_repo.go +++ b/internal/store/sqlite/import_runs_repo.go @@ -120,288 +120,6 @@ func (r *ImportRunsRepo) List(ctx context.Context, limit int) ([]ImportRun, erro return runs, nil } -type ImportRunItem struct { - ItemID string - RunID string - BaseURL string - ProviderID string - RequestedModelsJSON string - RawModelsJSON string - NormalizedModelsJSON string - RecommendedModelsJSON string - ResolvedSmokeModel string - CapabilityProfileJSON string - CurrentStage string - ConfirmationStatus string - AccessStatus string - ChannelID *int64 - AccountID *int64 - RetryCount int - ConfirmationAttempts int - LastRetryAt string - NextRetryAt string - LeaseOwner string - LeaseUntil string - AdvisoryMessagesJSON string - LastErrorStage string - LastError string - LegacyBatchID *int64 - LegacyProviderID string - CreatedAt string - UpdatedAt string -} - -type ImportRunItemsRepo struct { - db execQuerier -} - -func newImportRunItemsRepo(db execQuerier) *ImportRunItemsRepo { - return &ImportRunItemsRepo{db: db} -} - -func (r *ImportRunItemsRepo) Create(ctx context.Context, item ImportRunItem) error { - itemID := strings.TrimSpace(item.ItemID) - runID := strings.TrimSpace(item.RunID) - baseURL := strings.TrimSpace(item.BaseURL) - providerID := strings.TrimSpace(item.ProviderID) - currentStage := strings.TrimSpace(item.CurrentStage) - confirmationStatus := strings.TrimSpace(item.ConfirmationStatus) - accessStatus := strings.TrimSpace(item.AccessStatus) - - switch { - case itemID == "": - return fmt.Errorf("item_id is required") - case runID == "": - return fmt.Errorf("run_id is required") - case baseURL == "": - return fmt.Errorf("base_url is required") - case providerID == "": - return fmt.Errorf("provider_id is required") - case currentStage == "": - return fmt.Errorf("current_stage is required") - case confirmationStatus == "": - return fmt.Errorf("confirmation_status is required") - case accessStatus == "": - return fmt.Errorf("access_status is required") - } - - if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_items ( - item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, - recommended_models_json, resolved_smoke_model, capability_profile_json, current_stage, confirmation_status, - access_status, channel_id, account_id, retry_count, confirmation_attempts, last_retry_at, next_retry_at, - lease_owner, lease_until, advisory_messages_json, last_error_stage, last_error, legacy_batch_id, legacy_provider_id - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, - itemID, runID, baseURL, providerID, - defaultJSON(item.RequestedModelsJSON, "[]"), - defaultJSON(item.RawModelsJSON, "[]"), - defaultJSON(item.NormalizedModelsJSON, "[]"), - defaultJSON(item.RecommendedModelsJSON, "[]"), - nullableString(strings.TrimSpace(item.ResolvedSmokeModel)), - defaultJSON(item.CapabilityProfileJSON, "{}"), - currentStage, confirmationStatus, accessStatus, - item.ChannelID, item.AccountID, item.RetryCount, item.ConfirmationAttempts, - nullableString(strings.TrimSpace(item.LastRetryAt)), - nullableString(strings.TrimSpace(item.NextRetryAt)), - nullableString(strings.TrimSpace(item.LeaseOwner)), - nullableString(strings.TrimSpace(item.LeaseUntil)), - defaultJSON(item.AdvisoryMessagesJSON, "[]"), - nullableString(strings.TrimSpace(item.LastErrorStage)), - nullableString(strings.TrimSpace(item.LastError)), - item.LegacyBatchID, - nullableString(strings.TrimSpace(item.LegacyProviderID)), - ); err != nil { - return fmt.Errorf("insert import run item %q: %w", itemID, err) - } - return nil -} - -func (r *ImportRunItemsRepo) Update(ctx context.Context, item ImportRunItem) error { - itemID := strings.TrimSpace(item.ItemID) - runID := strings.TrimSpace(item.RunID) - baseURL := strings.TrimSpace(item.BaseURL) - providerID := strings.TrimSpace(item.ProviderID) - currentStage := strings.TrimSpace(item.CurrentStage) - confirmationStatus := strings.TrimSpace(item.ConfirmationStatus) - accessStatus := strings.TrimSpace(item.AccessStatus) - - switch { - case itemID == "": - return fmt.Errorf("item_id is required") - case runID == "": - return fmt.Errorf("run_id is required") - case baseURL == "": - return fmt.Errorf("base_url is required") - case providerID == "": - return fmt.Errorf("provider_id is required") - case currentStage == "": - return fmt.Errorf("current_stage is required") - case confirmationStatus == "": - return fmt.Errorf("confirmation_status is required") - case accessStatus == "": - return fmt.Errorf("access_status is required") - } - - if _, err := r.db.ExecContext(ctx, `UPDATE import_run_items SET - run_id = ?, base_url = ?, provider_id = ?, requested_models_json = ?, raw_models_json = ?, normalized_models_json = ?, - recommended_models_json = ?, resolved_smoke_model = ?, capability_profile_json = ?, current_stage = ?, confirmation_status = ?, - access_status = ?, channel_id = ?, account_id = ?, retry_count = ?, confirmation_attempts = ?, last_retry_at = ?, next_retry_at = ?, - lease_owner = ?, lease_until = ?, advisory_messages_json = ?, last_error_stage = ?, last_error = ?, legacy_batch_id = ?, legacy_provider_id = ?, - updated_at = CURRENT_TIMESTAMP - WHERE item_id = ?`, - runID, baseURL, providerID, - defaultJSON(item.RequestedModelsJSON, "[]"), - defaultJSON(item.RawModelsJSON, "[]"), - defaultJSON(item.NormalizedModelsJSON, "[]"), - defaultJSON(item.RecommendedModelsJSON, "[]"), - nullableString(strings.TrimSpace(item.ResolvedSmokeModel)), - defaultJSON(item.CapabilityProfileJSON, "{}"), - currentStage, confirmationStatus, accessStatus, - item.ChannelID, item.AccountID, item.RetryCount, item.ConfirmationAttempts, - nullableString(strings.TrimSpace(item.LastRetryAt)), - nullableString(strings.TrimSpace(item.NextRetryAt)), - nullableString(strings.TrimSpace(item.LeaseOwner)), - nullableString(strings.TrimSpace(item.LeaseUntil)), - defaultJSON(item.AdvisoryMessagesJSON, "[]"), - nullableString(strings.TrimSpace(item.LastErrorStage)), - nullableString(strings.TrimSpace(item.LastError)), - item.LegacyBatchID, - nullableString(strings.TrimSpace(item.LegacyProviderID)), - itemID, - ); err != nil { - return fmt.Errorf("update import run item %q: %w", itemID, err) - } - return nil -} - -func (r *ImportRunItemsRepo) GetByItemID(ctx context.Context, itemID string) (ImportRunItem, error) { - itemID = strings.TrimSpace(itemID) - if itemID == "" { - return ImportRunItem{}, fmt.Errorf("item_id is required") - } - - var item ImportRunItem - var channelID sqlNullInt64 - var accountID sqlNullInt64 - var legacyBatchID sqlNullInt64 - if err := r.db.QueryRowContext(ctx, `SELECT item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE item_id = ?`, itemID). - Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { - return ImportRunItem{}, err - } - item.ChannelID = channelID.ptr() - item.AccountID = accountID.ptr() - item.LegacyBatchID = legacyBatchID.ptr() - return item, nil -} - -func (r *ImportRunItemsRepo) ListByRunID(ctx context.Context, runID string) ([]ImportRunItem, error) { - runID = strings.TrimSpace(runID) - if runID == "" { - return nil, fmt.Errorf("run_id is required") - } - - rows, err := r.db.QueryContext(ctx, `SELECT item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE run_id = ? ORDER BY created_at, item_id`, runID) - if err != nil { - return nil, fmt.Errorf("list import run items by run_id %q: %w", runID, err) - } - defer rows.Close() - - items := make([]ImportRunItem, 0) - for rows.Next() { - var item ImportRunItem - var channelID sqlNullInt64 - var accountID sqlNullInt64 - var legacyBatchID sqlNullInt64 - if err := rows.Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { - return nil, fmt.Errorf("scan import run item: %w", err) - } - item.ChannelID = channelID.ptr() - item.AccountID = accountID.ptr() - item.LegacyBatchID = legacyBatchID.ptr() - items = append(items, item) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterate import run items by run_id %q: %w", runID, err) - } - return items, nil -} - -type ImportRunItemEvent struct { - EventID string - RunID string - ItemID string - EventType string - Stage string - Attempt int - Message string - PayloadJSON string - CreatedAt string -} - -type ImportRunItemEventsRepo struct { - db execQuerier -} - -func newImportRunItemEventsRepo(db execQuerier) *ImportRunItemEventsRepo { - return &ImportRunItemEventsRepo{db: db} -} - -func (r *ImportRunItemEventsRepo) Create(ctx context.Context, event ImportRunItemEvent) error { - eventID := strings.TrimSpace(event.EventID) - runID := strings.TrimSpace(event.RunID) - itemID := strings.TrimSpace(event.ItemID) - eventType := strings.TrimSpace(event.EventType) - stage := strings.TrimSpace(event.Stage) - message := strings.TrimSpace(event.Message) - payloadJSON := defaultJSON(event.PayloadJSON, "{}") - - switch { - case eventID == "": - return fmt.Errorf("event_id is required") - case runID == "": - return fmt.Errorf("run_id is required") - case itemID == "": - return fmt.Errorf("item_id is required") - case eventType == "": - return fmt.Errorf("event_type is required") - case stage == "": - return fmt.Errorf("stage is required") - case message == "": - return fmt.Errorf("message is required") - } - - if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_item_events (event_id, run_id, item_id, event_type, stage, attempt, message, payload_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, - eventID, runID, itemID, eventType, stage, event.Attempt, message, payloadJSON); err != nil { - return fmt.Errorf("insert import run item event %q: %w", eventID, err) - } - return nil -} - -func (r *ImportRunItemEventsRepo) ListByItemID(ctx context.Context, itemID string) ([]ImportRunItemEvent, error) { - itemID = strings.TrimSpace(itemID) - if itemID == "" { - return nil, fmt.Errorf("item_id is required") - } - - rows, err := r.db.QueryContext(ctx, `SELECT event_id, run_id, item_id, event_type, stage, attempt, message, payload_json, created_at FROM import_run_item_events WHERE item_id = ? ORDER BY created_at, event_id`, itemID) - if err != nil { - return nil, fmt.Errorf("list import run item events by item_id %q: %w", itemID, err) - } - defer rows.Close() - - events := make([]ImportRunItemEvent, 0) - for rows.Next() { - var event ImportRunItemEvent - if err := rows.Scan(&event.EventID, &event.RunID, &event.ItemID, &event.EventType, &event.Stage, &event.Attempt, &event.Message, &event.PayloadJSON, &event.CreatedAt); err != nil { - return nil, fmt.Errorf("scan import run item event: %w", err) - } - events = append(events, event) - } - if err := rows.Err(); err != nil { - return nil, fmt.Errorf("iterate import run item events by item_id %q: %w", itemID, err) - } - return events, nil -} - type sqlNullInt64 struct { Int64 int64 Valid bool diff --git a/internal/store/sqlite/import_runs_repo_test.go b/internal/store/sqlite/import_runs_repo_test.go new file mode 100644 index 00000000..41173413 --- /dev/null +++ b/internal/store/sqlite/import_runs_repo_test.go @@ -0,0 +1,155 @@ +package sqlite + +import ( + "context" + "reflect" + "testing" +) + +func TestRunStateStore(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + + run := ImportRun{ + RunID: "run-1", + Mode: "strict", + AccessMode: "subscription", + State: "running", + TotalItems: 1, + } + if err := store.ImportRuns().Create(ctx, run); err != nil { + t.Fatalf("ImportRuns().Create() error = %v", err) + } + + run.State = "completed_with_warnings" + run.CompletedItems = 1 + run.ActiveItems = 1 + run.WarningItems = 1 + run.FinishedAt = "2026-05-22T12:00:00Z" + if err := store.ImportRuns().Update(ctx, run); err != nil { + t.Fatalf("ImportRuns().Update() error = %v", err) + } + + gotRun, err := store.ImportRuns().GetByRunID(ctx, "run-1") + if err != nil { + t.Fatalf("ImportRuns().GetByRunID() error = %v", err) + } + if gotRun.State != "completed_with_warnings" { + t.Fatalf("run.State = %q, want completed_with_warnings", gotRun.State) + } + if gotRun.WarningItems != 1 { + t.Fatalf("run.WarningItems = %d, want 1", gotRun.WarningItems) + } + + legacyBatchID := int64(88) + reusedAccountID := int64(321) + channelID := int64(66) + accountID := int64(77) + + item := ImportRunItem{ + ItemID: "item-1", + RunID: "run-1", + BaseURL: "https://api.deepseek.com/v1", + ProviderID: "api-deepseek-12345678", + APIKeyFingerprint: "fp_abc123", + RequestedModelsJSON: `["kimi 2.6"]`, + RawModelsJSON: `["kimi-k2.6"]`, + NormalizedModelsJSON: `["kimi-k2.6"]`, + CanonicalFamiliesJSON: `["kimi-2.6"]`, + RecommendedModelsJSON: `["kimi-k2.6"]`, + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "deprecated", + AccountResolution: "reactivated", + ProvisionReused: true, + ReusedFromProviderID: "api-deepseek-87654321", + ReusedFromAccountID: &reusedAccountID, + ChannelID: &channelID, + AccountID: &accountID, + RetryCount: 2, + ConfirmationAttempts: 3, + LastRetryAt: "2026-05-22T12:01:00Z", + NextRetryAt: "2026-05-22T12:02:00Z", + LeaseOwner: "worker-1", + LeaseUntil: "2026-05-22T12:03:00Z", + AdvisoryMessagesJSON: `["warmup"]`, + LastErrorStage: "confirm", + LastError: "no available accounts", + LegacyBatchID: &legacyBatchID, + LegacyProviderID: "legacy-provider", + CapabilityProfileJSON: `{"transport_profile":{"supports_openai_chat_completions":true}}`, + ResolvedSmokeModel: "kimi-k2.6", + } + + if err := store.ImportRunItems().Upsert(ctx, item); err != nil { + t.Fatalf("ImportRunItems().Upsert() error = %v", err) + } + + gotItem, err := store.ImportRunItems().GetByItemID(ctx, "item-1") + if err != nil { + t.Fatalf("ImportRunItems().GetByItemID() error = %v", err) + } + if gotItem.APIKeyFingerprint != "fp_abc123" { + t.Fatalf("item.APIKeyFingerprint = %q, want fp_abc123", gotItem.APIKeyFingerprint) + } + if gotItem.CanonicalFamiliesJSON != `["kimi-2.6"]` { + t.Fatalf("item.CanonicalFamiliesJSON = %q, want canonical families json", gotItem.CanonicalFamiliesJSON) + } + if gotItem.MatchedAccountState != "deprecated" { + t.Fatalf("item.MatchedAccountState = %q, want deprecated", gotItem.MatchedAccountState) + } + if gotItem.AccountResolution != "reactivated" { + t.Fatalf("item.AccountResolution = %q, want reactivated", gotItem.AccountResolution) + } + if !gotItem.ProvisionReused { + t.Fatal("item.ProvisionReused = false, want true") + } + if gotItem.ReusedFromProviderID != "api-deepseek-87654321" { + t.Fatalf("item.ReusedFromProviderID = %q, want reused provider id", gotItem.ReusedFromProviderID) + } + if gotItem.ReusedFromAccountID == nil || *gotItem.ReusedFromAccountID != reusedAccountID { + t.Fatalf("item.ReusedFromAccountID = %#v, want %d", gotItem.ReusedFromAccountID, reusedAccountID) + } + if gotItem.LeaseOwner != "worker-1" || gotItem.LeaseUntil != "2026-05-22T12:03:00Z" { + t.Fatalf("lease = (%q, %q), want persisted lease fields", gotItem.LeaseOwner, gotItem.LeaseUntil) + } + + event := ImportRunItemEvent{ + EventID: "evt-1", + RunID: "run-1", + ItemID: "item-1", + EventType: "retry_scheduled", + Stage: "confirm", + Attempt: 2, + Message: "retry after warmup", + PayloadJSON: `{"next_retry_at":"2026-05-22T12:02:00Z"}`, + } + if err := store.ImportRunEvents().Append(ctx, event); err != nil { + t.Fatalf("ImportRunEvents().Append() error = %v", err) + } + + events, err := store.ImportRunEvents().ListByItemID(ctx, "item-1") + if err != nil { + t.Fatalf("ImportRunEvents().ListByItemID() error = %v", err) + } + if len(events) != 1 { + t.Fatalf("len(events) = %d, want 1", len(events)) + } + if events[0].EventType != "retry_scheduled" { + t.Fatalf("events[0].EventType = %q, want retry_scheduled", events[0].EventType) + } + + items, err := store.ImportRunItems().ListByRunID(ctx, "run-1") + if err != nil { + t.Fatalf("ImportRunItems().ListByRunID() error = %v", err) + } + if len(items) != 1 { + t.Fatalf("len(items) = %d, want 1", len(items)) + } + if !reflect.DeepEqual(items[0].AdvisoryMessagesJSON, `["warmup"]`) { + t.Fatalf("items[0].AdvisoryMessagesJSON = %q, want advisory json", items[0].AdvisoryMessagesJSON) + } +} diff --git a/tests/integration/store_init_test.go b/tests/integration/store_init_test.go index a1bc1fff..c35eb44d 100644 --- a/tests/integration/store_init_test.go +++ b/tests/integration/store_init_test.go @@ -129,6 +129,33 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) { } } +func TestStoreAppliesLatestMigration(t *testing.T) { + store := openTestStore(t) + defer closeTestStore(t, store) + + for _, table := range []string{"import_runs", "import_run_items", "import_run_item_events"} { + if !tableExists(t, store.SQLDB(), table) { + t.Fatalf("table %q does not exist after latest migration", table) + } + } + + for _, column := range []string{ + "api_key_fingerprint", + "canonical_model_families_json", + "matched_account_state", + "account_resolution", + "provision_reused", + "reused_from_provider_id", + "reused_from_account_id", + "lease_owner", + "lease_until", + } { + if !tableColumnExists(t, store.SQLDB(), "import_run_items", column) { + t.Fatalf("column %q missing from import_run_items", column) + } + } +} + func TestStoreInitBackfillsLedgerForCompletePreLedgerSchema(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "state.db") dsn := fmt.Sprintf("file:%s?_busy_timeout=5000", filepath.ToSlash(dbPath)) @@ -303,3 +330,34 @@ func countRows(t *testing.T, db *sql.DB, table string) int { return count } + +func tableColumnExists(t *testing.T, db *sql.DB, table, column string) bool { + t.Helper() + + rows, err := db.QueryContext(context.Background(), "PRAGMA table_info("+table+")") + if err != nil { + t.Fatalf("table_info(%q) query error = %v", table, err) + } + defer rows.Close() + + for rows.Next() { + var ( + cid int + name string + columnType string + notNull int + defaultVal sql.NullString + pk int + ) + if err := rows.Scan(&cid, &name, &columnType, ¬Null, &defaultVal, &pk); err != nil { + t.Fatalf("table_info(%q) scan error = %v", table, err) + } + if name == column { + return true + } + } + if err := rows.Err(); err != nil { + t.Fatalf("table_info(%q) rows error = %v", table, err) + } + return false +}