From bcc67c4a8a0b373235668e96bd6ab9dc72ecefef Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Sat, 23 May 2026 10:55:57 +0800 Subject: [PATCH] Expand coverage for runtime and sqlite paths --- internal/app/coverage_helpers_test.go | 1840 +++++++++++++++++ internal/app/http_batch_import_test.go | 58 + internal/app/http_batch_runs_test.go | 66 + internal/app/reconcile_background_test.go | 189 ++ internal/provision/import_service_test.go | 13 + .../provision/provider_status_service_test.go | 100 + internal/provision/rollback_service_test.go | 57 + .../provision/runtime_import_service_test.go | 95 + internal/reconcile/service_runtime_test.go | 136 +- .../access_closure_records_repo_test.go | 57 + internal/store/sqlite/db_test.go | 112 + internal/store/sqlite/hosts_repo_test.go | 70 + .../store/sqlite/import_batches_repo_test.go | 238 +++ .../import_run_item_events_repo_test.go | 88 + .../store/sqlite/import_runs_repo_test.go | 180 ++ internal/store/sqlite/packs_repo_test.go | 30 + internal/store/sqlite/providers_repo_test.go | 65 + 17 files changed, 3393 insertions(+), 1 deletion(-) create mode 100644 internal/store/sqlite/access_closure_records_repo_test.go create mode 100644 internal/store/sqlite/import_run_item_events_repo_test.go diff --git a/internal/app/coverage_helpers_test.go b/internal/app/coverage_helpers_test.go index 9150b65d..cefc1448 100644 --- a/internal/app/coverage_helpers_test.go +++ b/internal/app/coverage_helpers_test.go @@ -3,10 +3,16 @@ package app import ( "context" "database/sql" + "fmt" + "net/http" + "net/http/httptest" + "path/filepath" "strings" "testing" + "time" "sub2api-cn-relay-manager/internal/batch" + "sub2api-cn-relay-manager/internal/config" "sub2api-cn-relay-manager/internal/host/sub2api" "sub2api-cn-relay-manager/internal/provision" "sub2api-cn-relay-manager/internal/store/sqlite" @@ -65,6 +71,588 @@ func TestDefaultBackgroundSchedulersAndNewActionSet(t *testing.T) { } } +func TestStartBackgroundSchedulersAndBootstrap(t *testing.T) { + var batchCalls int + var reconcileCalls int + startBackgroundSchedulers(context.Background(), config.StartupConfig{ + Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"}, + Reconcile: config.ReconcileConfig{ + WorkerEnabled: false, + }, + }, backgroundSchedulers{ + runBatchImport: func(context.Context, string) { batchCalls++ }, + runReconcile: func(context.Context, string, time.Duration) { reconcileCalls++ }, + }) + if batchCalls != 1 || reconcileCalls != 0 { + t.Fatalf("startBackgroundSchedulers(disabled) calls = (%d, %d), want (1, 0)", batchCalls, reconcileCalls) + } + + startBackgroundSchedulers(context.Background(), config.StartupConfig{ + Database: config.DatabaseConfig{SQLiteDSN: "file:test.db"}, + Reconcile: config.ReconcileConfig{ + WorkerEnabled: true, + PollInterval: time.Minute, + }, + }, backgroundSchedulers{ + runBatchImport: func(context.Context, string) { batchCalls++ }, + runReconcile: func(context.Context, string, time.Duration) { reconcileCalls++ }, + }) + if batchCalls != 2 || reconcileCalls != 1 { + t.Fatalf("startBackgroundSchedulers(enabled) calls = (%d, %d), want (2, 1)", batchCalls, reconcileCalls) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + t.Setenv(config.EnvListenAddr, "127.0.0.1:19090") + t.Setenv(config.EnvSQLiteDSN, "file:bootstrap-test.db?_foreign_keys=on&_busy_timeout=5000") + t.Setenv(config.EnvAdminToken, "bootstrap-token") + t.Setenv(config.EnvReconcileWorkerEnabled, "false") + server, err := Bootstrap(ctx) + if err != nil { + t.Fatalf("Bootstrap() error = %v", err) + } + if server.Addr() != "127.0.0.1:19090" { + t.Fatalf("Bootstrap() server.Addr() = %q, want 127.0.0.1:19090", server.Addr()) + } +} + +func TestBackgroundSchedulerEntryPoints(t *testing.T) { + t.Parallel() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + runBatchImportBackgroundScheduler(ctx, appTestDSN(t, store)) + runReconcileBackgroundScheduler(ctx, appTestDSN(t, store), 0) + + job := reconcileSweepJob{sqliteDSN: appTestDSN(t, store), interval: time.Minute} + if err := job.Run(context.Background()); err != nil { + t.Fatalf("reconcileSweepJob.Run() error = %v", err) + } +} + +func TestBatchImportProvisionerPatchAndSleepWithContext(t *testing.T) { + t.Parallel() + + if err := (batchImportProvisioner{}).Patch(context.Background(), batch.PatchProvisionRequest{}); err != nil { + t.Fatalf("Patch() error = %v, want nil", err) + } + if err := sleepWithContext(context.Background(), 0); err != nil { + t.Fatalf("sleepWithContext() error = %v, want nil", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := sleepWithContext(ctx, time.Second); err != context.Canceled { + t.Fatalf("sleepWithContext(canceled) error = %v, want %v", err, context.Canceled) + } +} + +func TestProbeHostSnapshotAndResolveValidationAPIKey(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "apikey", Token: "host-token"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + hostVersion, capabilities, err := probeHostSnapshot(context.Background(), client) + if err != nil { + t.Fatalf("probeHostSnapshot() error = %v", err) + } + if hostVersion != "0.1.126" || !capabilities.Groups || !capabilities.Subscriptions { + t.Fatalf("probeHostSnapshot() = (%q, %+v), want supported host snapshot", hostVersion, capabilities) + } + + selfServiceRunner := batchImportRuntimeRunner{ + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSelfService, + ProbeAPIKey: " probe-key ", + }, + } + apiKey, err := selfServiceRunner.resolveValidationAPIKey(context.Background(), sqlite.ImportRunItem{}) + if err != nil { + t.Fatalf("resolveValidationAPIKey(self_service) error = %v", err) + } + if apiKey != "probe-key" { + t.Fatalf("resolveValidationAPIKey(self_service) = %q, want probe-key", apiKey) + } + + subscriptionRunner := batchImportRuntimeRunner{ + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSubscription, + SubscriptionDays: 30, + }, + } + if _, err := subscriptionRunner.resolveValidationAPIKey(context.Background(), sqlite.ImportRunItem{}); err == nil || err.Error() != "subscription_users is required" { + t.Fatalf("resolveValidationAPIKey(subscription missing users) error = %v, want subscription_users is required", err) + } + + unsupportedRunner := batchImportRuntimeRunner{ + request: CreateBatchImportRunRequest{ + AccessMode: "other", + }, + } + if _, err := unsupportedRunner.resolveValidationAPIKey(context.Background(), sqlite.ImportRunItem{}); err == nil || !strings.Contains(err.Error(), `unsupported access mode "other"`) { + t.Fatalf("resolveValidationAPIKey(unsupported) error = %v, want unsupported access mode", err) + } +} + +func TestResolveManagedResourceHostIDAndConfirmItem(t *testing.T) { + t.Parallel() + + if _, err := resolveManagedResourceHostID(context.Background(), nil, sqlite.ImportRunItem{}, "account"); err == nil || err.Error() != "store is required" { + t.Fatalf("resolveManagedResourceHostID(nil store) error = %v, want store is required", err) + } + if _, err := resolveManagedResourceHostID(context.Background(), openAppTestStore(t), sqlite.ImportRunItem{}, "account"); err == nil || err.Error() != "legacy_batch_id is required for account lookup" { + t.Fatalf("resolveManagedResourceHostID(missing batch) error = %v, want legacy_batch_id required", err) + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/admin/accounts/account-ok/test": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"type\":\"test_complete\",\"ok\":true,\"status\":\"passed\",\"message\":\"smoke passed\"}\n\n")) + case "/api/v1/admin/accounts/account-http/test": + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + case "/api/v1/admin/accounts/account-busy/test": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"type\":\"test_complete\",\"ok\":false,\"status\":\"failed\",\"message\":\"No available accounts\"}\n\n")) + case "/api/v1/admin/accounts/account-forbidden/test": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"type\":\"test_complete\",\"ok\":false,\"status\":\"failed\",\"message\":\"Forbidden by upstream\"}\n\n")) + case "/api/v1/admin/accounts/account-bad/test": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"type\":\"test_complete\",\"ok\":false,\"status\":\"failed\",\"message\":\"model mismatch\"}\n\n")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "apikey", Token: "host-token"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + + tests := []struct { + name string + accountID string + wantStatus int + wantMsg string + }{ + {name: "probe ok", accountID: "account-ok", wantStatus: http.StatusOK, wantMsg: "smoke passed"}, + {name: "http error passthrough", accountID: "account-http", wantStatus: http.StatusForbidden, wantMsg: `{"error":"forbidden"}`}, + {name: "busy advisory", accountID: "account-busy", wantStatus: http.StatusServiceUnavailable, wantMsg: "No available accounts"}, + {name: "forbidden advisory", accountID: "account-forbidden", wantStatus: http.StatusForbidden, wantMsg: "Forbidden by upstream"}, + {name: "generic bad request", accountID: "account-bad", wantStatus: http.StatusBadRequest, wantMsg: "model mismatch"}, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + batchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: batchID, + HostID: hostPK, + ResourceType: "account", + HostResourceID: tc.accountID, + ResourceName: tc.accountID, + }); err != nil { + t.Fatalf("ManagedResources().Create() error = %v", err) + } + + runner := batchImportRuntimeRunner{ + store: store, + hostClient: client, + } + result, err := runner.confirmItem(context.Background(), sqlite.ImportRunItem{ + LegacyBatchID: &batchID, + ResolvedSmokeModel: "kimi-k2.6", + }) + if err != nil { + t.Fatalf("confirmItem() error = %v", err) + } + if result.StatusCode != tc.wantStatus || result.Message != tc.wantMsg { + t.Fatalf("confirmItem() = %+v, want status=%d message=%q", result, tc.wantStatus, tc.wantMsg) + } + }) + } +} + +func TestBatchImportRunItemStoreTryAcquireLease(t *testing.T) { + t.Parallel() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + mustCreateAppImportRun(t, store, sqlite.ImportRun{ + RunID: "run-lease-1", + HostID: "host-1", + Mode: "partial", + AccessMode: "self_service", + State: "running", + }) + mustCreateAppImportRunItem(t, store, sqlite.ImportRunItem{ + ItemID: "item-lease-1", + RunID: "run-lease-1", + BaseURL: "https://deepseek.example.com/v1", + ProviderID: "deepseek", + APIKeyFingerprint: "sha256:lease", + ResolvedSmokeModel: "deepseek-v4-pro", + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + }) + + leaseStore := batchImportRunItemStore{store: store, runID: "run-lease-1"} + now := time.Date(2026, 5, 23, 10, 0, 0, 0, time.UTC) + + if _, _, err := leaseStore.TryAcquireLease(context.Background(), "", "worker-1", now, time.Minute); err == nil || err.Error() != "item_id is required" { + t.Fatalf("TryAcquireLease(missing item) error = %v, want item_id is required", err) + } + if _, _, err := leaseStore.TryAcquireLease(context.Background(), "item-lease-1", "", now, time.Minute); err == nil || err.Error() != "worker_id is required" { + t.Fatalf("TryAcquireLease(missing worker) error = %v, want worker_id is required", err) + } + + item, claimed, err := leaseStore.TryAcquireLease(context.Background(), "item-lease-1", "worker-1", now, time.Minute) + if err != nil { + t.Fatalf("TryAcquireLease(first) error = %v", err) + } + if !claimed || item.LeaseOwner != "worker-1" { + t.Fatalf("TryAcquireLease(first) = (%+v, %v), want claimed by worker-1", item, claimed) + } + + item, claimed, err = leaseStore.TryAcquireLease(context.Background(), "item-lease-1", "worker-2", now.Add(30*time.Second), time.Minute) + if err != nil { + t.Fatalf("TryAcquireLease(second) error = %v", err) + } + if claimed || item.ItemID != "" { + t.Fatalf("TryAcquireLease(second) = (%+v, %v), want unclaimed empty item", item, claimed) + } +} + +func TestDriveRunValidationAndRetryBranches(t *testing.T) { + t.Parallel() + + t.Run("validate stage completes run", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/v1/chat/completions": + writeJSON(w, http.StatusOK, map[string]any{ + "id": "chatcmpl_validate", + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "pong", + }, + }}, + }) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "bearer", Token: "gateway-key"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + mustCreateAppImportRun(t, store, sqlite.ImportRun{ + RunID: "run-validate-1", + HostID: "host-1", + Mode: "partial", + AccessMode: "self_service", + State: "running", + TotalItems: 1, + }) + mustCreateAppImportRunItem(t, store, sqlite.ImportRunItem{ + ItemID: "item-validate-1", + RunID: "run-validate-1", + BaseURL: server.URL, + ProviderID: "deepseek", + APIKeyFingerprint: "sha256:validate", + ResolvedSmokeModel: "deepseek-v4-pro", + CurrentStage: "validate", + ConfirmationStatus: "confirmed", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + }) + + runner := batchImportRuntimeRunner{ + store: store, + hostClient: client, + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSelfService, + ProbeAPIKey: "gateway-key", + }, + } + if err := runner.driveRun(context.Background(), "run-validate-1", 0); err != nil { + t.Fatalf("driveRun(validate) error = %v", err) + } + + run, err := store.ImportRuns().GetByRunID(context.Background(), "run-validate-1") + if err != nil { + t.Fatalf("ImportRuns().GetByRunID() error = %v", err) + } + if run.State != "completed" || run.CompletedItems != 1 || run.ActiveItems != 1 { + t.Fatalf("run = %+v, want completed with one active item", run) + } + + item, err := store.ImportRunItems().GetByItemID(context.Background(), "item-validate-1") + if err != nil { + t.Fatalf("ImportRunItems().GetByItemID() error = %v", err) + } + if item.CurrentStage != "done" || item.AccessStatus != "active" { + t.Fatalf("item = %+v, want done/active", item) + } + }) + + t.Run("confirm stage schedules retry when wait budget is zero", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/admin/accounts/account-busy/test": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("data: {\"type\":\"test_complete\",\"ok\":false,\"status\":\"failed\",\"message\":\"No available accounts\"}\n\n")) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "bearer", Token: "gateway-key"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + legacyBatchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: legacyBatchID, + HostID: hostPK, + ResourceType: "account", + HostResourceID: "account-busy", + ResourceName: "account-busy", + }); err != nil { + t.Fatalf("ManagedResources().Create(account) error = %v", err) + } + + mustCreateAppImportRun(t, store, sqlite.ImportRun{ + RunID: "run-confirm-1", + HostID: "host-1", + Mode: "partial", + AccessMode: "self_service", + State: "running", + TotalItems: 1, + }) + mustCreateAppImportRunItem(t, store, sqlite.ImportRunItem{ + ItemID: "item-confirm-1", + RunID: "run-confirm-1", + BaseURL: server.URL, + ProviderID: "deepseek", + APIKeyFingerprint: "sha256:confirm", + ResolvedSmokeModel: "deepseek-v4-pro", + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + LegacyBatchID: &legacyBatchID, + CapabilityProfileJSON: `{"transport_profile":{"known_advisories":[]}}`, + }) + + runner := batchImportRuntimeRunner{ + store: store, + hostClient: client, + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSelfService, + ProbeAPIKey: "gateway-key", + }, + } + if err := runner.driveRun(context.Background(), "run-confirm-1", 0); err != nil { + t.Fatalf("driveRun(confirm retry) error = %v", err) + } + + run, err := store.ImportRuns().GetByRunID(context.Background(), "run-confirm-1") + if err != nil { + t.Fatalf("ImportRuns().GetByRunID() error = %v", err) + } + if run.State != "running" || run.CompletedItems != 0 { + t.Fatalf("run = %+v, want still running with no completed items", run) + } + + item, err := store.ImportRunItems().GetByItemID(context.Background(), "item-confirm-1") + if err != nil { + t.Fatalf("ImportRunItems().GetByItemID() error = %v", err) + } + if item.CurrentStage != "confirm" || item.RetryCount != 1 || strings.TrimSpace(item.NextRetryAt) == "" { + t.Fatalf("item = %+v, want confirm stage with scheduled retry", item) + } + }) +} + +func TestResolveValidationAPIKeySubscriptionSuccess(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"): + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84, "email": "relay-sub-user-1@sub2api.local"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 401}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"access_token": "user-jwt"}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 501, "key": "sk-relay-key", "name": "managed-key"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "bearer", Token: "admin-token"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + batchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: batchID, + HostID: hostPK, + ResourceType: "group", + HostResourceID: "101", + ResourceName: "group-101", + }); err != nil { + t.Fatalf("ManagedResources().Create(group) error = %v", err) + } + + runner := batchImportRuntimeRunner{ + store: store, + hostClient: client, + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSubscription, + SubscriptionUsers: []string{"crm-user-1"}, + SubscriptionDays: 30, + }, + } + + apiKey, err := runner.resolveValidationAPIKey(context.Background(), sqlite.ImportRunItem{ + LegacyBatchID: &batchID, + }) + if err != nil { + t.Fatalf("resolveValidationAPIKey(subscription) error = %v", err) + } + if !strings.HasPrefix(apiKey, "sk-relay-") { + t.Fatalf("resolveValidationAPIKey(subscription) = %q, want managed sk-relay-* key", apiKey) + } +} + +func TestResolveValidationAPIKeySubscriptionRequiresStoredGroup(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"): + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84, "email": "relay-sub-user-1@sub2api.local"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 401}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"access_token": "user-jwt"}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 501, "key": "sk-relay-key", "name": "managed-key"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + client, err := newSub2APIClient(server.URL, CreateHostAuth{Type: "bearer", Token: "admin-token"}) + if err != nil { + t.Fatalf("newSub2APIClient() error = %v", err) + } + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + batchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + + runner := batchImportRuntimeRunner{ + store: store, + hostClient: client, + request: CreateBatchImportRunRequest{ + AccessMode: provision.AccessModeSubscription, + SubscriptionUsers: []string{"crm-user-1"}, + SubscriptionDays: 30, + }, + } + + _, err = runner.resolveValidationAPIKey(context.Background(), sqlite.ImportRunItem{ + LegacyBatchID: &batchID, + }) + if err == nil || err.Error() != fmt.Sprintf("%s resource not found for batch %d", "group", batchID) { + t.Fatalf("resolveValidationAPIKey(subscription missing group) error = %v, want missing group", err) + } +} + func TestCreateHostAuthFromLegacyFields(t *testing.T) { t.Parallel() @@ -474,6 +1062,81 @@ func mustExecSQL(t *testing.T, store *sqlite.DB, query string, args ...any) { } } +func createAppHostRecord(t *testing.T, store *sqlite.DB, baseURL string) int64 { + t.Helper() + id, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-" + sanitizeAppTestName(t.Name()), + BaseURL: baseURL, + HostVersion: "0.1.126", + AuthType: "apikey", + AuthToken: "host-token", + }) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + return id +} + +func createAppPackRecord(t *testing.T, store *sqlite.DB) int64 { + t.Helper() + id, err := store.Packs().Create(context.Background(), sqlite.Pack{ + PackID: "pack-" + sanitizeAppTestName(t.Name()), + Version: "1.0.0", + Checksum: "checksum-" + sanitizeAppTestName(t.Name()), + }) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + return id +} + +func createAppProviderRecord(t *testing.T, store *sqlite.DB, packID int64) int64 { + t.Helper() + id, err := store.Providers().Create(context.Background(), sqlite.Provider{ + PackID: packID, + ProviderID: "provider-" + sanitizeAppTestName(t.Name()), + DisplayName: "Provider", + BaseURL: "https://provider.example.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + return id +} + +func createAppBatchRecord(t *testing.T, store *sqlite.DB, hostID, packID, providerID int64) int64 { + t.Helper() + id, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostID, + PackID: packID, + ProviderID: providerID, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSelfServiceReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + return id +} + +func sanitizeAppTestName(name string) string { + var b strings.Builder + for _, c := range name { + switch { + case c >= 'a' && c <= 'z', c >= '0' && c <= '9': + b.WriteRune(c) + case c >= 'A' && c <= 'Z': + b.WriteRune(c + ('a' - 'A')) + } + } + if b.Len() == 0 { + return "default" + } + return b.String() +} + func TestResolveManagedHostAndNewSub2APIClient(t *testing.T) { t.Parallel() @@ -515,6 +1178,423 @@ func TestResolveManagedHostAndNewSub2APIClient(t *testing.T) { } } +func TestHandleAssignAccessSubscriptionsAndAccessPreview(t *testing.T) { + t.Parallel() + + t.Run("handleAssignAccessSubscriptions returns success", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/providers/deepseek/access/subscriptions", map[string]any{ + "host_id": "host-1", + "access_api_key": "user-key", + "subscription_users": []string{"user-1"}, + }, "") + req.SetPathValue("providerID", "deepseek") + rec := &responseRecorder{header: map[string][]string{}} + handleAssignAccessSubscriptions(rec, req, func(_ context.Context, got AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error) { + if got.ProviderID != "deepseek" || got.HostID != "host-1" || got.AccessAPIKey != "user-key" || len(got.SubscriptionUsers) != 1 || got.SubscriptionUsers[0] != "user-1" { + t.Fatalf("AssignAccessSubscriptionsRequest = %+v, want projected request", got) + } + return AssignAccessSubscriptionsResult{Assigned: 1}, nil + }) + assertStatusCode(t, rec, 200) + assertJSONContains(t, rec.Body().Bytes(), "assigned", float64(1)) + }) + + t.Run("handleAccessPreview falls back to query values", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/providers/deepseek/access/preview?pack_id=openai-cn-pack&host_id=host-1", map[string]any{ + "mode": "self_service", + }, "") + req.SetPathValue("providerID", "deepseek") + rec := &responseRecorder{header: map[string][]string{}} + handleAccessPreview(rec, req, func(_ context.Context, got AccessPreviewRequest) (AccessPreviewResult, error) { + if got.ProviderID != "deepseek" || got.PackID != "openai-cn-pack" || got.HostID != "host-1" { + t.Fatalf("AccessPreviewRequest = %+v, want query fallback values", got) + } + return AccessPreviewResult{Available: true, Mode: got.Mode}, nil + }) + assertStatusCode(t, rec, 200) + assertJSONContains(t, rec.Body().Bytes(), "available", true) + assertJSONContains(t, rec.Body().Bytes(), "mode", "self_service") + }) +} + +func TestAdditionalHTTPWrappers(t *testing.T) { + t.Parallel() + + t.Run("handleListProviderImportBatches returns query-scoped payload", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "GET", "/providers/deepseek/import-batches?pack_id=openai-cn-pack&host_id=host-1", map[string]any{}, "") + req.SetPathValue("providerID", "deepseek") + rec := &responseRecorder{header: map[string][]string{}} + handleListProviderImportBatches(rec, req, func(_ context.Context, got ProviderQueryRequest) ([]ImportBatchInfo, error) { + if got.ProviderID != "deepseek" || got.PackID != "openai-cn-pack" || got.HostID != "host-1" { + t.Fatalf("ProviderQueryRequest = %+v, want provider/pack/host filters", got) + } + return nil, nil + }) + assertStatusCode(t, rec, 200) + batches, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "batches") + if !ok || len(batches) != 0 { + t.Fatalf("batches = %#v, want empty array", batches) + } + }) + + t.Run("handleRollbackBatch validates batch id", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/import-batches/bad/rollback", map[string]any{}, "") + req.SetPathValue("batchID", "bad") + rec := &responseRecorder{header: map[string][]string{}} + handleRollbackBatch(rec, req, func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error) { + t.Fatal("rollback fn should not be called for invalid batch id") + return provision.RollbackReport{}, nil + }) + assertStatusCode(t, rec, 400) + assertJSONContains(t, rec.Body().Bytes(), "error.message", "batch_id must be a positive integer") + }) + + t.Run("handleRollbackBatch returns summary", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/import-batches/17/rollback", map[string]any{}, "") + req.SetPathValue("batchID", "17") + rec := &responseRecorder{header: map[string][]string{}} + handleRollbackBatch(rec, req, func(_ context.Context, got RollbackBatchRequest) (provision.RollbackReport, error) { + if got.BatchID != 17 { + t.Fatalf("RollbackBatchRequest.BatchID = %d, want 17", got.BatchID) + } + return provision.RollbackReport{AccountsDeleted: 2, PlansDeleted: 1, ChannelsDeleted: 1, GroupsDeleted: 1}, nil + }) + assertStatusCode(t, rec, 200) + assertJSONContains(t, rec.Body().Bytes(), "batch_id", float64(17)) + assertJSONContains(t, rec.Body().Bytes(), "deleted_accounts", float64(2)) + }) + + t.Run("handleCreateHost returns payload", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/hosts", map[string]any{ + "name": "host-1", + "base_url": "https://sub2api.example.com", + "auth": map[string]any{"type": "apikey", "token": "host-token"}, + }, "") + rec := &responseRecorder{header: map[string][]string{}} + handleCreateHost(rec, req, func(_ context.Context, got CreateHostRequest) (HostInfo, error) { + if got.Name != "host-1" || got.BaseURL != "https://sub2api.example.com" || got.Auth.Token != "host-token" { + t.Fatalf("CreateHostRequest = %+v, want decoded request", got) + } + return HostInfo{HostID: got.Name, BaseURL: got.BaseURL, HostVersion: "0.1.126", AuthType: "apikey"}, nil + }) + assertStatusCode(t, rec, 200) + assertJSONContains(t, rec.Body().Bytes(), "host_id", "host-1") + assertJSONContains(t, rec.Body().Bytes(), "base_url", "https://sub2api.example.com") + }) +} + +func TestActionSetPackClosures(t *testing.T) { + t.Parallel() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + packID := createAppPackRecord(t, store) + if _, err := store.Providers().Create(context.Background(), sqlite.Provider{ + PackID: packID, + ProviderID: "provider-a", + DisplayName: "Provider A", + BaseURL: "https://provider-a.example.com", + Platform: "openai", + }); err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + + packs, err := actions.ListPacks(context.Background()) + if err != nil { + t.Fatalf("ListPacks() error = %v", err) + } + if len(packs) != 1 || packs[0].PackID == "" { + t.Fatalf("ListPacks() = %+v, want single persisted pack", packs) + } + + packInfo, err := actions.GetPack(context.Background(), packs[0].PackID) + if err != nil { + t.Fatalf("GetPack() error = %v", err) + } + if packInfo.PackID != packs[0].PackID { + t.Fatalf("GetPack() = %+v, want pack_id %q", packInfo, packs[0].PackID) + } + + providers, err := actions.ListPackProviders(context.Background(), packs[0].PackID) + if err != nil { + t.Fatalf("ListPackProviders() error = %v", err) + } + if len(providers) != 1 || providers[0].ProviderID != "provider-a" { + t.Fatalf("ListPackProviders() = %+v, want provider-a", providers) + } +} + +func TestActionSetCreateHostClosure(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + actions := NewActionSet(appTestDSN(t, store)) + host, err := actions.CreateHost(context.Background(), CreateHostRequest{ + Name: "prod-sub2api", + BaseURL: server.URL, + Auth: CreateHostAuth{Type: "apikey", Token: "host-token"}, + }) + if err != nil { + t.Fatalf("CreateHost() error = %v", err) + } + if host.HostID != "prod-sub2api" || host.HostVersion != "0.1.126" || host.Status != "supported" { + t.Fatalf("CreateHost() = %+v, want stored supported host", host) + } + + stored, err := store.Hosts().GetByHostID(context.Background(), "prod-sub2api") + if err != nil { + t.Fatalf("Hosts().GetByHostID() error = %v", err) + } + if stored.BaseURL != server.URL || stored.AuthToken != "host-token" { + t.Fatalf("stored host = %+v, want persisted connection details", stored) + } +} + +func TestActionSetListProviderImportBatchesClosure(t *testing.T) { + t.Parallel() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostA, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-a", + BaseURL: "https://host-a.example.com", + HostVersion: "0.1.126", + AuthToken: "token-a", + }) + if err != nil { + t.Fatalf("Hosts().Create(host-a) error = %v", err) + } + hostB, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-b", + BaseURL: "https://host-b.example.com", + HostVersion: "0.1.126", + AuthToken: "token-b", + }) + if err != nil { + t.Fatalf("Hosts().Create(host-b) error = %v", err) + } + packID := createAppPackRecord(t, store) + providerID, err := store.Providers().Create(context.Background(), sqlite.Provider{ + PackID: packID, + ProviderID: "shared-provider", + DisplayName: "Shared Provider", + BaseURL: "https://provider.example.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostA, + PackID: packID, + ProviderID: providerID, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSelfServiceReady, + }); err != nil { + t.Fatalf("ImportBatches().Create(host-a) error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostB, + PackID: packID, + ProviderID: providerID, + Mode: provision.ImportModeStrict, + BatchStatus: provision.BatchStatusPartial, + AccessStatus: provision.AccessStatusBroken, + }); err != nil { + t.Fatalf("ImportBatches().Create(host-b) error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + if _, err := actions.ListProviderImportBatches(context.Background(), ProviderQueryRequest{ProviderID: "shared-provider"}); err == nil || err.Error() != "provider exists on multiple hosts; host_id is required" { + t.Fatalf("ListProviderImportBatches(multi-host) error = %v, want host_id required", err) + } + + batches, err := actions.ListProviderImportBatches(context.Background(), ProviderQueryRequest{ + ProviderID: "shared-provider", + HostID: "host-a", + }) + if err != nil { + t.Fatalf("ListProviderImportBatches(host filter) error = %v", err) + } + if len(batches) != 1 || batches[0].BatchStatus != provision.BatchStatusSucceeded { + t.Fatalf("ListProviderImportBatches(host filter) = %+v, want single succeeded batch", batches) + } +} + +func TestActionSetHostClosuresAndAccessPreview(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostID, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-main", + BaseURL: server.URL, + HostVersion: "0.0.1", + CapabilityProbeJSON: "{}", + AuthType: "apikey", + AuthToken: "host-token", + }) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + packID := createAppPackRecord(t, store) + providerID, err := store.Providers().Create(context.Background(), sqlite.Provider{ + PackID: packID, + ProviderID: "preview-provider", + DisplayName: "Preview Provider", + BaseURL: "https://provider.example.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostID, + PackID: packID, + ProviderID: providerID, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSelfServiceReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.AccessClosures().Create(context.Background(), sqlite.AccessClosureRecord{ + BatchID: batchID, + ClosureType: provision.AccessModeSelfService, + Status: provision.AccessStatusSelfServiceReady, + DetailsJSON: `{}`, + }); err != nil { + t.Fatalf("AccessClosures().Create() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + + hosts, err := actions.ListHosts(context.Background()) + if err != nil { + t.Fatalf("ListHosts() error = %v", err) + } + if len(hosts) != 1 || hosts[0].HostID != "host-main" { + t.Fatalf("ListHosts() = %+v, want host-main", hosts) + } + + hostInfo, err := actions.GetHost(context.Background(), "host-main") + if err != nil { + t.Fatalf("GetHost() error = %v", err) + } + if hostInfo.HostID != "host-main" || hostInfo.BaseURL != server.URL { + t.Fatalf("GetHost() = %+v, want stored host", hostInfo) + } + + probed, err := actions.ProbeHost(context.Background(), ProbeHostRequest{HostID: "host-main"}) + if err != nil { + t.Fatalf("ProbeHost() error = %v", err) + } + if probed.Status != "supported" || probed.HostVersion != "0.1.126" { + t.Fatalf("ProbeHost() = %+v, want supported host 0.1.126", probed) + } + + packRow, err := store.Packs().GetByID(context.Background(), packID) + if err != nil { + t.Fatalf("Packs().GetByID() error = %v", err) + } + preview, err := actions.AccessPreview(context.Background(), AccessPreviewRequest{ + ProviderID: "preview-provider", + PackID: packRow.PackID, + HostID: "host-main", + Mode: provision.AccessModeSelfService, + }) + if err != nil { + t.Fatalf("AccessPreview(self_service) error = %v", err) + } + if !preview.Available { + t.Fatalf("AccessPreview(self_service) = %+v, want available=true", preview) + } + + preview, err = actions.AccessPreview(context.Background(), AccessPreviewRequest{ + ProviderID: "preview-provider", + PackID: packRow.PackID, + HostID: "host-main", + Mode: provision.AccessModeSubscription, + }) + if err != nil { + t.Fatalf("AccessPreview(subscription) error = %v", err) + } + if preview.Available { + t.Fatalf("AccessPreview(subscription) = %+v, want available=false", preview) + } + + if err := actions.DeleteHost(context.Background(), "host-main"); err != nil { + t.Fatalf("DeleteHost() error = %v", err) + } + if _, err := store.Hosts().GetByHostID(context.Background(), "host-main"); err == nil { + t.Fatal("DeleteHost() did not remove host-main") + } +} + +func TestAdditionalHostHTTPWrappers(t *testing.T) { + t.Parallel() + + t.Run("handleListHosts returns empty array", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "GET", "/hosts", map[string]any{}, "") + rec := &responseRecorder{header: map[string][]string{}} + handleListHosts(rec, req, func(context.Context) ([]HostInfo, error) { return nil, nil }) + assertStatusCode(t, rec, 200) + hosts, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "hosts") + if !ok || len(hosts) != 0 { + t.Fatalf("hosts = %#v, want empty array", hosts) + } + }) + + t.Run("handleGetHost requires host id", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "GET", "/hosts/", map[string]any{}, "") + rec := &responseRecorder{header: map[string][]string{}} + handleGetHost(rec, req, func(context.Context, string) (HostInfo, error) { return HostInfo{}, nil }) + assertStatusCode(t, rec, 400) + assertJSONContains(t, rec.Body().Bytes(), "error.message", "host_id is required") + }) + + t.Run("handleProbeHost returns payload", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, "POST", "/hosts/host-1/probe", map[string]any{ + "auth": map[string]any{"type": "apikey", "token": "token"}, + }, "") + req.SetPathValue("hostID", "host-1") + rec := &responseRecorder{header: map[string][]string{}} + handleProbeHost(rec, req, func(_ context.Context, got ProbeHostRequest) (HostInfo, error) { + if got.HostID != "host-1" || got.Auth.Token != "token" { + t.Fatalf("ProbeHostRequest = %+v, want host-1/token", got) + } + return HostInfo{HostID: "host-1", HostVersion: "0.1.126", Status: "supported"}, nil + }) + assertStatusCode(t, rec, 200) + assertJSONContains(t, rec.Body().Bytes(), "host_id", "host-1") + assertJSONContains(t, rec.Body().Bytes(), "status", "supported") + }) +} + func TestHandlerWrappersForPackAndHostRoutes(t *testing.T) { t.Parallel() @@ -602,6 +1682,766 @@ func TestHandlerWrappersForPackAndHostRoutes(t *testing.T) { }) } +func TestAdditionalHTTPWrapperErrorBranches(t *testing.T) { + t.Parallel() + + t.Run("list and pack wrappers handle missing actions", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + path string + run func(*responseRecorder, *http.Request) + wantStatus int + wantCode string + }{ + { + name: "list packs", + path: "/packs", + run: func(rec *responseRecorder, req *http.Request) { + handleListPacks(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "list provider import batches", + path: "/providers/deepseek/import-batches", + run: func(rec *responseRecorder, req *http.Request) { + req.SetPathValue("providerID", "deepseek") + handleListProviderImportBatches(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "get pack", + path: "/packs/openai-cn-pack", + run: func(rec *responseRecorder, req *http.Request) { + req.SetPathValue("packID", "openai-cn-pack") + handleGetPack(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "list pack providers", + path: "/packs/openai-cn-pack/providers", + run: func(rec *responseRecorder, req *http.Request) { + req.SetPathValue("packID", "openai-cn-pack") + handleListPackProviders(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "rollback batch", + path: "/import-batches/11/rollback", + run: func(rec *responseRecorder, req *http.Request) { + req.SetPathValue("batchID", "11") + handleRollbackBatch(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "probe host", + path: "/hosts/host-1/probe", + run: func(rec *responseRecorder, req *http.Request) { + req.SetPathValue("hostID", "host-1") + handleProbeHost(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + { + name: "create host", + path: "/hosts", + run: func(rec *responseRecorder, req *http.Request) { + handleCreateHost(rec, req, nil) + }, + wantStatus: http.StatusInternalServerError, + wantCode: "server_misconfigured", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + req := httptestRequest(t, "POST", tc.path, map[string]any{}, "") + rec := &responseRecorder{header: map[string][]string{}} + tc.run(rec, req) + assertStatusCode(t, rec, tc.wantStatus) + assertJSONContains(t, rec.Body().Bytes(), "error.code", tc.wantCode) + }) + } + }) + + t.Run("access wrappers handle missing actions and decode errors", func(t *testing.T) { + t.Parallel() + + req := httptestRequest(t, "POST", "/providers/deepseek/access/subscriptions", map[string]any{}, "") + req.SetPathValue("providerID", "deepseek") + rec := &responseRecorder{header: map[string][]string{}} + handleAssignAccessSubscriptions(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured") + + req = httptestRequest(t, "POST", "/providers/deepseek/access/preview", map[string]any{}, "") + req.SetPathValue("providerID", "deepseek") + rec = &responseRecorder{header: map[string][]string{}} + handleAccessPreview(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured") + + badReq, err := http.NewRequest(http.MethodPost, "/providers/deepseek/access/subscriptions", strings.NewReader("{")) + if err != nil { + t.Fatalf("http.NewRequest(assign bad json) error = %v", err) + } + badReq.SetPathValue("providerID", "deepseek") + badReq.Header.Set("Content-Type", "application/json") + rec = &responseRecorder{header: map[string][]string{}} + handleAssignAccessSubscriptions(rec, badReq, func(context.Context, AssignAccessSubscriptionsRequest) (AssignAccessSubscriptionsResult, error) { + t.Fatal("assign action should not be called for bad json") + return AssignAccessSubscriptionsResult{}, nil + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "bad_request") + + badReq, err = http.NewRequest(http.MethodPost, "/providers/deepseek/access/preview", strings.NewReader("{")) + if err != nil { + t.Fatalf("http.NewRequest(preview bad json) error = %v", err) + } + badReq.SetPathValue("providerID", "deepseek") + badReq.Header.Set("Content-Type", "application/json") + rec = &responseRecorder{header: map[string][]string{}} + handleAccessPreview(rec, badReq, func(context.Context, AccessPreviewRequest) (AccessPreviewResult, error) { + t.Fatal("access preview action should not be called for bad json") + return AccessPreviewResult{}, nil + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "bad_request") + }) + + t.Run("pack and host wrappers classify action errors", func(t *testing.T) { + t.Parallel() + + req := httptestRequest(t, "GET", "/packs/openai-cn-pack", map[string]any{}, "") + req.SetPathValue("packID", "openai-cn-pack") + rec := &responseRecorder{header: map[string][]string{}} + handleGetPack(rec, req, func(context.Context, string) (PackInfo, error) { + return PackInfo{}, sql.ErrNoRows + }) + assertStatusCode(t, rec, http.StatusInternalServerError) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "internal_error") + + req = httptestRequest(t, "POST", "/import-batches/11/rollback", map[string]any{}, "") + req.SetPathValue("batchID", "11") + rec = &responseRecorder{header: map[string][]string{}} + handleRollbackBatch(rec, req, func(context.Context, RollbackBatchRequest) (provision.RollbackReport, error) { + return provision.RollbackReport{}, fmt.Errorf("batch 11 not found") + }) + assertStatusCode(t, rec, http.StatusNotFound) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "not_found") + + req = httptestRequest(t, "POST", "/hosts/host-1/probe", map[string]any{ + "auth": map[string]any{"type": "apikey", "token": "host-token"}, + }, "") + req.SetPathValue("hostID", "host-1") + rec = &responseRecorder{header: map[string][]string{}} + handleProbeHost(rec, req, func(context.Context, ProbeHostRequest) (HostInfo, error) { + return HostInfo{}, fmt.Errorf("host probe decode failed") + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "bad_request") + + req = httptestRequest(t, "POST", "/hosts", map[string]any{ + "name": "host-1", + }, "") + rec = &responseRecorder{header: map[string][]string{}} + handleCreateHost(rec, req, func(context.Context, CreateHostRequest) (HostInfo, error) { + return HostInfo{}, fmt.Errorf("base_url is required") + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "bad_request") + }) +} + +func TestActionSetBatchDetailAndProviderSnapshotClosures(t *testing.T) { + t.Parallel() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, "https://sub2api.example.com") + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + batchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + + if _, err := store.ImportBatchItems().Create(context.Background(), sqlite.ImportBatchItem{ + BatchID: batchID, + KeyFingerprint: "sha256:test", + AccountStatus: "passed", + }); err != nil { + t.Fatalf("ImportBatchItems().Create() error = %v", err) + } + for _, resource := range []sqlite.ManagedResource{ + {BatchID: batchID, HostID: hostPK, ResourceType: "group", HostResourceID: "group-1", ResourceName: "group-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "account", HostResourceID: "account-1", ResourceName: "account-1"}, + } { + if _, err := store.ManagedResources().Create(context.Background(), resource); err != nil { + t.Fatalf("ManagedResources().Create(%s) error = %v", resource.ResourceType, err) + } + } + if _, err := store.AccessClosures().Create(context.Background(), sqlite.AccessClosureRecord{ + BatchID: batchID, + ClosureType: provision.AccessModeSelfService, + Status: provision.AccessStatusSelfServiceReady, + DetailsJSON: `{"ok":true}`, + }); err != nil { + t.Fatalf("AccessClosures().Create() error = %v", err) + } + if _, err := store.ReconcileRuns().Create(context.Background(), sqlite.ReconcileRun{ + BatchID: batchID, + HostID: hostPK, + ProviderID: providerPK, + Status: "drifted", + SummaryJSON: `{"missing_count":1}`, + }); err != nil { + t.Fatalf("ReconcileRuns().Create() error = %v", err) + } + + packRow, err := store.Packs().GetByID(context.Background(), packPK) + if err != nil { + t.Fatalf("Packs().GetByID() error = %v", err) + } + providerRow, err := store.Providers().GetByID(context.Background(), providerPK) + if err != nil { + t.Fatalf("Providers().GetByID() error = %v", err) + } + hostRow, err := store.Hosts().GetByID(context.Background(), hostPK) + if err != nil { + t.Fatalf("Hosts().GetByID() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + + detail, err := actions.BatchDetail(context.Background(), BatchDetailRequest{BatchID: batchID}) + if err != nil { + t.Fatalf("BatchDetail() error = %v", err) + } + if detail.Batch.ID != batchID || len(detail.Items) != 1 || len(detail.ManagedResources) != 2 || len(detail.AccessClosures) != 1 || len(detail.ReconcileRuns) != 1 { + t.Fatalf("BatchDetail() = %+v, want populated batch detail", detail) + } + + snapshotReq := ProviderQueryRequest{ + ProviderID: providerRow.ProviderID, + PackID: packRow.PackID, + HostID: hostRow.HostID, + } + statusSnapshot, err := actions.GetProviderStatus(context.Background(), snapshotReq) + if err != nil { + t.Fatalf("GetProviderStatus() error = %v", err) + } + if statusSnapshot.Batch.ID != batchID || statusSnapshot.LatestAccessStatus != provision.AccessStatusSelfServiceReady || statusSnapshot.LatestReconcileStatus != "drifted" || statusSnapshot.ProviderStatus != provision.ProviderStatusActive { + t.Fatalf("GetProviderStatus() = %+v, want active snapshot with drifted reconcile metadata", statusSnapshot) + } + if got := statusSnapshot.LatestReconcileSummary["missing_count"]; got != float64(1) { + t.Fatalf("LatestReconcileSummary[missing_count] = %#v, want 1", got) + } + + resourceSnapshot, err := actions.GetProviderResources(context.Background(), snapshotReq) + if err != nil { + t.Fatalf("GetProviderResources() error = %v", err) + } + if len(resourceSnapshot.ManagedResources) != 2 || len(resourceSnapshot.AccessClosures) != 1 || len(resourceSnapshot.ReconcileRuns) != 1 { + t.Fatalf("GetProviderResources() = %+v, want persisted resources/closures/runs", resourceSnapshot) + } + + accessSnapshot, err := actions.GetProviderAccessStatus(context.Background(), snapshotReq) + if err != nil { + t.Fatalf("GetProviderAccessStatus() error = %v", err) + } + if accessSnapshot.Batch.ID != batchID || accessSnapshot.LatestAccessStatus != provision.AccessStatusSelfServiceReady || len(accessSnapshot.AccessClosures) != 1 { + t.Fatalf("GetProviderAccessStatus() = %+v, want latest access closure", accessSnapshot) + } +} + +func TestActionSetRollbackBatchClosure(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodDelete { + t.Fatalf("method = %s, want DELETE", r.Method) + } + switch r.URL.Path { + case "/api/v1/admin/accounts/account-1", + "/api/v1/admin/payment/plans/plan-1", + "/api/v1/admin/channels/channel-1", + "/api/v1/admin/groups/group-1": + w.WriteHeader(http.StatusNoContent) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + packPK := createAppPackRecord(t, store) + providerPK := createAppProviderRecord(t, store, packPK) + batchID := createAppBatchRecord(t, store, hostPK, packPK, providerPK) + for _, resource := range []sqlite.ManagedResource{ + {BatchID: batchID, HostID: hostPK, ResourceType: "group", HostResourceID: "group-1", ResourceName: "group-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "channel", HostResourceID: "channel-1", ResourceName: "channel-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "plan", HostResourceID: "plan-1", ResourceName: "plan-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "account", HostResourceID: "account-1", ResourceName: "account-1"}, + } { + if _, err := store.ManagedResources().Create(context.Background(), resource); err != nil { + t.Fatalf("ManagedResources().Create(%s) error = %v", resource.ResourceType, err) + } + } + + actions := NewActionSet(appTestDSN(t, store)) + report, err := actions.RollbackBatch(context.Background(), RollbackBatchRequest{BatchID: batchID}) + if err != nil { + t.Fatalf("RollbackBatch() error = %v", err) + } + if report.AccountsDeleted != 1 || report.PlansDeleted != 1 || report.ChannelsDeleted != 1 || report.GroupsDeleted != 1 { + t.Fatalf("RollbackBatch() = %+v, want one deleted resource per type", report) + } + + batchRow, err := store.ImportBatches().GetByID(context.Background(), batchID) + if err != nil { + t.Fatalf("ImportBatches().GetByID() error = %v", err) + } + if batchRow.BatchStatus != provision.BatchStatusRolledBack { + t.Fatalf("batch status = %q, want %q", batchRow.BatchStatus, provision.BatchStatusRolledBack) + } +} + +func TestActionSetCreateHostUpdateAndConflict(t *testing.T) { + t.Parallel() + + t.Run("update existing host connection", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + if _, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "prod-sub2api", + BaseURL: server.URL, + HostVersion: "0.0.1", + CapabilityProbeJSON: "{}", + AuthType: "apikey", + AuthToken: "old-token", + }); err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + host, err := actions.CreateHost(context.Background(), CreateHostRequest{ + Name: "prod-sub2api", + BaseURL: server.URL, + Auth: CreateHostAuth{Type: "apikey", Token: "host-token"}, + }) + if err != nil { + t.Fatalf("CreateHost(update) error = %v", err) + } + if host.HostVersion != "0.1.126" || host.Status != "supported" { + t.Fatalf("CreateHost(update) = %+v, want reprobed supported host", host) + } + + stored, err := store.Hosts().GetByHostID(context.Background(), "prod-sub2api") + if err != nil { + t.Fatalf("Hosts().GetByHostID() error = %v", err) + } + if stored.AuthToken != "host-token" || stored.HostVersion != "0.1.126" { + t.Fatalf("stored host = %+v, want updated token/version", stored) + } + }) + + t.Run("base url conflict returns 409", func(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(newBatchImportActionStubServer(t)) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + if _, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "existing-host", + BaseURL: server.URL, + HostVersion: "0.1.126", + CapabilityProbeJSON: "{}", + AuthType: "apikey", + AuthToken: "host-token", + }); err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + _, err := actions.CreateHost(context.Background(), CreateHostRequest{ + Name: "new-host", + BaseURL: server.URL, + Auth: CreateHostAuth{Type: "apikey", Token: "host-token"}, + }) + if err == nil { + t.Fatal("CreateHost(conflict) error = nil, want conflict") + } + httpErr, ok := err.(*httpError) + if !ok || httpErr.StatusCode != http.StatusConflict { + t.Fatalf("CreateHost(conflict) error = %T %v, want *httpError 409", err, err) + } + }) +} + +func TestActionSetInstallPreviewAndRollbackProviderClosures(t *testing.T) { + t.Parallel() + + baseHandler := newBatchImportActionStubServer(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodDelete { + switch r.URL.Path { + case "/api/v1/admin/accounts/account-1", + "/api/v1/admin/payment/plans/plan-1", + "/api/v1/admin/channels/channel-1", + "/api/v1/admin/groups/group-1": + w.WriteHeader(http.StatusNoContent) + return + } + } + baseHandler.ServeHTTP(w, r) + })) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + actions := NewActionSet(appTestDSN(t, store)) + packPath := filepath.Join("..", "..", "packs", "openai-cn-pack") + + installResult, err := actions.InstallPack(context.Background(), InstallPackRequest{ + HostBaseURL: server.URL, + HostAPIKey: "host-token", + PackPath: packPath, + }) + if err != nil { + t.Fatalf("InstallPack() error = %v", err) + } + if installResult.Pack.PackID != "openai-cn-pack" || len(installResult.Providers) == 0 { + t.Fatalf("InstallPack() = %+v, want persisted pack with providers", installResult) + } + + preview, err := actions.PreviewProvider(context.Background(), PreviewProviderRequest{ + HostBaseURL: server.URL, + PackPath: packPath, + ProviderID: "deepseek", + Mode: provision.ImportModePartial, + Keys: []string{" key-1 ", "key-2"}, + }) + if err != nil { + t.Fatalf("PreviewProvider() error = %v", err) + } + if len(preview.AcceptedKeys) != 2 || preview.Decisions["group"].Action != provision.PreviewActionCreate { + t.Fatalf("PreviewProvider() = %+v, want accepted keys and create decisions", preview) + } + + packRow, err := store.Packs().GetByPackID(context.Background(), "openai-cn-pack") + if err != nil { + t.Fatalf("Packs().GetByPackID() error = %v", err) + } + providerRow, err := store.Providers().GetByPackIDAndProviderID(context.Background(), packRow.ID, "deepseek") + if err != nil { + t.Fatalf("Providers().GetByPackIDAndProviderID() error = %v", err) + } + batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostPK, + PackID: packRow.ID, + ProviderID: providerRow.ID, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSelfServiceReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + for _, resource := range []sqlite.ManagedResource{ + {BatchID: batchID, HostID: hostPK, ResourceType: "group", HostResourceID: "group-1", ResourceName: "group-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "channel", HostResourceID: "channel-1", ResourceName: "channel-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "plan", HostResourceID: "plan-1", ResourceName: "plan-1"}, + {BatchID: batchID, HostID: hostPK, ResourceType: "account", HostResourceID: "account-1", ResourceName: "account-1"}, + } { + if _, err := store.ManagedResources().Create(context.Background(), resource); err != nil { + t.Fatalf("ManagedResources().Create(%s) error = %v", resource.ResourceType, err) + } + } + + report, err := actions.RollbackProvider(context.Background(), RollbackProviderRequest{ + HostBaseURL: server.URL, + PackPath: packPath, + ProviderID: "deepseek", + }) + if err != nil { + t.Fatalf("RollbackProvider() error = %v", err) + } + if report.AccountsDeleted != 1 || report.PlansDeleted != 1 || report.ChannelsDeleted != 1 || report.GroupsDeleted != 1 { + t.Fatalf("RollbackProvider() = %+v, want one deleted resource per type", report) + } + + batchRow, err := store.ImportBatches().GetByID(context.Background(), batchID) + if err != nil { + t.Fatalf("ImportBatches().GetByID() error = %v", err) + } + if batchRow.BatchStatus != provision.BatchStatusRolledBack { + t.Fatalf("batch status = %q, want %q", batchRow.BatchStatus, provision.BatchStatusRolledBack) + } +} + +func TestActionSetAssignAccessSubscriptionsClosure(t *testing.T) { + t.Parallel() + + baseHandler := newBatchImportActionStubServer(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"): + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}}}) + return + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84, "email": "relay-sub-user-1@sub2api.local"}}) + return + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + return + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + return + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"access_token": "user-jwt"}}) + return + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 501, "key": "sk-relay-test", "name": "managed-key"}}) + return + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}}) + return + case r.URL.Path == "/v1/models" && strings.HasPrefix(strings.TrimSpace(r.Header.Get("Authorization")), "Bearer sk-relay-"): + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "deepseek-v4-pro"}}}) + return + case r.URL.Path == "/v1/chat/completions" && strings.HasPrefix(strings.TrimSpace(r.Header.Get("Authorization")), "Bearer sk-relay-"): + writeJSON(w, http.StatusOK, map[string]any{ + "id": "chatcmpl_subscription", + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "pong", + }, + }}, + }) + return + } + baseHandler.ServeHTTP(w, r) + })) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + hostPK := createAppHostRecord(t, store, server.URL) + actions := NewActionSet(appTestDSN(t, store)) + packPath := filepath.Join("..", "..", "packs", "openai-cn-pack") + + if _, err := actions.InstallPack(context.Background(), InstallPackRequest{ + HostBaseURL: server.URL, + HostAPIKey: "host-token", + PackPath: packPath, + }); err != nil { + t.Fatalf("InstallPack() error = %v", err) + } + packRow, err := store.Packs().GetByPackID(context.Background(), "openai-cn-pack") + if err != nil { + t.Fatalf("Packs().GetByPackID() error = %v", err) + } + providerRow, err := store.Providers().GetByPackIDAndProviderID(context.Background(), packRow.ID, "deepseek") + if err != nil { + t.Fatalf("Providers().GetByPackIDAndProviderID() error = %v", err) + } + batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostPK, + PackID: packRow.ID, + ProviderID: providerRow.ID, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSelfServiceReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: batchID, + HostID: hostPK, + ResourceType: "group", + HostResourceID: "101", + ResourceName: "group-101", + }); err != nil { + t.Fatalf("ManagedResources().Create(group) error = %v", err) + } + + result, err := actions.AssignAccessSubscriptions(context.Background(), AssignAccessSubscriptionsRequest{ + HostBaseURL: server.URL, + PackPath: packPath, + ProviderID: "deepseek", + SubscriptionUsers: []string{"crm-user-1"}, + SubscriptionDays: 30, + }) + if err != nil { + t.Fatalf("AssignAccessSubscriptions() error = %v", err) + } + if result.Assigned != 1 || result.AccessStatus != provision.AccessStatusSubscriptionReady { + t.Fatalf("AssignAccessSubscriptions() = %+v, want one assigned subscription_ready result", result) + } + + closures, err := store.AccessClosures().GetByBatchID(context.Background(), batchID) + if err != nil { + t.Fatalf("AccessClosures().GetByBatchID() error = %v", err) + } + if len(closures) != 1 || closures[0].Status != provision.AccessStatusSubscriptionReady { + t.Fatalf("AccessClosures() = %+v, want persisted subscription_ready closure", closures) + } + batchRow, err := store.ImportBatches().GetByID(context.Background(), batchID) + if err != nil { + t.Fatalf("ImportBatches().GetByID() error = %v", err) + } + if batchRow.AccessStatus != provision.AccessStatusSubscriptionReady { + t.Fatalf("batch access_status = %q, want %q", batchRow.AccessStatus, provision.AccessStatusSubscriptionReady) + } +} + +func TestActionSetImportAndReconcileProviderClosures(t *testing.T) { + t.Parallel() + + baseHandler := newBatchImportActionStubServer(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/admin/groups": + if r.Method == http.MethodGet { + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "group_1", "name": "DeepSeek 默认分组-self-service"}}}) + return + } + if r.Method == http.MethodPut { + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": "group_1", "name": "DeepSeek 默认分组-self-service"}}) + return + } + case "/api/v1/admin/channels": + if r.Method == http.MethodGet { + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "channel_1", "name": "DeepSeek 默认渠道-self-service"}}}) + return + } + case "/api/v1/admin/groups/group_1": + if r.Method == http.MethodPut { + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": "group_1", "name": "DeepSeek 默认分组-self-service"}}) + return + } + case "/api/v1/admin/channels/channel_1": + if r.Method == http.MethodPut { + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": "channel_1", "name": "DeepSeek 默认渠道-self-service"}}) + return + } + case "/api/v1/admin/accounts": + if r.Method == http.MethodGet { + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{{"id": "account_1", "name": "deepseek-01"}}, "pages": 1}}) + return + } + case "/api/v1/admin/accounts/account_1/models": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{{"id": "deepseek-v4-pro", "display_name": "DeepSeek V4 Pro", "type": "chat"}}}}) + return + case "/api/v1/admin/accounts/account_1": + if r.Method == http.MethodDelete { + w.WriteHeader(http.StatusNoContent) + return + } + case "/v1/models": + switch strings.TrimSpace(r.Header.Get("Authorization")) { + case "Bearer entry-key", "Bearer gateway-key": + writeJSON(w, http.StatusOK, map[string]any{"data": []map[string]any{{"id": "deepseek-v4-pro"}}}) + return + } + case "/v1/chat/completions": + switch strings.TrimSpace(r.Header.Get("Authorization")) { + case "Bearer entry-key", "Bearer gateway-key": + writeJSON(w, http.StatusOK, map[string]any{ + "id": "chatcmpl_deepseek", + "choices": []map[string]any{{ + "index": 0, + "message": map[string]any{ + "role": "assistant", + "content": "pong", + }, + }}, + }) + return + } + } + baseHandler.ServeHTTP(w, r) + })) + defer server.Close() + + store := openAppTestStore(t) + defer closeAppTestStore(t, store) + + if _, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-deepseek", + BaseURL: server.URL, + HostVersion: "0.1.126", + CapabilityProbeJSON: "{}", + AuthType: "apikey", + AuthToken: "host-token", + }); err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + + actions := NewActionSet(appTestDSN(t, store)) + packPath := filepath.Join("..", "..", "packs", "openai-cn-pack") + + importResult, err := actions.ImportProvider(context.Background(), ImportProviderRequest{ + HostID: "host-deepseek", + PackPath: packPath, + ProviderID: "deepseek", + Keys: []string{"entry-key"}, + Mode: provision.ImportModePartial, + AccessMode: provision.AccessModeSelfService, + AccessAPIKey: "gateway-key", + }) + if err != nil { + t.Fatalf("ImportProvider() error = %v", err) + } + if importResult.BatchID <= 0 || importResult.Report.BatchStatus != provision.BatchStatusSucceeded || importResult.Report.AccessStatus != provision.AccessStatusSelfServiceReady { + t.Fatalf("ImportProvider() = %+v, want succeeded self_service_ready batch", importResult) + } + + reconcileResult, err := actions.ReconcileProvider(context.Background(), ReconcileProviderRequest{ + HostID: "host-deepseek", + PackPath: packPath, + ProviderID: "deepseek", + AccessAPIKey: "gateway-key", + }) + if err != nil { + t.Fatalf("ReconcileProvider() error = %v", err) + } + if reconcileResult.BatchID != importResult.BatchID || reconcileResult.Status != "active" || reconcileResult.AccessStatus != provision.AccessStatusSelfServiceReady { + t.Fatalf("ReconcileProvider() = %+v, want active reconcile for imported batch", reconcileResult) + } +} + func TestBuildListBatchImportRunItemsActionCursor(t *testing.T) { t.Parallel() diff --git a/internal/app/http_batch_import_test.go b/internal/app/http_batch_import_test.go index 7fa4b0c2..3ceb3118 100644 --- a/internal/app/http_batch_import_test.go +++ b/internal/app/http_batch_import_test.go @@ -202,6 +202,64 @@ func TestBatchImportHTTP(t *testing.T) { }) } +func TestBatchImportWrapperFunctions(t *testing.T) { + t.Parallel() + + t.Run("handleCreateBatchImportRun requires action", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodPost, "/api/batch-import/runs", map[string]any{}, "") + rec := &responseRecorder{header: map[string][]string{}} + handleCreateBatchImportRun(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured") + }) + + t.Run("handleCreateBatchImportRun classifies action error", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodPost, "/api/batch-import/runs", map[string]any{ + "host_id": "host-1", + "mode": "strict", + "access_mode": "self_service", + "probe_api_key": "probe-key", + "entries": []map[string]any{ + {"base_url": "https://kimi.example.com/v1", "api_key": "sk-test"}, + }, + }, "") + rec := &responseRecorder{header: map[string][]string{}} + handleCreateBatchImportRun(rec, req, func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) { + return BatchImportRunCreateResponse{}, fmt.Errorf("host x not found") + }) + assertStatusCode(t, rec, http.StatusNotFound) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "not_found") + }) + + t.Run("handleListBatchImportRuns requires action", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs", nil, "") + rec := &responseRecorder{header: map[string][]string{}} + handleListBatchImportRuns(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured") + }) + + t.Run("handleListBatchImportRuns returns empty array", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs?limit=5", nil, "") + rec := &responseRecorder{header: map[string][]string{}} + handleListBatchImportRuns(rec, req, func(_ context.Context, got ListBatchImportRunsRequest) (ListBatchImportRunsResponse, error) { + if got.Limit != 5 { + t.Fatalf("ListBatchImportRunsRequest.Limit = %d, want 5", got.Limit) + } + return ListBatchImportRunsResponse{}, nil + }) + assertStatusCode(t, rec, http.StatusOK) + runs, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "runs") + if !ok || len(runs) != 0 { + t.Fatalf("runs = %#v, want empty array", runs) + } + }) +} + func newBatchImportActionStubServer(t *testing.T) http.Handler { t.Helper() diff --git a/internal/app/http_batch_runs_test.go b/internal/app/http_batch_runs_test.go index a9aff0b3..b7ce5af8 100644 --- a/internal/app/http_batch_runs_test.go +++ b/internal/app/http_batch_runs_test.go @@ -173,6 +173,72 @@ func TestBatchRunsHTTP(t *testing.T) { }) } +func TestBatchRunWrapperFunctions(t *testing.T) { + t.Parallel() + + t.Run("handleGetBatchImportRun validates inputs", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1", nil, "") + rec := &responseRecorder{header: map[string][]string{}} + handleGetBatchImportRun(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + + req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs/", nil, "") + rec = &responseRecorder{header: map[string][]string{}} + handleGetBatchImportRun(rec, req, func(context.Context, string) (batch.RunSummaryProjection, error) { + return batch.RunSummaryProjection{}, nil + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.message", "run_id is required") + }) + + t.Run("handleListBatchImportRunItems validates run id and empty result", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items?has_warning=true", nil, "") + rec := &responseRecorder{header: map[string][]string{}} + handleListBatchImportRunItems(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + + req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs//items?has_warning=true", nil, "") + rec = &responseRecorder{header: map[string][]string{}} + handleListBatchImportRunItems(rec, req, func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) { + return ListBatchImportRunItemsResponse{}, nil + }) + assertStatusCode(t, rec, http.StatusBadRequest) + + req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items?has_warning=true&limit=3", nil, "") + req.SetPathValue("run_id", "run-1") + rec = &responseRecorder{header: map[string][]string{}} + handleListBatchImportRunItems(rec, req, func(_ context.Context, got ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) { + if got.RunID != "run-1" || got.Limit != 3 || got.HasWarning == nil || !*got.HasWarning { + t.Fatalf("ListBatchImportRunItemsRequest = %+v, want parsed filters", got) + } + return ListBatchImportRunItemsResponse{}, nil + }) + assertStatusCode(t, rec, http.StatusOK) + items, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "items") + if !ok || len(items) != 0 { + t.Fatalf("items = %#v, want empty array", items) + } + }) + + t.Run("handleGetBatchImportRunItem validates ids", func(t *testing.T) { + t.Parallel() + req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items/item-1", nil, "") + rec := &responseRecorder{header: map[string][]string{}} + handleGetBatchImportRunItem(rec, req, nil) + assertStatusCode(t, rec, http.StatusInternalServerError) + + req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs//items/", nil, "") + rec = &responseRecorder{header: map[string][]string{}} + handleGetBatchImportRunItem(rec, req, func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error) { + return batch.ItemDetailProjection{}, nil + }) + assertStatusCode(t, rec, http.StatusBadRequest) + assertJSONContains(t, rec.Body().Bytes(), "error.message", "run_id and item_id are required") + }) +} + func batchProbeProfile() probe.CapabilityProfile { return probe.CapabilityProfile{ TransportProfile: probe.TransportProfile{ diff --git a/internal/app/reconcile_background_test.go b/internal/app/reconcile_background_test.go index b4bfbc52..475d5204 100644 --- a/internal/app/reconcile_background_test.go +++ b/internal/app/reconcile_background_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "net/http" "net/http/httptest" "path/filepath" "strings" @@ -467,6 +468,194 @@ func TestReconcileProbeAPIKeyRejectsUnsupportedClosureType(t *testing.T) { } } +func TestReconcileProbeAPIKeyRequiresAccessClosure(t *testing.T) { + t.Parallel() + + store := openReconcileBackgroundTestStore(t) + defer closeAppTestStore(t, store) + + batchID, _, _ := seedReconcileBackgroundBatch(t, store) + hostRow := mustGetBackgroundHost(t, store) + + _, err := reconcileProbeAPIKey(context.Background(), store, hostRow, sqlite.ImportBatch{ID: batchID}, nil) + if err == nil || err.Error() != fmt.Sprintf("access closure not found for batch %d", batchID) { + t.Fatalf("reconcileProbeAPIKey() error = %v, want missing access closure", err) + } +} + +func TestReconcileProbeAPIKeySubscriptionSuccess(t *testing.T) { + t.Parallel() + + var assignCalls int + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"): + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84, "email": "relay-sub-user-1@sub2api.local"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign": + assignCalls++ + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 401}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"access_token": "user-jwt"}}) + case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 501, "key": "sk-relay-key", "name": "managed-key"}}) + case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501": + writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}}) + default: + http.NotFound(w, r) + } + })) + defer server.Close() + + store := openReconcileBackgroundTestStore(t) + defer closeAppTestStore(t, store) + + hostPK, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-subscription", + BaseURL: server.URL, + HostVersion: "0.1.126", + CapabilityProbeJSON: "{}", + AuthType: "bearer", + AuthToken: "admin-token", + }) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + packPK := createBackgroundPack(t, store) + providerPK := createBackgroundProvider(t, store, packPK) + batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostPK, + PackID: packPK, + ProviderID: providerPK, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSubscriptionReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: batchID, + HostID: hostPK, + ResourceType: "group", + HostResourceID: "101", + ResourceName: "group-101", + }); err != nil { + t.Fatalf("ManagedResources().Create(group) error = %v", err) + } + + hostRow := sqlite.Host{HostID: "host-subscription", BaseURL: server.URL, AuthType: "bearer", AuthToken: "admin-token"} + got, err := reconcileProbeAPIKey(context.Background(), store, hostRow, sqlite.ImportBatch{ID: batchID}, []sqlite.AccessClosureRecord{{ + BatchID: batchID, + ClosureType: provision.AccessModeSubscription, + Status: provision.AccessStatusSubscriptionReady, + DetailsJSON: `{"subscription_users":["crm-user-1"],"subscription_days":0}`, + }}) + if err != nil { + t.Fatalf("reconcileProbeAPIKey() error = %v", err) + } + if !strings.HasPrefix(got, "sk-relay-") { + t.Fatalf("reconcileProbeAPIKey() = %q, want sk-relay-*", got) + } + if assignCalls != 1 { + t.Fatalf("subscription assign calls = %d, want 1 (EnsureSubscriptionAccess only)", assignCalls) + } +} + +func TestReconcileProbeAPIKeySubscriptionRequiresHostAuth(t *testing.T) { + t.Parallel() + + store := openReconcileBackgroundTestStore(t) + defer closeAppTestStore(t, store) + + hostPK, err := store.Hosts().Create(context.Background(), sqlite.Host{ + HostID: "host-subscription", + BaseURL: "https://sub2api.example.com", + HostVersion: "0.1.126", + CapabilityProbeJSON: "{}", + AuthType: "bearer", + AuthToken: "admin-token", + }) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + packPK := createBackgroundPack(t, store) + providerPK := createBackgroundProvider(t, store, packPK) + batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{ + HostID: hostPK, + PackID: packPK, + ProviderID: providerPK, + Mode: provision.ImportModePartial, + BatchStatus: provision.BatchStatusSucceeded, + AccessStatus: provision.AccessStatusSubscriptionReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{ + BatchID: batchID, + HostID: hostPK, + ResourceType: "group", + HostResourceID: "101", + ResourceName: "group-101", + }); err != nil { + t.Fatalf("ManagedResources().Create(group) error = %v", err) + } + + _, err = reconcileProbeAPIKey(context.Background(), store, sqlite.Host{BaseURL: "https://sub2api.example.com", AuthType: "bearer"}, sqlite.ImportBatch{ID: batchID}, []sqlite.AccessClosureRecord{{ + BatchID: batchID, + ClosureType: provision.AccessModeSubscription, + Status: provision.AccessStatusSubscriptionReady, + DetailsJSON: `{"subscription_users":["crm-user-1"],"subscription_days":30}`, + }}) + if err == nil || !strings.Contains(err.Error(), "auth.token is required") { + t.Fatalf("reconcileProbeAPIKey() error = %v, want auth.token is required", err) + } +} + +func createBackgroundPack(t *testing.T, store *sqlite.DB) int64 { + t.Helper() + + packPK, err := store.Packs().Create(context.Background(), sqlite.Pack{ + PackID: "openai-cn-pack", + Version: "1.0.0", + Checksum: "checksum-1", + Vendor: "OpenAI CN", + TargetHost: "sub2api", + MinHostVersion: "0.1.126", + MaxHostVersion: "0.2.x", + ManifestJSON: `{"pack_id":"openai-cn-pack","version":"1.0.0","target_host":"sub2api"}`, + }) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + return packPK +} + +func createBackgroundProvider(t *testing.T, store *sqlite.DB, packPK int64) int64 { + t.Helper() + + providerPK, err := store.Providers().Create(context.Background(), sqlite.Provider{ + PackID: packPK, + ProviderID: "deepseek", + DisplayName: "DeepSeek", + BaseURL: "https://api.example.com", + Platform: "openai", + AccountType: "openai", + SmokeTestModel: "deepseek-chat", + ManifestJSON: `{"provider_id":"deepseek","base_url":"https://api.example.com","platform":"openai","account_type":"openai","smoke_test_model":"deepseek-chat"}`, + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + return providerPK +} + func mustGetBackgroundHost(t *testing.T, store *sqlite.DB) sqlite.Host { t.Helper() diff --git a/internal/provision/import_service_test.go b/internal/provision/import_service_test.go index c71b8f6b..96efb90b 100644 --- a/internal/provision/import_service_test.go +++ b/internal/provision/import_service_test.go @@ -671,6 +671,7 @@ type fakeHostAdapter struct { testedModels map[string]string disableResponsesCalls int disabledResponsesAccountIDs []string + deleteErrors map[string]error } func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) { @@ -688,6 +689,9 @@ func (f *fakeHostAdapter) CreateGroup(_ context.Context, req sub2api.CreateGroup return sub2api.GroupRef{ID: "group_1", Name: "g"}, nil } func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error { + if err := f.deleteErrors["group:"+groupID]; err != nil { + return err + } f.deletedResources = append(f.deletedResources, "group:"+groupID) return nil } @@ -703,6 +707,9 @@ func (f *fakeHostAdapter) UpdateChannel(_ context.Context, channelID string, req return nil } func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error { + if err := f.deleteErrors["channel:"+channelID]; err != nil { + return err + } f.deletedResources = append(f.deletedResources, "channel:"+channelID) return nil } @@ -711,6 +718,9 @@ func (f *fakeHostAdapter) CreatePlan(context.Context, sub2api.CreatePlanRequest) return sub2api.PlanRef{ID: "plan_1", Name: "p"}, nil } func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error { + if err := f.deleteErrors["plan:"+planID]; err != nil { + return err + } f.deletedResources = append(f.deletedResources, "plan:"+planID) return nil } @@ -725,6 +735,9 @@ func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, req sub2api.Bat return f.batchAccounts, nil } func (f *fakeHostAdapter) DeleteAccount(_ context.Context, accountID string) error { + if err := f.deleteErrors["account:"+accountID]; err != nil { + return err + } f.callSequence = append(f.callSequence, "deleteAccount:"+accountID) f.deletedResources = append(f.deletedResources, "account:"+accountID) return nil diff --git a/internal/provision/provider_status_service_test.go b/internal/provision/provider_status_service_test.go index 283077be..deb3f30d 100644 --- a/internal/provision/provider_status_service_test.go +++ b/internal/provision/provider_status_service_test.go @@ -188,3 +188,103 @@ func TestProviderStatusServiceRequiresPackIDWhenProviderIDIsAmbiguous(t *testing t.Fatalf("GetStatus() error = %v, want ambiguous provider error", err) } } + +func TestProviderStatusServiceRequiresLatestBatchWhenProviderHasNoImports(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + ctx := context.Background() + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + if _, err := store.Providers().Create(ctx, sqlite.Provider{ + PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai", + }); err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + + _, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"}) + if err == nil || err.Error() != "latest import batch not found for provider" { + t.Fatalf("GetStatus() error = %v, want latest batch not found", err) + } +} + +func TestProviderStatusServiceRequiresHostWhenProviderExistsOnMultipleHosts(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + ctx := context.Background() + hostA, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-a", BaseURL: "https://a.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{}`}) + if err != nil { + t.Fatalf("Hosts().Create(host-a) error = %v", err) + } + hostB, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-b", BaseURL: "https://b.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{}`}) + if err != nil { + t.Fatalf("Hosts().Create(host-b) error = %v", err) + } + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + providerID, err := store.Providers().Create(ctx, sqlite.Provider{ + PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + if _, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady, + }); err != nil { + t.Fatalf("ImportBatches().Create(host-a) error = %v", err) + } + if _, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{ + HostID: hostB, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady, + }); err != nil { + t.Fatalf("ImportBatches().Create(host-b) error = %v", err) + } + + _, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"}) + if err == nil || err.Error() != "provider exists on multiple hosts; host_id is required" { + t.Fatalf("GetStatus() error = %v, want host_id required", err) + } +} + +func TestProviderStatusServiceFailsOnInvalidReconcileSummary(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + ctx := context.Background() + hostID := seedProvisionHost(t, store, "host-1", "https://sub2api.example.com") + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + providerID, err := store.Providers().Create(ctx, sqlite.Provider{ + PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{ + HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady, + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{ + BatchID: batchID, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`, + }); err != nil { + t.Fatalf("AccessClosures().Create() error = %v", err) + } + if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{ + BatchID: batchID, HostID: hostID, ProviderID: providerID, Status: "active", SummaryJSON: `{"missing_count":`, + }); err != nil { + t.Fatalf("ReconcileRuns().Create() error = %v", err) + } + + _, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek", HostID: "host-1"}) + if err == nil || err.Error() == "" { + t.Fatal("GetStatus() error = nil, want decode reconcile summary failure") + } +} diff --git a/internal/provision/rollback_service_test.go b/internal/provision/rollback_service_test.go index 761758ca..9ea9768f 100644 --- a/internal/provision/rollback_service_test.go +++ b/internal/provision/rollback_service_test.go @@ -2,7 +2,9 @@ package provision import ( "context" + "errors" "reflect" + "strings" "testing" "sub2api-cn-relay-manager/internal/host/sub2api" @@ -69,3 +71,58 @@ func TestRollbackServiceRollbackStoredResourcesDeletesOnlyProvidedIDs(t *testing t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want) } } + +func TestRollbackServiceRequiresHost(t *testing.T) { + svc := NewRollbackService(nil) + + if _, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()}); err == nil || err.Error() != "rollback host is required" { + t.Fatalf("Rollback() error = %v, want rollback host is required", err) + } + if _, err := svc.RollbackStoredResources(context.Background(), nil); err == nil || err.Error() != "rollback host is required" { + t.Fatalf("RollbackStoredResources() error = %v, want rollback host is required", err) + } +} + +func TestRollbackServiceCollectsDeleteErrors(t *testing.T) { + host := &fakeHostAdapter{ + managedSnapshot: sub2api.ManagedResourceSnapshot{ + Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}}, + Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}}, + Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "DeepSeek 默认套餐"}}, + Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}}, + }, + deleteErrors: map[string]error{ + "account:account_2": errors.New("account blocked"), + "channel:channel_1": errors.New("channel blocked"), + }, + } + + report, err := NewRollbackService(host).Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()}) + if err == nil { + t.Fatal("Rollback() error = nil, want joined delete errors") + } + if report.AccountsDeleted != 1 || report.PlansDeleted != 1 || report.ChannelsDeleted != 0 || report.GroupsDeleted != 1 { + t.Fatalf("Rollback() report = %+v, want partial success counts", report) + } + if !strings.Contains(err.Error(), "delete account account_2") || !strings.Contains(err.Error(), "delete channel channel_1") { + t.Fatalf("Rollback() error = %v, want joined account/channel errors", err) + } +} + +func TestNamedResourceSnapshotFromStoredIgnoresUnknownTypes(t *testing.T) { + snapshot := namedResourceSnapshotFromStored([]sqlite.ManagedResource{ + {ResourceType: "group", HostResourceID: "group_1", ResourceName: "g1"}, + {ResourceType: "mystery", HostResourceID: "mystery_1", ResourceName: "m1"}, + {ResourceType: "account", HostResourceID: "account_1", ResourceName: "a1"}, + }) + + if len(snapshot.Groups) != 1 || snapshot.Groups[0].ID != "group_1" { + t.Fatalf("snapshot.Groups = %#v, want group_1 only", snapshot.Groups) + } + if len(snapshot.Accounts) != 1 || snapshot.Accounts[0].ID != "account_1" { + t.Fatalf("snapshot.Accounts = %#v, want account_1 only", snapshot.Accounts) + } + if len(snapshot.Plans) != 0 || len(snapshot.Channels) != 0 { + t.Fatalf("snapshot = %#v, want unknown type ignored", snapshot) + } +} diff --git a/internal/provision/runtime_import_service_test.go b/internal/provision/runtime_import_service_test.go index b1079fc2..254bf5e5 100644 --- a/internal/provision/runtime_import_service_test.go +++ b/internal/provision/runtime_import_service_test.go @@ -514,6 +514,101 @@ func TestRuntimeImportServiceRepeatedImportReusesManagedResources(t *testing.T) } } +func TestRuntimeImportServiceResolvesHostByBaseURL(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + seedProvisionHost(t, store, "host-1", "https://sub2api.example.com") + + host := &fakeHostAdapter{ + batchAccounts: []sub2api.AccountRef{{ID: "account_1"}}, + testResults: map[string]sub2api.ProbeResult{ + "account_1": {OK: true, Status: "passed"}, + }, + models: map[string][]sub2api.AccountModel{ + "account_1": {{ID: "deepseek-chat"}}, + }, + gatewayResult: sub2api.GatewayAccessResult{ + OK: true, + StatusCode: 200, + HasExpectedModel: true, + Models: []string{"deepseek-chat"}, + CompletionOK: true, + CompletionStatus: 200, + }, + } + + result, err := NewRuntimeImportService(store, host).Import(context.Background(), RuntimeImportRequest{ + HostBaseURL: "https://sub2api.example.com", + Pack: pack.LoadedPack{ + Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}, + Checksum: "checksum-1", + }, + Provider: sampleProviderManifest(), + Mode: ImportModePartial, + Keys: []string{"key-1"}, + Access: AccessRequest{ + Mode: AccessModeSelfService, + ProbeAPIKey: "user-key", + }, + }) + if err != nil { + t.Fatalf("RuntimeImportService.Import() error = %v", err) + } + if result.BatchID <= 0 { + t.Fatalf("BatchID = %d, want positive id", result.BatchID) + } +} + +func TestRuntimeImportServiceRejectsUnregisteredHostBaseURL(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + _, err := NewRuntimeImportService(store, &fakeHostAdapter{}).Import(context.Background(), RuntimeImportRequest{ + HostBaseURL: "https://missing.example.com", + Pack: pack.LoadedPack{ + Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}, + Checksum: "checksum-1", + }, + Provider: sampleProviderManifest(), + Mode: ImportModePartial, + Keys: []string{"key-1"}, + Access: AccessRequest{ + Mode: AccessModeSelfService, + ProbeAPIKey: "user-key", + }, + }) + if err == nil || !strings.Contains(err.Error(), `host_id is required for unregistered host_base_url "https://missing.example.com"`) { + t.Fatalf("RuntimeImportService.Import() error = %v, want unregistered host_base_url error", err) + } +} + +func TestRuntimeImportServiceRejectsHostBaseURLMismatch(t *testing.T) { + store := openProvisionTestStore(t) + defer closeProvisionTestStore(t, store) + + seedProvisionHost(t, store, "host-1", "https://sub2api.example.com") + + _, err := NewRuntimeImportService(store, &fakeHostAdapter{}).Import(context.Background(), RuntimeImportRequest{ + HostID: "host-1", + HostBaseURL: "https://other.example.com", + Pack: pack.LoadedPack{ + Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"}, + Checksum: "checksum-1", + }, + Provider: sampleProviderManifest(), + Mode: ImportModePartial, + Keys: []string{"key-1"}, + Access: AccessRequest{ + Mode: AccessModeSelfService, + ProbeAPIKey: "user-key", + }, + }) + if err == nil || err.Error() != `host "host-1" base_url mismatch: registered=https://sub2api.example.com runtime=https://other.example.com` { + t.Fatalf("RuntimeImportService.Import() error = %v, want base_url mismatch", err) + } +} + func TestRuntimeImportServiceImportReconcilesExistingChannelConfiguration(t *testing.T) { store := openProvisionTestStore(t) defer closeProvisionTestStore(t, store) diff --git a/internal/reconcile/service_runtime_test.go b/internal/reconcile/service_runtime_test.go index ad3a5c58..b9059373 100644 --- a/internal/reconcile/service_runtime_test.go +++ b/internal/reconcile/service_runtime_test.go @@ -8,6 +8,7 @@ import ( "testing" "sub2api-cn-relay-manager/internal/host/sub2api" + "sub2api-cn-relay-manager/internal/pack" "sub2api-cn-relay-manager/internal/store/sqlite" ) @@ -382,6 +383,127 @@ func TestStoredResourcesForReconcileMergesCurrentBatchAndSharedScaffoldingOnly(t } } +func TestReconcileValidatesTopLevelRequest(t *testing.T) { + t.Parallel() + + req := Request{} + if _, err := NewService(nil, &reconcileHostStub{}).Reconcile(context.Background(), req); err == nil || err.Error() != "store is required" { + t.Fatalf("Reconcile(nil store) error = %v, want store is required", err) + } + + store := openReconcileTestStore(t) + defer closeReconcileTestStore(t, store) + + if _, err := NewService(store, nil).Reconcile(context.Background(), req); err == nil || err.Error() != "host adapter is required" { + t.Fatalf("Reconcile(nil host) error = %v, want host adapter is required", err) + } + if _, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), req); err == nil || err.Error() != "host_id is required" { + t.Fatalf("Reconcile(missing host_id) error = %v, want host_id is required", err) + } + if _, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), Request{HostID: "host-1"}); err == nil || err.Error() != "host_base_url is required" { + t.Fatalf("Reconcile(missing host_base_url) error = %v, want host_base_url is required", err) + } +} + +func TestReconcileRejectsNonReconcilableLatestBatch(t *testing.T) { + t.Parallel() + + store := openReconcileTestStore(t) + defer closeReconcileTestStore(t, store) + + fixture := seedReconcileFixture(t, store) + mustExecReconcileSQL(t, store, `UPDATE import_batches SET batch_status = 'failed' WHERE id = ?`, fixture.batchID) + + _, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), Request{ + HostID: "host-1", + HostBaseURL: "https://sub2api.example.com", + AccessProbeAPIKey: "user-key", + Pack: pack.LoadedPack{ + Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api"}, + }, + Provider: pack.ProviderManifest{ + ProviderID: "deepseek", + SmokeTestModel: "deepseek-chat", + }, + }) + if err == nil || err.Error() != "latest import batch is failed; run import again before reconcile" { + t.Fatalf("Reconcile() error = %v, want non-reconcilable batch error", err) + } +} + +func TestReconcilePersistsActiveRunForHealthySnapshot(t *testing.T) { + t.Parallel() + + store := openReconcileTestStore(t) + defer closeReconcileTestStore(t, store) + + fixture := seedReconcileFixture(t, store) + mustCreateImportBatchItem(t, store, fixture.batchID, "fp-1", `{"account_id":"account-1"}`) + mustCreateManagedResource(t, store, fixture.batchID, fixture.hostPK, "group", "group-1", "group one") + mustCreateManagedResource(t, store, fixture.batchID, fixture.hostPK, "account", "account-1", "account one") + mustCreateAndLoadAccessClosures(t, store, fixture.batchID, sqlite.AccessClosureRecord{ + BatchID: fixture.batchID, + ClosureType: accessModeSelfService, + Status: accessStatusSelfServiceReady, + }) + + host := &reconcileHostStub{ + testResults: map[string]sub2api.ProbeResult{ + "account-1": {OK: true, Status: accountStatusPassed, Message: "ok"}, + }, + models: map[string][]sub2api.AccountModel{ + "account-1": {{ID: "deepseek-chat"}}, + }, + gatewayResult: sub2api.GatewayAccessResult{ + OK: true, + StatusCode: 200, + HasExpectedModel: true, + CompletionOK: true, + Models: []string{"deepseek-chat"}, + }, + completionResults: []sub2api.GatewayCompletionResult{ + {OK: true, StatusCode: 200}, + }, + managedResourceSnapshot: sub2api.ManagedResourceSnapshot{ + Groups: []sub2api.NamedResource{{ID: "group-1", Name: "group one"}}, + Accounts: []sub2api.NamedResource{{ID: "account-1", Name: "account one"}}, + }, + } + + result, err := NewService(store, host).Reconcile(context.Background(), Request{ + HostID: "host-1", + HostBaseURL: "https://sub2api.example.com", + AccessProbeAPIKey: "user-key", + Pack: pack.LoadedPack{ + Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api"}, + }, + Provider: pack.ProviderManifest{ + ProviderID: "deepseek", + SmokeTestModel: "deepseek-chat", + }, + }) + if err != nil { + t.Fatalf("Reconcile() error = %v", err) + } + if result.Status != "active" { + t.Fatalf("result.Status = %q, want active", result.Status) + } + if result.MissingCount != 0 || result.ExtraCount != 0 || result.ProbeFailureCount != 0 { + t.Fatalf("result = %+v, want no drift and no probe failures", result) + } + if result.AccessStatus != accessStatusSelfServiceReady { + t.Fatalf("result.AccessStatus = %q, want %q", result.AccessStatus, accessStatusSelfServiceReady) + } + + runs, err := store.ReconcileRuns().GetByBatchID(context.Background(), fixture.batchID) + if err != nil { + t.Fatalf("ReconcileRuns().GetByBatchID() error = %v", err) + } + if len(runs) != 1 || runs[0].Status != "active" { + t.Fatalf("persisted reconcile runs = %+v, want single active run", runs) + } +} + type reconcileFixture struct { hostPK int64 packPK int64 @@ -512,6 +634,13 @@ func mustCreateAndLoadAccessClosures(t *testing.T, store *sqlite.DB, batchID int return loaded } +func mustExecReconcileSQL(t *testing.T, store *sqlite.DB, query string, args ...any) { + t.Helper() + if _, err := store.SQLDB().Exec(query, args...); err != nil { + t.Fatalf("Exec(%q) error = %v", query, err) + } +} + type reconcileHostStub struct { disableResponsesErr error gatewayErr error @@ -525,6 +654,8 @@ type reconcileHostStub struct { completionCalls int disableResponsesCalls int disabledResponsesAccounts []string + managedResourceSnapshot sub2api.ManagedResourceSnapshot + listManagedResourcesErr error } func (h *reconcileHostStub) GetHostVersion(context.Context) (string, error) { @@ -632,5 +763,8 @@ func (h *reconcileHostStub) DisableOpenAIResponsesAPI(_ context.Context, account } func (h *reconcileHostStub) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) { - return sub2api.ManagedResourceSnapshot{}, nil + if h.listManagedResourcesErr != nil { + return sub2api.ManagedResourceSnapshot{}, h.listManagedResourcesErr + } + return h.managedResourceSnapshot, nil } diff --git a/internal/store/sqlite/access_closure_records_repo_test.go b/internal/store/sqlite/access_closure_records_repo_test.go new file mode 100644 index 00000000..ec41d2e5 --- /dev/null +++ b/internal/store/sqlite/access_closure_records_repo_test.go @@ -0,0 +1,57 @@ +package sqlite + +import ( + "context" + "testing" +) + +func TestAccessClosureRecordsRepoCreateValidationAndDefaults(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + repo := store.AccessClosures() + + if _, err := repo.Create(ctx, AccessClosureRecord{}); err == nil || err.Error() != "batch_id is required" { + t.Fatalf("Create() error = %v, want batch_id is required", err) + } + if _, err := repo.Create(ctx, AccessClosureRecord{BatchID: 1}); err == nil || err.Error() != "closure_type is required" { + t.Fatalf("Create() error = %v, want closure_type is required", err) + } + if _, err := repo.Create(ctx, AccessClosureRecord{BatchID: 1, ClosureType: "self_service"}); err == nil || err.Error() != "status is required" { + t.Fatalf("Create() error = %v, want status is required", err) + } + + batchID := createTestBatch(t, store) + recordID, err := repo.Create(ctx, AccessClosureRecord{ + BatchID: batchID, + ClosureType: "self_service", + Status: "ready", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if recordID <= 0 { + t.Fatalf("recordID = %d, want positive id", recordID) + } + + records, err := repo.GetByBatchID(ctx, batchID) + if err != nil { + t.Fatalf("GetByBatchID() error = %v", err) + } + if len(records) != 1 { + t.Fatalf("len(records) = %d, want 1", len(records)) + } + if records[0].DetailsJSON != "{}" { + t.Fatalf("DetailsJSON = %q, want {}", records[0].DetailsJSON) + } +} + +func TestAccessClosureRecordsRepoGetByBatchIDValidation(t *testing.T) { + t.Parallel() + + _, err := openTestDB(t).AccessClosures().GetByBatchID(context.Background(), 0) + if err == nil || err.Error() != "batch_id is required" { + t.Fatalf("GetByBatchID() error = %v, want batch_id is required", err) + } +} diff --git a/internal/store/sqlite/db_test.go b/internal/store/sqlite/db_test.go index 9727f1ef..07b44e01 100644 --- a/internal/store/sqlite/db_test.go +++ b/internal/store/sqlite/db_test.go @@ -159,3 +159,115 @@ func TestReadMigrationNotFound(t *testing.T) { t.Fatal("readMigration('nonexistent.sql') error = nil, want error") } } + +func TestEnsureMigrationLedgerAndLoadAppliedMigrations(t *testing.T) { + store := openTestDB(t) + db := store.SQLDB() + + tx, err := db.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx() error = %v", err) + } + defer tx.Rollback() + + if err := ensureMigrationLedger(context.Background(), tx); err != nil { + t.Fatalf("ensureMigrationLedger() error = %v", err) + } + if err := ensureMigrationLedger(context.Background(), tx); err != nil { + t.Fatalf("ensureMigrationLedger() second call error = %v", err) + } + if _, err := tx.ExecContext(context.Background(), "INSERT INTO schema_migrations (version) VALUES (?)", "9999_test.sql"); err != nil { + t.Fatalf("insert schema_migrations error = %v", err) + } + + applied, err := loadAppliedMigrations(context.Background(), tx) + if err != nil { + t.Fatalf("loadAppliedMigrations() error = %v", err) + } + if !applied["9999_test.sql"] { + t.Fatalf("loadAppliedMigrations() = %#v, want 9999_test.sql=true", applied) + } +} + +func TestBackfillLegacySchemaIfNeededRecordsCompleteLegacySchema(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy.db") + rawDB, err := sql.Open("sqlite", "file:"+filepath.ToSlash(dbPath)) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + defer rawDB.Close() + + for _, ddl := range []string{ + "CREATE TABLE hosts (id INTEGER PRIMARY KEY, host_id TEXT NOT NULL, base_url TEXT NOT NULL, host_version TEXT NOT NULL, capability_probe_json TEXT NOT NULL DEFAULT '', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)", + "CREATE TABLE packs (id INTEGER PRIMARY KEY, pack_id TEXT NOT NULL, version TEXT NOT NULL, checksum TEXT NOT NULL, manifest_json TEXT NOT NULL DEFAULT '{}', metadata_json TEXT NOT NULL DEFAULT '{}', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)", + "CREATE TABLE providers (id INTEGER PRIMARY KEY, pack_id INTEGER NOT NULL, provider_id TEXT NOT NULL, display_name TEXT NOT NULL, base_url TEXT NOT NULL, platform TEXT NOT NULL, account_type TEXT NOT NULL DEFAULT 'apikey', smoke_test_model TEXT NOT NULL DEFAULT '', provider_manifest_json TEXT NOT NULL DEFAULT '{}', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)", + } { + if _, err := rawDB.ExecContext(context.Background(), ddl); err != nil { + t.Fatalf("Exec legacy ddl error = %v", err) + } + } + + tx, err := rawDB.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx() error = %v", err) + } + defer tx.Rollback() + + if err := ensureMigrationLedger(context.Background(), tx); err != nil { + t.Fatalf("ensureMigrationLedger() error = %v", err) + } + applied, err := loadAppliedMigrations(context.Background(), tx) + if err != nil { + t.Fatalf("loadAppliedMigrations() error = %v", err) + } + if err := backfillLegacySchemaIfNeeded(context.Background(), tx, []string{"0001_init.sql"}, applied); err != nil { + t.Fatalf("backfillLegacySchemaIfNeeded() error = %v", err) + } + if !applied["0001_init.sql"] { + t.Fatalf("applied = %#v, want 0001_init.sql marked", applied) + } +} + +func TestBackfillLegacySchemaIfNeededRejectsPartialLegacySchema(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "legacy-partial.db") + rawDB, err := sql.Open("sqlite", "file:"+filepath.ToSlash(dbPath)) + if err != nil { + t.Fatalf("sql.Open() error = %v", err) + } + defer rawDB.Close() + + if _, err := rawDB.ExecContext(context.Background(), "CREATE TABLE hosts (id INTEGER PRIMARY KEY)"); err != nil { + t.Fatalf("Exec partial legacy ddl error = %v", err) + } + + tx, err := rawDB.BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx() error = %v", err) + } + defer tx.Rollback() + + if err := ensureMigrationLedger(context.Background(), tx); err != nil { + t.Fatalf("ensureMigrationLedger() error = %v", err) + } + applied, err := loadAppliedMigrations(context.Background(), tx) + if err != nil { + t.Fatalf("loadAppliedMigrations() error = %v", err) + } + err = backfillLegacySchemaIfNeeded(context.Background(), tx, []string{"0001_init.sql"}, applied) + if err == nil || err.Error() != "legacy sqlite schema is partially applied without schema_migrations" { + t.Fatalf("backfillLegacySchemaIfNeeded() error = %v, want partial legacy schema error", err) + } +} + +func TestRollbackMigrationReturnsOriginalErrorWhenRollbackSucceeds(t *testing.T) { + store := openTestDB(t) + tx, err := store.SQLDB().BeginTx(context.Background(), nil) + if err != nil { + t.Fatalf("BeginTx() error = %v", err) + } + + baseErr := errors.New("apply failed") + if err := rollbackMigration(tx, baseErr); !errors.Is(err, baseErr) { + t.Fatalf("rollbackMigration() error = %v, want original error", err) + } +} diff --git a/internal/store/sqlite/hosts_repo_test.go b/internal/store/sqlite/hosts_repo_test.go index 7e8522d3..00ba4936 100644 --- a/internal/store/sqlite/hosts_repo_test.go +++ b/internal/store/sqlite/hosts_repo_test.go @@ -210,6 +210,30 @@ func TestHostsRepoGetByHostIDNotFound(t *testing.T) { } } +func TestHostsRepoGetByBaseURL(t *testing.T) { + store := openTestDB(t) + createTestHostWithBaseURL(t, store, "host-base-url", "https://base-url.example.com") + + got, err := store.Hosts().GetByBaseURL(context.Background(), "https://base-url.example.com") + if err != nil { + t.Fatalf("GetByBaseURL() error = %v", err) + } + if got.HostID != "host-base-url" { + t.Fatalf("GetByBaseURL() host_id = %q, want host-base-url", got.HostID) + } +} + +func TestHostsRepoGetByBaseURLErrors(t *testing.T) { + store := openTestDB(t) + + if _, err := store.Hosts().GetByBaseURL(context.Background(), ""); err == nil { + t.Fatal("GetByBaseURL(\"\") error = nil, want error") + } + if _, err := store.Hosts().GetByBaseURL(context.Background(), "https://missing.example.com"); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetByBaseURL(missing) error = %v, want sql.ErrNoRows", err) + } +} + func TestHostsRepoListAll(t *testing.T) { store := openTestDB(t) @@ -273,6 +297,52 @@ func TestHostsRepoUpdateProbeByHostID(t *testing.T) { } } +func TestHostsRepoUpdateConnectionByHostID(t *testing.T) { + store := openTestDB(t) + createTestHost(t, store) + + if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "https://updated.example.com", "0.3.0", "", "", "token-1"); err != nil { + t.Fatalf("UpdateConnectionByHostID() error = %v", err) + } + + host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name())) + if err != nil { + t.Fatalf("GetByHostID() error = %v", err) + } + if host.BaseURL != "https://updated.example.com" { + t.Fatalf("BaseURL = %q, want updated URL", host.BaseURL) + } + if host.HostVersion != "0.3.0" { + t.Fatalf("HostVersion = %q, want 0.3.0", host.HostVersion) + } + if host.CapabilityProbeJSON != "{}" { + t.Fatalf("CapabilityProbeJSON = %q, want {}", host.CapabilityProbeJSON) + } + if host.AuthType != "apikey" { + t.Fatalf("AuthType = %q, want apikey", host.AuthType) + } + if host.AuthToken != "token-1" { + t.Fatalf("AuthToken = %q, want token-1", host.AuthToken) + } +} + +func TestHostsRepoUpdateConnectionByHostIDErrors(t *testing.T) { + store := openTestDB(t) + + if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil { + t.Fatal("UpdateConnectionByHostID() empty host_id error = nil") + } + if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "", "0.2.0", "{}", "apikey", "token"); err == nil { + t.Fatal("UpdateConnectionByHostID() empty base_url error = nil") + } + if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "", "{}", "apikey", "token"); err == nil { + t.Fatal("UpdateConnectionByHostID() empty host_version error = nil") + } + if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil { + t.Fatal("UpdateConnectionByHostID() missing host error = nil") + } +} + func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) { store := openTestDB(t) err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent") diff --git a/internal/store/sqlite/import_batches_repo_test.go b/internal/store/sqlite/import_batches_repo_test.go index 875433ca..d80e1902 100644 --- a/internal/store/sqlite/import_batches_repo_test.go +++ b/internal/store/sqlite/import_batches_repo_test.go @@ -102,6 +102,148 @@ func TestImportBatchesRepoListByProviderID(t *testing.T) { } } +func TestImportBatchesRepoGetLatestByProviderIDAndHostID(t *testing.T) { + store := openTestDB(t) + hostA := createTestHost(t, store) + hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b.example.com") + packID := createTestPack(t, store) + providerID := createTestProviderWithPack(t, store, packID) + + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, + Mode: "partial", BatchStatus: "running", AccessStatus: "pending", + }); err != nil { + t.Fatalf("Create(hostA older) error = %v", err) + } + latestA, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, + Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("Create(hostA newer) error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packID, ProviderID: providerID, + Mode: "partial", BatchStatus: "partially_succeeded", AccessStatus: "degraded", + }); err != nil { + t.Fatalf("Create(hostB) error = %v", err) + } + + got, err := store.ImportBatches().GetLatestByProviderIDAndHostID(context.Background(), providerID, hostA) + if err != nil { + t.Fatalf("GetLatestByProviderIDAndHostID() error = %v", err) + } + if got.ID != latestA { + t.Fatalf("GetLatestByProviderIDAndHostID() id = %d, want %d", got.ID, latestA) + } +} + +func TestImportBatchesRepoListByProviderIDAndHostID(t *testing.T) { + store := openTestDB(t) + hostA := createTestHost(t, store) + hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b2.example.com") + packID := createTestPack(t, store) + providerID := createTestProviderWithPack(t, store, packID) + + olderA, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, + Mode: "partial", BatchStatus: "running", AccessStatus: "pending", + }) + if err != nil { + t.Fatalf("Create(hostA older) error = %v", err) + } + newerA, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, + Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("Create(hostA newer) error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packID, ProviderID: providerID, + Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }); err != nil { + t.Fatalf("Create(hostB) error = %v", err) + } + + batches, err := store.ImportBatches().ListByProviderIDAndHostID(context.Background(), providerID, hostA) + if err != nil { + t.Fatalf("ListByProviderIDAndHostID() error = %v", err) + } + if len(batches) != 2 { + t.Fatalf("ListByProviderIDAndHostID() len = %d, want 2", len(batches)) + } + if batches[0].ID != newerA || batches[1].ID != olderA { + t.Fatalf("ListByProviderIDAndHostID() ids = [%d %d], want [%d %d]", batches[0].ID, batches[1].ID, newerA, olderA) + } +} + +func TestImportBatchesRepoListLatestReconcilable(t *testing.T) { + store := openTestDB(t) + hostA := createTestHost(t, store) + hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b3.example.com") + packAID := createTestPack(t, store) + packBID := createTestPackWithSuffix(t, store, "reconcile") + providerA := createTestProviderWithPack(t, store, packAID) + providerB := createTestProviderWithPack(t, store, packBID) + + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packAID, ProviderID: providerA, + Mode: "partial", BatchStatus: "running", AccessStatus: "pending", + }); err != nil { + t.Fatalf("Create(providerA older) error = %v", err) + } + latestA, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packAID, ProviderID: providerA, + Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("Create(providerA latest) error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packBID, ProviderID: providerB, + Mode: "strict", BatchStatus: "partially_succeeded", AccessStatus: "degraded", + }); err != nil { + t.Fatalf("Create(providerB latest) error = %v", err) + } + if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packBID, ProviderID: providerB, + Mode: "strict", BatchStatus: "failed", AccessStatus: "broken", + }); err != nil { + t.Fatalf("Create(providerB newer failed) error = %v", err) + } + + batches, err := store.ImportBatches().ListLatestReconcilable(context.Background()) + if err != nil { + t.Fatalf("ListLatestReconcilable() error = %v", err) + } + if len(batches) != 1 { + t.Fatalf("ListLatestReconcilable() len = %d, want 1 because providerB latest batch is not reconcilable", len(batches)) + } + if batches[0].ID != latestA { + t.Fatalf("ListLatestReconcilable() id = %d, want %d", batches[0].ID, latestA) + } + + batchID, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packBID, ProviderID: providerB, + Mode: "strict", BatchStatus: "partially_succeeded", AccessStatus: "degraded", + }) + if err != nil { + t.Fatalf("Create(providerB new reconcilable) error = %v", err) + } + + batches, err = store.ImportBatches().ListLatestReconcilable(context.Background()) + if err != nil { + t.Fatalf("ListLatestReconcilable() second call error = %v", err) + } + if len(batches) != 2 { + t.Fatalf("ListLatestReconcilable() second len = %d, want 2", len(batches)) + } + if batches[0].ID != batchID || batches[1].ID != latestA { + t.Fatalf("ListLatestReconcilable() ids = [%d %d], want [%d %d]", batches[0].ID, batches[1].ID, batchID, latestA) + } +} + func TestImportBatchesRepoGetByIDNotFound(t *testing.T) { store := openTestDB(t) _, err := store.ImportBatches().GetByID(context.Background(), 999) @@ -265,6 +407,102 @@ func TestManagedResourcesRepoGetByBatchIDEmpty(t *testing.T) { } } +func TestManagedResourcesRepoGetByResourceIdentity(t *testing.T) { + store := openTestDB(t) + batchID := createTestBatch(t, store) + batch, err := store.ImportBatches().GetByID(context.Background(), batchID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if _, err := store.ManagedResources().Create(context.Background(), ManagedResource{ + BatchID: batchID, HostID: batch.HostID, ResourceType: "channel", HostResourceID: "channel-1", ResourceName: "Channel 1", + }); err != nil { + t.Fatalf("Create() error = %v", err) + } + + resource, err := store.ManagedResources().GetByResourceIdentity(context.Background(), batch.HostID, "channel", "channel-1") + if err != nil { + t.Fatalf("GetByResourceIdentity() error = %v", err) + } + if resource.ResourceName != "Channel 1" { + t.Fatalf("GetByResourceIdentity() resource_name = %q, want Channel 1", resource.ResourceName) + } +} + +func TestManagedResourcesRepoListByProviderScopes(t *testing.T) { + store := openTestDB(t) + hostA := createTestHost(t, store) + hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://managed-host-b.example.com") + packID := createTestPack(t, store) + providerID := createTestProviderWithPack(t, store, packID) + otherProviderID := createTestProviderWithPack(t, store, createTestPackWithSuffix(t, store, "managed")) + + batchA, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostA, PackID: packID, ProviderID: providerID, + Mode: "partial", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("Create(batchA) error = %v", err) + } + batchB, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packID, ProviderID: providerID, + Mode: "partial", BatchStatus: "partially_succeeded", AccessStatus: "degraded", + }) + if err != nil { + t.Fatalf("Create(batchB) error = %v", err) + } + otherBatch, err := store.ImportBatches().Create(context.Background(), ImportBatch{ + HostID: hostB, PackID: packID, ProviderID: otherProviderID, + Mode: "partial", BatchStatus: "succeeded", AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("Create(otherBatch) error = %v", err) + } + + for _, resource := range []ManagedResource{ + {BatchID: batchA, HostID: hostA, ResourceType: "group", HostResourceID: "g-a", ResourceName: "Group A"}, + {BatchID: batchB, HostID: hostB, ResourceType: "account", HostResourceID: "a-b", ResourceName: "Account B"}, + {BatchID: otherBatch, HostID: hostB, ResourceType: "channel", HostResourceID: "c-other", ResourceName: "Other"}, + } { + if _, err := store.ManagedResources().Create(context.Background(), resource); err != nil { + t.Fatalf("Create(%s) error = %v", resource.HostResourceID, err) + } + } + + byProvider, err := store.ManagedResources().ListByProviderID(context.Background(), providerID) + if err != nil { + t.Fatalf("ListByProviderID() error = %v", err) + } + if len(byProvider) != 2 { + t.Fatalf("ListByProviderID() len = %d, want 2", len(byProvider)) + } + + byProviderHost, err := store.ManagedResources().ListByProviderIDAndHostID(context.Background(), providerID, hostB) + if err != nil { + t.Fatalf("ListByProviderIDAndHostID() error = %v", err) + } + if len(byProviderHost) != 1 { + t.Fatalf("ListByProviderIDAndHostID() len = %d, want 1", len(byProviderHost)) + } + if byProviderHost[0].HostResourceID != "a-b" { + t.Fatalf("ListByProviderIDAndHostID() resource = %q, want a-b", byProviderHost[0].HostResourceID) + } +} + +func TestManagedResourcesRepoQueryValidationErrors(t *testing.T) { + store := openTestDB(t) + + if _, err := store.ManagedResources().GetByResourceIdentity(context.Background(), 0, "group", "g-1"); err == nil { + t.Fatal("GetByResourceIdentity() host_id=0 error = nil") + } + if _, err := store.ManagedResources().ListByProviderID(context.Background(), 0); err == nil { + t.Fatal("ListByProviderID() provider_id=0 error = nil") + } + if _, err := store.ManagedResources().ListByProviderIDAndHostID(context.Background(), 1, 0); err == nil { + t.Fatal("ListByProviderIDAndHostID() host_id=0 error = nil") + } +} + func TestManagedResourcesRepoValidationErrors(t *testing.T) { store := openTestDB(t) for _, tt := range []struct { diff --git a/internal/store/sqlite/import_run_item_events_repo_test.go b/internal/store/sqlite/import_run_item_events_repo_test.go new file mode 100644 index 00000000..3e994ce3 --- /dev/null +++ b/internal/store/sqlite/import_run_item_events_repo_test.go @@ -0,0 +1,88 @@ +package sqlite + +import ( + "context" + "testing" +) + +func TestImportRunItemEventsRepoAppendValidationAndDefaults(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + repo := store.ImportRunEvents() + + if err := repo.Append(ctx, ImportRunItemEvent{}); err == nil || err.Error() != "event_id is required" { + t.Fatalf("Append() error = %v, want event_id is required", err) + } + if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1"}); err == nil || err.Error() != "run_id is required" { + t.Fatalf("Append() error = %v, want run_id is required", err) + } + if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1"}); err == nil || err.Error() != "item_id is required" { + t.Fatalf("Append() error = %v, want item_id is required", err) + } + if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1"}); err == nil || err.Error() != "event_type is required" { + t.Fatalf("Append() error = %v, want event_type is required", err) + } + if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1", EventType: "confirm"}); err == nil || err.Error() != "stage is required" { + t.Fatalf("Append() error = %v, want stage is required", err) + } + if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1", EventType: "confirm", Stage: "confirm"}); err == nil || err.Error() != "message is required" { + t.Fatalf("Append() error = %v, want message is required", err) + } + + if err := store.ImportRuns().Create(ctx, ImportRun{ + RunID: "run-1", + HostID: "host-1", + Mode: "strict", + AccessMode: "subscription", + State: "running", + }); err != nil { + t.Fatalf("ImportRuns().Create() error = %v", err) + } + if err := store.ImportRunItems().Create(ctx, ImportRunItem{ + ItemID: "item-1", + RunID: "run-1", + BaseURL: "https://api.example.com/v1", + ProviderID: "provider-1", + APIKeyFingerprint: "fp-1", + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + }); err != nil { + t.Fatalf("ImportRunItems().Create() error = %v", err) + } + + if err := repo.Append(ctx, ImportRunItemEvent{ + EventID: "evt-1", + RunID: "run-1", + ItemID: "item-1", + EventType: "confirm", + Stage: "confirm", + Message: "scheduled", + }); err != nil { + t.Fatalf("Append() error = %v", err) + } + + events, err := repo.ListByItemID(ctx, "item-1") + if err != nil { + t.Fatalf("ListByItemID() error = %v", err) + } + if len(events) != 1 { + t.Fatalf("len(events) = %d, want 1", len(events)) + } + if events[0].PayloadJSON != "{}" { + t.Fatalf("PayloadJSON = %q, want {}", events[0].PayloadJSON) + } +} + +func TestImportRunItemEventsRepoListByItemIDValidation(t *testing.T) { + t.Parallel() + + _, err := openTestDB(t).ImportRunEvents().ListByItemID(context.Background(), "") + if err == nil || err.Error() != "item_id is required" { + t.Fatalf("ListByItemID() error = %v, want item_id is required", err) + } +} diff --git a/internal/store/sqlite/import_runs_repo_test.go b/internal/store/sqlite/import_runs_repo_test.go index 62c67907..dea737e3 100644 --- a/internal/store/sqlite/import_runs_repo_test.go +++ b/internal/store/sqlite/import_runs_repo_test.go @@ -4,6 +4,7 @@ import ( "context" "reflect" "testing" + "time" ) func TestRunStateStore(t *testing.T) { @@ -169,3 +170,182 @@ func TestRunStateStore(t *testing.T) { t.Fatalf("items[0].AdvisoryMessagesJSON = %q, want advisory json", items[0].AdvisoryMessagesJSON) } } + +func TestImportRunsRepoList(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + + for _, run := range []ImportRun{ + {RunID: "run-a", HostID: "host-a", Mode: "strict", AccessMode: "subscription", State: "running"}, + {RunID: "run-b", HostID: "host-b", Mode: "partial", AccessMode: "self_service", State: "completed"}, + } { + if err := store.ImportRuns().Create(ctx, run); err != nil { + t.Fatalf("ImportRuns().Create(%q) error = %v", run.RunID, err) + } + } + + runs, err := store.ImportRuns().List(ctx, 1) + if err != nil { + t.Fatalf("ImportRuns().List(limit=1) error = %v", err) + } + if len(runs) != 1 { + t.Fatalf("ImportRuns().List(limit=1) len = %d, want 1", len(runs)) + } + + runs, err = store.ImportRuns().List(ctx, 0) + if err != nil { + t.Fatalf("ImportRuns().List(limit=0) error = %v", err) + } + if len(runs) != 2 { + t.Fatalf("ImportRuns().List(limit=0) len = %d, want 2", len(runs)) + } +} + +func TestImportRunItemsRepoCreateUpdateAndLease(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + + run := ImportRun{RunID: "run-lease", HostID: "host-lease", Mode: "strict", AccessMode: "subscription", State: "running"} + if err := store.ImportRuns().Create(ctx, run); err != nil { + t.Fatalf("ImportRuns().Create() error = %v", err) + } + + item := ImportRunItem{ + ItemID: "item-lease", + RunID: "run-lease", + BaseURL: "https://api.example.com/v1", + ProviderID: "provider-lease", + APIKeyFingerprint: "fp_lease", + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + RequestedModelsJSON: `["model-a"]`, + AdvisoryMessagesJSON: `["first"]`, + } + if err := store.ImportRunItems().Create(ctx, item); err != nil { + t.Fatalf("ImportRunItems().Create() error = %v", err) + } + + item.ConfirmationStatus = "confirmed" + item.CurrentStage = "validate" + item.AccessStatus = "active" + item.AdvisoryMessagesJSON = `["updated"]` + if err := store.ImportRunItems().Update(ctx, item); err != nil { + t.Fatalf("ImportRunItems().Update() error = %v", err) + } + + got, err := store.ImportRunItems().GetByItemID(ctx, "item-lease") + if err != nil { + t.Fatalf("ImportRunItems().GetByItemID() error = %v", err) + } + if got.CurrentStage != "validate" || got.ConfirmationStatus != "confirmed" { + t.Fatalf("updated item = %+v, want validate/confirmed", got) + } + + leaseItem := item + leaseItem.ItemID = "item-pending" + leaseItem.CurrentStage = "confirm" + leaseItem.ConfirmationStatus = "pending" + leaseItem.AccessStatus = "unknown" + if err := store.ImportRunItems().Create(ctx, leaseItem); err != nil { + t.Fatalf("ImportRunItems().Create(item-pending) error = %v", err) + } + + now := time.Date(2026, 5, 23, 1, 2, 3, 0, time.UTC) + leased, ok, err := store.ImportRunItems().TryAcquireConfirmationLease(ctx, "item-pending", "worker-1", now, 2*time.Minute) + if err != nil { + t.Fatalf("TryAcquireConfirmationLease() error = %v", err) + } + if !ok { + t.Fatal("TryAcquireConfirmationLease() ok = false, want true") + } + if leased.LeaseOwner != "worker-1" { + t.Fatalf("LeaseOwner = %q, want worker-1", leased.LeaseOwner) + } + if leased.ConfirmationAttempts != 1 { + t.Fatalf("ConfirmationAttempts = %d, want 1", leased.ConfirmationAttempts) + } + + _, ok, err = store.ImportRunItems().TryAcquireConfirmationLease(ctx, "item-pending", "worker-2", now.Add(30*time.Second), time.Minute) + if err != nil { + t.Fatalf("TryAcquireConfirmationLease(second) error = %v", err) + } + if ok { + t.Fatal("TryAcquireConfirmationLease(second) ok = true, want false while lease active") + } +} + +func TestImportRunEventsRepoCreateAndHelpers(t *testing.T) { + t.Parallel() + + ctx := context.Background() + store := openTestDB(t) + + if err := store.ImportRuns().Create(ctx, ImportRun{ + RunID: "run-events", + HostID: "host-events", + Mode: "strict", + AccessMode: "subscription", + State: "running", + }); err != nil { + t.Fatalf("ImportRuns().Create() error = %v", err) + } + if err := store.ImportRunItems().Create(ctx, ImportRunItem{ + ItemID: "item-events", + RunID: "run-events", + BaseURL: "https://api.example.com/v1", + ProviderID: "provider-events", + APIKeyFingerprint: "fp_events", + CurrentStage: "confirm", + ConfirmationStatus: "pending", + AccessStatus: "unknown", + MatchedAccountState: "active", + AccountResolution: "created", + }); err != nil { + t.Fatalf("ImportRunItems().Create() error = %v", err) + } + + event := ImportRunItemEvent{ + EventID: "evt-create", + RunID: "run-events", + ItemID: "item-events", + EventType: "confirmation", + Stage: "confirm", + Message: "created by wrapper", + } + if err := store.ImportRunEvents().Create(ctx, event); err != nil { + t.Fatalf("ImportRunEvents().Create() error = %v", err) + } + if err := store.ImportRunItemEvents().Create(ctx, ImportRunItemEvent{ + EventID: "evt-alias", + RunID: "run-events", + ItemID: "item-events", + EventType: "alias", + Stage: "confirm", + Message: "created by alias accessor", + }); err != nil { + t.Fatalf("ImportRunItemEvents().Create() error = %v", err) + } + + events, err := store.ImportRunEvents().ListByItemID(ctx, "item-events") + if err != nil { + t.Fatalf("ImportRunEvents().ListByItemID() error = %v", err) + } + if len(events) != 2 { + t.Fatalf("ImportRunEvents().ListByItemID() len = %d, want 2", len(events)) + } + + nullable := sqlNullInt64{Int64: 42, Valid: true} + if ptr := nullable.ptr(); ptr == nil || *ptr != 42 { + t.Fatalf("sqlNullInt64.ptr() = %#v, want 42", ptr) + } + if ptr := (sqlNullInt64{}).ptr(); ptr != nil { + t.Fatalf("sqlNullInt64{}.ptr() = %#v, want nil", ptr) + } +} diff --git a/internal/store/sqlite/packs_repo_test.go b/internal/store/sqlite/packs_repo_test.go index 8028497e..8277af38 100644 --- a/internal/store/sqlite/packs_repo_test.go +++ b/internal/store/sqlite/packs_repo_test.go @@ -157,3 +157,33 @@ func TestPacksRepoGetByPackIDNotFound(t *testing.T) { t.Fatalf("GetByPackID() error = %v, want sql.ErrNoRows", err) } } + +func TestPacksRepoListAll(t *testing.T) { + store := openTestDB(t) + + packs, err := store.Packs().ListAll(context.Background()) + if err != nil { + t.Fatalf("ListAll() empty error = %v", err) + } + if len(packs) != 0 { + t.Fatalf("ListAll() empty len = %d, want 0", len(packs)) + } + + if _, err := store.Packs().Create(context.Background(), Pack{PackID: "pack-a", Version: "1.0.0", Checksum: "chk-a"}); err != nil { + t.Fatalf("Create(pack-a) error = %v", err) + } + if _, err := store.Packs().Create(context.Background(), Pack{PackID: "pack-b", Version: "1.1.0", Checksum: "chk-b"}); err != nil { + t.Fatalf("Create(pack-b) error = %v", err) + } + + packs, err = store.Packs().ListAll(context.Background()) + if err != nil { + t.Fatalf("ListAll() error = %v", err) + } + if len(packs) != 2 { + t.Fatalf("ListAll() len = %d, want 2", len(packs)) + } + if packs[0].PackID != "pack-a" || packs[1].PackID != "pack-b" { + t.Fatalf("ListAll() ids = [%q %q], want [pack-a pack-b]", packs[0].PackID, packs[1].PackID) + } +} diff --git a/internal/store/sqlite/providers_repo_test.go b/internal/store/sqlite/providers_repo_test.go index 857ea414..64cdf904 100644 --- a/internal/store/sqlite/providers_repo_test.go +++ b/internal/store/sqlite/providers_repo_test.go @@ -55,6 +55,33 @@ func TestProvidersRepoListByProviderID(t *testing.T) { } } +func TestProvidersRepoListByPackID(t *testing.T) { + store := openTestDB(t) + packID := createTestPack(t, store) + otherPackID := createTestPackWithSuffix(t, store, "other") + + if _, err := store.Providers().Create(context.Background(), Provider{PackID: packID, ProviderID: "provider-a", DisplayName: "A", BaseURL: "https://a.example.com", Platform: "openai"}); err != nil { + t.Fatalf("Create(provider-a) error = %v", err) + } + if _, err := store.Providers().Create(context.Background(), Provider{PackID: packID, ProviderID: "provider-b", DisplayName: "B", BaseURL: "https://b.example.com", Platform: "openai"}); err != nil { + t.Fatalf("Create(provider-b) error = %v", err) + } + if _, err := store.Providers().Create(context.Background(), Provider{PackID: otherPackID, ProviderID: "provider-c", DisplayName: "C", BaseURL: "https://c.example.com", Platform: "openai"}); err != nil { + t.Fatalf("Create(provider-c) error = %v", err) + } + + providers, err := store.Providers().ListByPackID(context.Background(), packID) + if err != nil { + t.Fatalf("ListByPackID() error = %v", err) + } + if len(providers) != 2 { + t.Fatalf("ListByPackID() count = %d, want 2", len(providers)) + } + if providers[0].ProviderID != "provider-a" || providers[1].ProviderID != "provider-b" { + t.Fatalf("ListByPackID() provider ids = [%q %q], want [provider-a provider-b]", providers[0].ProviderID, providers[1].ProviderID) + } +} + func createTestPackWithSuffix(t *testing.T, store *DB, suffix string) int64 { t.Helper() id, err := store.Packs().Create(context.Background(), Pack{ @@ -145,6 +172,44 @@ func TestProvidersRepoGetByPackIDAndProviderIDNotFound(t *testing.T) { } } +func TestProvidersRepoGetByID(t *testing.T) { + store := openTestDB(t) + + packID := createTestPack(t, store) + providerID, err := store.Providers().Create(context.Background(), Provider{ + PackID: packID, + ProviderID: "provider-id-lookup", + DisplayName: "Lookup", + BaseURL: "https://lookup.example.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + + got, err := store.Providers().GetByID(context.Background(), providerID) + if err != nil { + t.Fatalf("GetByID() error = %v", err) + } + if got.ProviderID != "provider-id-lookup" { + t.Fatalf("GetByID() provider_id = %q, want provider-id-lookup", got.ProviderID) + } +} + +func TestProvidersRepoGetByIDErrors(t *testing.T) { + store := openTestDB(t) + + if _, err := store.Providers().GetByID(context.Background(), 0); err == nil { + t.Fatal("GetByID(0) error = nil, want error") + } + if _, err := store.Providers().ListByPackID(context.Background(), 0); err == nil { + t.Fatal("ListByPackID(0) error = nil, want error") + } + if _, err := store.Providers().GetByID(context.Background(), 999); !errors.Is(err, sql.ErrNoRows) { + t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err) + } +} + func TestProvidersRepoGetByPackIDEmptyError(t *testing.T) { store := openTestDB(t) _, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 0, "p")