diff --git a/docs/EXECUTION_BOARD.md b/docs/EXECUTION_BOARD.md index 537acfd9..735e85ba 100644 --- a/docs/EXECUTION_BOARD.md +++ b/docs/EXECUTION_BOARD.md @@ -173,7 +173,9 @@ **当前剩余项**: - [x] 按收口后的 canonical contract 输出数据库 migration 草案 - [x] 补齐 run/item API response schema 细稿 -- [ ] 按收口后的 OpenAPI、migration、projection 字段开始实现 +- [~] 按收口后的 OpenAPI、migration、projection 字段开始实现 + - 已落地 `0007_batch_import_runs.sql` / `0008_batch_import_run_events.sql` + - 已补 `internal/store/sqlite` 下 run/item/event repo 骨架,并完成 migration ledger 测试同步 - [ ] 进入实现前再做一次实现前审阅,确认没有新增分叉 **实现前 Gate**:文档级 review 问题已收口,当前可以进入“按文档写 migration / 接口 / worker”的实现准备阶段 diff --git a/internal/store/sqlite/db.go b/internal/store/sqlite/db.go index 44c29e19..b6ca4611 100644 --- a/internal/store/sqlite/db.go +++ b/internal/store/sqlite/db.go @@ -25,6 +25,9 @@ type Queries struct { Providers *ProvidersRepo ImportBatches *ImportBatchesRepo ImportBatchItems *ImportBatchItemsRepo + ImportRuns *ImportRunsRepo + ImportRunItems *ImportRunItemsRepo + ImportRunEvents *ImportRunItemEventsRepo ManagedResources *ManagedResourcesRepo ProbeResults *ProbeResultsRepo AccessClosures *AccessClosureRecordsRepo @@ -91,6 +94,18 @@ func (db *DB) ImportBatchItems() *ImportBatchItemsRepo { return db.queries.ImportBatchItems } +func (db *DB) ImportRuns() *ImportRunsRepo { + return db.queries.ImportRuns +} + +func (db *DB) ImportRunItems() *ImportRunItemsRepo { + return db.queries.ImportRunItems +} + +func (db *DB) ImportRunEvents() *ImportRunItemEventsRepo { + return db.queries.ImportRunEvents +} + func (db *DB) ManagedResources() *ManagedResourcesRepo { return db.queries.ManagedResources } @@ -137,6 +152,9 @@ func newQueries(db execQuerier) *Queries { Providers: newProvidersRepo(db), ImportBatches: newImportBatchesRepo(db), ImportBatchItems: newImportBatchItemsRepo(db), + ImportRuns: newImportRunsRepo(db), + ImportRunItems: newImportRunItemsRepo(db), + ImportRunEvents: newImportRunItemEventsRepo(db), ManagedResources: newManagedResourcesRepo(db), ProbeResults: newProbeResultsRepo(db), AccessClosures: newAccessClosureRecordsRepo(db), diff --git a/internal/store/sqlite/import_runs_repo.go b/internal/store/sqlite/import_runs_repo.go new file mode 100644 index 00000000..c5220f7c --- /dev/null +++ b/internal/store/sqlite/import_runs_repo.go @@ -0,0 +1,431 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" +) + +type ImportRun struct { + RunID string + Mode string + AccessMode string + State string + TotalItems int + CompletedItems int + ActiveItems int + DegradedItems int + BrokenItems int + WarningItems int + StartedAt string + UpdatedAt string + FinishedAt string +} + +type ImportRunsRepo struct { + db execQuerier +} + +func newImportRunsRepo(db execQuerier) *ImportRunsRepo { + return &ImportRunsRepo{db: db} +} + +func (r *ImportRunsRepo) Create(ctx context.Context, run ImportRun) error { + runID := strings.TrimSpace(run.RunID) + mode := strings.TrimSpace(run.Mode) + accessMode := strings.TrimSpace(run.AccessMode) + state := strings.TrimSpace(run.State) + + switch { + case runID == "": + return fmt.Errorf("run_id is required") + case mode == "": + return fmt.Errorf("mode is required") + case accessMode == "": + return fmt.Errorf("access_mode is required") + case state == "": + return fmt.Errorf("state is required") + } + + if _, err := r.db.ExecContext(ctx, `INSERT INTO import_runs (run_id, mode, access_mode, state, total_items, completed_items, active_items, degraded_items, broken_items, warning_items) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + runID, mode, accessMode, state, run.TotalItems, run.CompletedItems, run.ActiveItems, run.DegradedItems, run.BrokenItems, run.WarningItems); err != nil { + return fmt.Errorf("insert import run %q: %w", runID, err) + } + return nil +} + +func (r *ImportRunsRepo) Update(ctx context.Context, run ImportRun) error { + runID := strings.TrimSpace(run.RunID) + mode := strings.TrimSpace(run.Mode) + accessMode := strings.TrimSpace(run.AccessMode) + state := strings.TrimSpace(run.State) + finishedAt := strings.TrimSpace(run.FinishedAt) + + switch { + case runID == "": + return fmt.Errorf("run_id is required") + case mode == "": + return fmt.Errorf("mode is required") + case accessMode == "": + return fmt.Errorf("access_mode is required") + case state == "": + return fmt.Errorf("state is required") + } + + if _, err := r.db.ExecContext(ctx, `UPDATE import_runs + SET mode = ?, access_mode = ?, state = ?, total_items = ?, completed_items = ?, active_items = ?, degraded_items = ?, broken_items = ?, warning_items = ?, finished_at = ?, updated_at = CURRENT_TIMESTAMP + WHERE run_id = ?`, + mode, accessMode, state, run.TotalItems, run.CompletedItems, run.ActiveItems, run.DegradedItems, run.BrokenItems, run.WarningItems, nullableString(finishedAt), runID); err != nil { + return fmt.Errorf("update import run %q: %w", runID, err) + } + return nil +} + +func (r *ImportRunsRepo) GetByRunID(ctx context.Context, runID string) (ImportRun, error) { + runID = strings.TrimSpace(runID) + if runID == "" { + return ImportRun{}, fmt.Errorf("run_id is required") + } + + var run ImportRun + if err := r.db.QueryRowContext(ctx, `SELECT run_id, mode, access_mode, state, total_items, completed_items, active_items, degraded_items, broken_items, warning_items, started_at, updated_at, COALESCE(finished_at, '') FROM import_runs WHERE run_id = ?`, runID). + Scan(&run.RunID, &run.Mode, &run.AccessMode, &run.State, &run.TotalItems, &run.CompletedItems, &run.ActiveItems, &run.DegradedItems, &run.BrokenItems, &run.WarningItems, &run.StartedAt, &run.UpdatedAt, &run.FinishedAt); err != nil { + return ImportRun{}, err + } + return run, nil +} + +func (r *ImportRunsRepo) List(ctx context.Context, limit int) ([]ImportRun, error) { + if limit <= 0 { + limit = 50 + } + + rows, err := r.db.QueryContext(ctx, `SELECT run_id, mode, access_mode, state, total_items, completed_items, active_items, degraded_items, broken_items, warning_items, started_at, updated_at, COALESCE(finished_at, '') FROM import_runs ORDER BY started_at DESC LIMIT ?`, limit) + if err != nil { + return nil, fmt.Errorf("list import runs: %w", err) + } + defer rows.Close() + + runs := make([]ImportRun, 0) + for rows.Next() { + var run ImportRun + if err := rows.Scan(&run.RunID, &run.Mode, &run.AccessMode, &run.State, &run.TotalItems, &run.CompletedItems, &run.ActiveItems, &run.DegradedItems, &run.BrokenItems, &run.WarningItems, &run.StartedAt, &run.UpdatedAt, &run.FinishedAt); err != nil { + return nil, fmt.Errorf("scan import run: %w", err) + } + runs = append(runs, run) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate import runs: %w", err) + } + return runs, nil +} + +type ImportRunItem struct { + ItemID string + RunID string + BaseURL string + ProviderID string + RequestedModelsJSON string + RawModelsJSON string + NormalizedModelsJSON string + RecommendedModelsJSON string + ResolvedSmokeModel string + CapabilityProfileJSON string + CurrentStage string + ConfirmationStatus string + AccessStatus string + ChannelID *int64 + AccountID *int64 + RetryCount int + ConfirmationAttempts int + LastRetryAt string + NextRetryAt string + LeaseOwner string + LeaseUntil string + AdvisoryMessagesJSON string + LastErrorStage string + LastError string + LegacyBatchID *int64 + LegacyProviderID string + CreatedAt string + UpdatedAt string +} + +type ImportRunItemsRepo struct { + db execQuerier +} + +func newImportRunItemsRepo(db execQuerier) *ImportRunItemsRepo { + return &ImportRunItemsRepo{db: db} +} + +func (r *ImportRunItemsRepo) Create(ctx context.Context, item ImportRunItem) error { + itemID := strings.TrimSpace(item.ItemID) + runID := strings.TrimSpace(item.RunID) + baseURL := strings.TrimSpace(item.BaseURL) + providerID := strings.TrimSpace(item.ProviderID) + currentStage := strings.TrimSpace(item.CurrentStage) + confirmationStatus := strings.TrimSpace(item.ConfirmationStatus) + accessStatus := strings.TrimSpace(item.AccessStatus) + + switch { + case itemID == "": + return fmt.Errorf("item_id is required") + case runID == "": + return fmt.Errorf("run_id is required") + case baseURL == "": + return fmt.Errorf("base_url is required") + case providerID == "": + return fmt.Errorf("provider_id is required") + case currentStage == "": + return fmt.Errorf("current_stage is required") + case confirmationStatus == "": + return fmt.Errorf("confirmation_status is required") + case accessStatus == "": + return fmt.Errorf("access_status is required") + } + + if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_items ( + item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, + recommended_models_json, resolved_smoke_model, capability_profile_json, current_stage, confirmation_status, + access_status, channel_id, account_id, retry_count, confirmation_attempts, last_retry_at, next_retry_at, + lease_owner, lease_until, advisory_messages_json, last_error_stage, last_error, legacy_batch_id, legacy_provider_id + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + itemID, runID, baseURL, providerID, + defaultJSON(item.RequestedModelsJSON, "[]"), + defaultJSON(item.RawModelsJSON, "[]"), + defaultJSON(item.NormalizedModelsJSON, "[]"), + defaultJSON(item.RecommendedModelsJSON, "[]"), + nullableString(strings.TrimSpace(item.ResolvedSmokeModel)), + defaultJSON(item.CapabilityProfileJSON, "{}"), + currentStage, confirmationStatus, accessStatus, + item.ChannelID, item.AccountID, item.RetryCount, item.ConfirmationAttempts, + nullableString(strings.TrimSpace(item.LastRetryAt)), + nullableString(strings.TrimSpace(item.NextRetryAt)), + nullableString(strings.TrimSpace(item.LeaseOwner)), + nullableString(strings.TrimSpace(item.LeaseUntil)), + defaultJSON(item.AdvisoryMessagesJSON, "[]"), + nullableString(strings.TrimSpace(item.LastErrorStage)), + nullableString(strings.TrimSpace(item.LastError)), + item.LegacyBatchID, + nullableString(strings.TrimSpace(item.LegacyProviderID)), + ); err != nil { + return fmt.Errorf("insert import run item %q: %w", itemID, err) + } + return nil +} + +func (r *ImportRunItemsRepo) Update(ctx context.Context, item ImportRunItem) error { + itemID := strings.TrimSpace(item.ItemID) + runID := strings.TrimSpace(item.RunID) + baseURL := strings.TrimSpace(item.BaseURL) + providerID := strings.TrimSpace(item.ProviderID) + currentStage := strings.TrimSpace(item.CurrentStage) + confirmationStatus := strings.TrimSpace(item.ConfirmationStatus) + accessStatus := strings.TrimSpace(item.AccessStatus) + + switch { + case itemID == "": + return fmt.Errorf("item_id is required") + case runID == "": + return fmt.Errorf("run_id is required") + case baseURL == "": + return fmt.Errorf("base_url is required") + case providerID == "": + return fmt.Errorf("provider_id is required") + case currentStage == "": + return fmt.Errorf("current_stage is required") + case confirmationStatus == "": + return fmt.Errorf("confirmation_status is required") + case accessStatus == "": + return fmt.Errorf("access_status is required") + } + + if _, err := r.db.ExecContext(ctx, `UPDATE import_run_items SET + run_id = ?, base_url = ?, provider_id = ?, requested_models_json = ?, raw_models_json = ?, normalized_models_json = ?, + recommended_models_json = ?, resolved_smoke_model = ?, capability_profile_json = ?, current_stage = ?, confirmation_status = ?, + access_status = ?, channel_id = ?, account_id = ?, retry_count = ?, confirmation_attempts = ?, last_retry_at = ?, next_retry_at = ?, + lease_owner = ?, lease_until = ?, advisory_messages_json = ?, last_error_stage = ?, last_error = ?, legacy_batch_id = ?, legacy_provider_id = ?, + updated_at = CURRENT_TIMESTAMP + WHERE item_id = ?`, + runID, baseURL, providerID, + defaultJSON(item.RequestedModelsJSON, "[]"), + defaultJSON(item.RawModelsJSON, "[]"), + defaultJSON(item.NormalizedModelsJSON, "[]"), + defaultJSON(item.RecommendedModelsJSON, "[]"), + nullableString(strings.TrimSpace(item.ResolvedSmokeModel)), + defaultJSON(item.CapabilityProfileJSON, "{}"), + currentStage, confirmationStatus, accessStatus, + item.ChannelID, item.AccountID, item.RetryCount, item.ConfirmationAttempts, + nullableString(strings.TrimSpace(item.LastRetryAt)), + nullableString(strings.TrimSpace(item.NextRetryAt)), + nullableString(strings.TrimSpace(item.LeaseOwner)), + nullableString(strings.TrimSpace(item.LeaseUntil)), + defaultJSON(item.AdvisoryMessagesJSON, "[]"), + nullableString(strings.TrimSpace(item.LastErrorStage)), + nullableString(strings.TrimSpace(item.LastError)), + item.LegacyBatchID, + nullableString(strings.TrimSpace(item.LegacyProviderID)), + itemID, + ); err != nil { + return fmt.Errorf("update import run item %q: %w", itemID, err) + } + return nil +} + +func (r *ImportRunItemsRepo) GetByItemID(ctx context.Context, itemID string) (ImportRunItem, error) { + itemID = strings.TrimSpace(itemID) + if itemID == "" { + return ImportRunItem{}, fmt.Errorf("item_id is required") + } + + var item ImportRunItem + var channelID sqlNullInt64 + var accountID sqlNullInt64 + var legacyBatchID sqlNullInt64 + if err := r.db.QueryRowContext(ctx, `SELECT item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE item_id = ?`, itemID). + Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { + return ImportRunItem{}, err + } + item.ChannelID = channelID.ptr() + item.AccountID = accountID.ptr() + item.LegacyBatchID = legacyBatchID.ptr() + return item, nil +} + +func (r *ImportRunItemsRepo) ListByRunID(ctx context.Context, runID string) ([]ImportRunItem, error) { + runID = strings.TrimSpace(runID) + if runID == "" { + return nil, fmt.Errorf("run_id is required") + } + + rows, err := r.db.QueryContext(ctx, `SELECT item_id, run_id, base_url, provider_id, requested_models_json, raw_models_json, normalized_models_json, recommended_models_json, COALESCE(resolved_smoke_model, ''), capability_profile_json, current_stage, confirmation_status, access_status, channel_id, account_id, retry_count, confirmation_attempts, COALESCE(last_retry_at, ''), COALESCE(next_retry_at, ''), COALESCE(lease_owner, ''), COALESCE(lease_until, ''), advisory_messages_json, COALESCE(last_error_stage, ''), COALESCE(last_error, ''), legacy_batch_id, COALESCE(legacy_provider_id, ''), created_at, updated_at FROM import_run_items WHERE run_id = ? ORDER BY created_at, item_id`, runID) + if err != nil { + return nil, fmt.Errorf("list import run items by run_id %q: %w", runID, err) + } + defer rows.Close() + + items := make([]ImportRunItem, 0) + for rows.Next() { + var item ImportRunItem + var channelID sqlNullInt64 + var accountID sqlNullInt64 + var legacyBatchID sqlNullInt64 + if err := rows.Scan(&item.ItemID, &item.RunID, &item.BaseURL, &item.ProviderID, &item.RequestedModelsJSON, &item.RawModelsJSON, &item.NormalizedModelsJSON, &item.RecommendedModelsJSON, &item.ResolvedSmokeModel, &item.CapabilityProfileJSON, &item.CurrentStage, &item.ConfirmationStatus, &item.AccessStatus, &channelID, &accountID, &item.RetryCount, &item.ConfirmationAttempts, &item.LastRetryAt, &item.NextRetryAt, &item.LeaseOwner, &item.LeaseUntil, &item.AdvisoryMessagesJSON, &item.LastErrorStage, &item.LastError, &legacyBatchID, &item.LegacyProviderID, &item.CreatedAt, &item.UpdatedAt); err != nil { + return nil, fmt.Errorf("scan import run item: %w", err) + } + item.ChannelID = channelID.ptr() + item.AccountID = accountID.ptr() + item.LegacyBatchID = legacyBatchID.ptr() + items = append(items, item) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate import run items by run_id %q: %w", runID, err) + } + return items, nil +} + +type ImportRunItemEvent struct { + EventID string + RunID string + ItemID string + EventType string + Stage string + Attempt int + Message string + PayloadJSON string + CreatedAt string +} + +type ImportRunItemEventsRepo struct { + db execQuerier +} + +func newImportRunItemEventsRepo(db execQuerier) *ImportRunItemEventsRepo { + return &ImportRunItemEventsRepo{db: db} +} + +func (r *ImportRunItemEventsRepo) Create(ctx context.Context, event ImportRunItemEvent) error { + eventID := strings.TrimSpace(event.EventID) + runID := strings.TrimSpace(event.RunID) + itemID := strings.TrimSpace(event.ItemID) + eventType := strings.TrimSpace(event.EventType) + stage := strings.TrimSpace(event.Stage) + message := strings.TrimSpace(event.Message) + payloadJSON := defaultJSON(event.PayloadJSON, "{}") + + switch { + case eventID == "": + return fmt.Errorf("event_id is required") + case runID == "": + return fmt.Errorf("run_id is required") + case itemID == "": + return fmt.Errorf("item_id is required") + case eventType == "": + return fmt.Errorf("event_type is required") + case stage == "": + return fmt.Errorf("stage is required") + case message == "": + return fmt.Errorf("message is required") + } + + if _, err := r.db.ExecContext(ctx, `INSERT INTO import_run_item_events (event_id, run_id, item_id, event_type, stage, attempt, message, payload_json) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + eventID, runID, itemID, eventType, stage, event.Attempt, message, payloadJSON); err != nil { + return fmt.Errorf("insert import run item event %q: %w", eventID, err) + } + return nil +} + +func (r *ImportRunItemEventsRepo) ListByItemID(ctx context.Context, itemID string) ([]ImportRunItemEvent, error) { + itemID = strings.TrimSpace(itemID) + if itemID == "" { + return nil, fmt.Errorf("item_id is required") + } + + rows, err := r.db.QueryContext(ctx, `SELECT event_id, run_id, item_id, event_type, stage, attempt, message, payload_json, created_at FROM import_run_item_events WHERE item_id = ? ORDER BY created_at, event_id`, itemID) + if err != nil { + return nil, fmt.Errorf("list import run item events by item_id %q: %w", itemID, err) + } + defer rows.Close() + + events := make([]ImportRunItemEvent, 0) + for rows.Next() { + var event ImportRunItemEvent + if err := rows.Scan(&event.EventID, &event.RunID, &event.ItemID, &event.EventType, &event.Stage, &event.Attempt, &event.Message, &event.PayloadJSON, &event.CreatedAt); err != nil { + return nil, fmt.Errorf("scan import run item event: %w", err) + } + events = append(events, event) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate import run item events by item_id %q: %w", itemID, err) + } + return events, nil +} + +type sqlNullInt64 struct { + Int64 int64 + Valid bool +} + +func (n sqlNullInt64) ptr() *int64 { + if !n.Valid { + return nil + } + value := n.Int64 + return &value +} + +func nullableString(value string) any { + if strings.TrimSpace(value) == "" { + return nil + } + return value +} + +func defaultJSON(value, fallback string) string { + value = strings.TrimSpace(value) + if value == "" { + return fallback + } + return value +} diff --git a/tests/integration/store_init_test.go b/tests/integration/store_init_test.go index 05487f08..a1bc1fff 100644 --- a/tests/integration/store_init_test.go +++ b/tests/integration/store_init_test.go @@ -5,10 +5,12 @@ import ( "database/sql" "errors" "fmt" + "io/fs" "path/filepath" "testing" _ "modernc.org/sqlite" + "sub2api-cn-relay-manager/internal/store/migrations" "sub2api-cn-relay-manager/internal/store/sqlite" ) @@ -103,13 +105,14 @@ func TestStoreInitRollsBackTransaction(t *testing.T) { func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "state.db") dsn := fmt.Sprintf("file:%s?_busy_timeout=5000", filepath.ToSlash(dbPath)) + wantMigrations := migrationCount(t) store1, err := sqlite.Open(context.Background(), dsn) if err != nil { t.Fatalf("first sqlite.Open() error = %v", err) } - if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != 6 { - t.Fatalf("schema_migrations row count after first open = %d, want 6", got) + if got := countRows(t, store1.SQLDB(), "schema_migrations"); got != wantMigrations { + t.Fatalf("schema_migrations row count after first open = %d, want %d", got, wantMigrations) } if err := store1.Close(); err != nil { t.Fatalf("first store.Close() error = %v", err) @@ -121,14 +124,15 @@ func TestStoreInitRecordsMigrationLedgerOnce(t *testing.T) { } defer closeTestStore(t, store2) - if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != 6 { - t.Fatalf("schema_migrations row count after second open = %d, want 6", got) + if got := countRows(t, store2.SQLDB(), "schema_migrations"); got != wantMigrations { + t.Fatalf("schema_migrations row count after second open = %d, want %d", got, wantMigrations) } } func TestStoreInitBackfillsLedgerForCompletePreLedgerSchema(t *testing.T) { dbPath := filepath.Join(t.TempDir(), "state.db") dsn := fmt.Sprintf("file:%s?_busy_timeout=5000", filepath.ToSlash(dbPath)) + wantMigrations := migrationCount(t) rawDB := openRawSQLiteDB(t, dsn) createLegacy0001Schema(t, rawDB) @@ -140,8 +144,8 @@ func TestStoreInitBackfillsLedgerForCompletePreLedgerSchema(t *testing.T) { } defer closeTestStore(t, store) - if got := countRows(t, store.SQLDB(), "schema_migrations"); got != 6 { - t.Fatalf("schema_migrations row count after backfill = %d, want 6", got) + if got := countRows(t, store.SQLDB(), "schema_migrations"); got != wantMigrations { + t.Fatalf("schema_migrations row count after backfill = %d, want %d", got, wantMigrations) } } @@ -205,6 +209,16 @@ func closeRawSQLiteDB(t *testing.T, db *sql.DB) { } } +func migrationCount(t *testing.T) int { + t.Helper() + + names, err := fs.Glob(migrations.Files, "*.sql") + if err != nil { + t.Fatalf("fs.Glob(migrations.Files) error = %v", err) + } + return len(names) +} + func createLegacy0001Schema(t *testing.T, db *sql.DB) { t.Helper()