diff --git a/internal/app/http_api.go b/internal/app/http_api.go index 67ce1578..e5e2a6ec 100644 --- a/internal/app/http_api.go +++ b/internal/app/http_api.go @@ -61,6 +61,10 @@ type ActionSet struct { GetRouteFailure func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error) SetRouteCooldown func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error) GetRouteCooldown func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error) + ListProviderAccounts func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error) + EnableProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) + DisableProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) + RetireProviderAccount func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error) ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error) GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error) @@ -432,6 +436,18 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha mux.Handle("GET /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleGetRouteCooldown(w, r, actions.GetRouteCooldown) }))) + mux.Handle("GET /api/provider-accounts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleListProviderAccounts(w, r, actions.ListProviderAccounts) + }))) + mux.Handle("POST /api/provider-accounts/{accountID}/enable", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleEnableProviderAccount(w, r, actions.EnableProviderAccount) + }))) + mux.Handle("POST /api/provider-accounts/{accountID}/disable", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleDisableProviderAccount(w, r, actions.DisableProviderAccount) + }))) + mux.Handle("POST /api/provider-accounts/{accountID}/retire", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleRetireProviderAccount(w, r, actions.RetireProviderAccount) + }))) mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleCreateProviderDraft(w, r, actions.CreateProviderDraft) }))) @@ -1280,6 +1296,10 @@ func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRu GetRouteFailure: buildGetRouteFailureAction(stickyRuntime), SetRouteCooldown: buildSetRouteCooldownAction(stickyRuntime), GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime), + ListProviderAccounts: buildListProviderAccountsAction(sqliteDSN), + EnableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusActive), + DisableProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDisabled), + RetireProviderAccount: buildUpdateProviderAccountStatusAction(sqliteDSN, sqlite.ProviderAccountStatusDeprecated), CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) { store, err := sqlite.Open(ctx, sqliteDSN) if err != nil { diff --git a/internal/app/provider_accounts_api.go b/internal/app/provider_accounts_api.go new file mode 100644 index 00000000..aaad71ad --- /dev/null +++ b/internal/app/provider_accounts_api.go @@ -0,0 +1,188 @@ +package app + +import ( + "context" + "database/sql" + "fmt" + "net/http" + "strconv" + "strings" + + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +type ListProviderAccountsRequest struct { + HostID string + ProviderID string + RouteID string + ShadowGroupID string + AccountStatus string + Query string + Limit int +} + +type UpdateProviderAccountStatusRequest struct { + AccountID int64 `json:"-"` + AccountStatus string `json:"-"` + DisabledReason string `json:"reason,omitempty"` +} + +type ProviderAccountInfo struct { + ID int64 `json:"id"` + HostID string `json:"host_id"` + ProviderID string `json:"provider_id"` + ProviderName string `json:"provider_name"` + RouteID string `json:"route_id,omitempty"` + LogicalGroupID string `json:"logical_group_id,omitempty"` + ShadowGroupID string `json:"shadow_group_id,omitempty"` + HostAccountID string `json:"host_account_id"` + KeyFingerprint string `json:"key_fingerprint"` + AccountName string `json:"account_name"` + AccountStatus string `json:"account_status"` + LastProbeStatus string `json:"last_probe_status,omitempty"` + LastProbeAt string `json:"last_probe_at,omitempty"` + DisabledReason string `json:"disabled_reason,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +func handleListProviderAccounts(w http.ResponseWriter, r *http.Request, fn func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "list-provider-accounts action is not configured"}) + return + } + accounts, err := fn(r.Context(), ListProviderAccountsRequest{ + HostID: strings.TrimSpace(r.URL.Query().Get("host_id")), + ProviderID: strings.TrimSpace(r.URL.Query().Get("provider_id")), + RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")), + ShadowGroupID: strings.TrimSpace(r.URL.Query().Get("shadow_group_id")), + AccountStatus: strings.TrimSpace(r.URL.Query().Get("account_status")), + Query: strings.TrimSpace(r.URL.Query().Get("q")), + Limit: parsePositiveInt(r.URL.Query().Get("limit")), + }) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + if accounts == nil { + accounts = []ProviderAccountInfo{} + } + writeJSON(w, http.StatusOK, map[string]any{"provider_accounts": accounts}) +} + +func handleEnableProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) { + handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusActive) +} + +func handleDisableProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) { + handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusDisabled) +} + +func handleRetireProviderAccount(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error)) { + handleUpdateProviderAccountStatus(w, r, fn, sqlite.ProviderAccountStatusDeprecated) +} + +func handleUpdateProviderAccountStatus(w http.ResponseWriter, r *http.Request, fn func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error), accountStatus string) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "update-provider-account-status action is not configured"}) + return + } + rawID := strings.TrimSpace(r.PathValue("accountID")) + accountID, err := strconv.ParseInt(rawID, 10, 64) + if err != nil || accountID <= 0 { + writeHTTPError(w, &httpError{StatusCode: http.StatusBadRequest, Code: "invalid_request", Message: "account_id must be a positive integer"}) + return + } + req := UpdateProviderAccountStatusRequest{ + AccountID: accountID, + AccountStatus: accountStatus, + } + if r.ContentLength != 0 { + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + req.AccountID = accountID + req.AccountStatus = accountStatus + } + account, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"provider_account": account}) +} + +func buildListProviderAccountsAction(sqliteDSN string) func(context.Context, ListProviderAccountsRequest) ([]ProviderAccountInfo, error) { + return func(ctx context.Context, req ListProviderAccountsRequest) ([]ProviderAccountInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return nil, err + } + defer store.Close() + + if err := sqlite.SyncProviderAccountsFromLatestImportBatches(ctx, store); err != nil { + return nil, err + } + rows, err := store.ProviderAccounts().List(ctx, sqlite.ProviderAccountListFilter{ + HostID: req.HostID, + ProviderID: req.ProviderID, + RouteID: req.RouteID, + ShadowGroupID: req.ShadowGroupID, + AccountStatus: req.AccountStatus, + Query: req.Query, + Limit: req.Limit, + }) + if err != nil { + return nil, err + } + result := make([]ProviderAccountInfo, 0, len(rows)) + for _, row := range rows { + result = append(result, providerAccountViewToInfo(row)) + } + return result, nil + } +} + +func buildUpdateProviderAccountStatusAction(sqliteDSN, accountStatus string) func(context.Context, UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) { + return func(ctx context.Context, req UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) { + store, err := sqlite.Open(ctx, sqliteDSN) + if err != nil { + return ProviderAccountInfo{}, err + } + defer store.Close() + + if err := store.ProviderAccounts().UpdateStatusByID(ctx, req.AccountID, accountStatus, strings.TrimSpace(req.DisabledReason)); err != nil { + if err == sql.ErrNoRows { + return ProviderAccountInfo{}, fmt.Errorf("provider account %d not found", req.AccountID) + } + return ProviderAccountInfo{}, err + } + updated, err := store.ProviderAccounts().GetViewByID(ctx, req.AccountID) + if err != nil { + return ProviderAccountInfo{}, err + } + return providerAccountViewToInfo(updated), nil + } +} + +func providerAccountViewToInfo(row sqlite.ProviderAccountView) ProviderAccountInfo { + return ProviderAccountInfo{ + ID: row.ID, + HostID: row.HostID, + ProviderID: row.ProviderID, + ProviderName: row.ProviderName, + RouteID: row.RouteID, + LogicalGroupID: row.LogicalGroupID, + ShadowGroupID: row.ShadowGroupID, + HostAccountID: row.HostAccountID, + KeyFingerprint: row.KeyFingerprint, + AccountName: row.AccountName, + AccountStatus: row.AccountStatus, + LastProbeStatus: row.LastProbeStatus, + LastProbeAt: row.LastProbeAt, + DisabledReason: row.DisabledReason, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } +} diff --git a/internal/app/provider_accounts_api_test.go b/internal/app/provider_accounts_api_test.go new file mode 100644 index 00000000..69108ee9 --- /dev/null +++ b/internal/app/provider_accounts_api_test.go @@ -0,0 +1,151 @@ +package app + +import ( + "context" + "encoding/json" + "path/filepath" + "testing" + + "sub2api-cn-relay-manager/internal/store/sqlite" +) + +func TestAPIListProviderAccountsReturnsRows(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + ListProviderAccounts: func(_ context.Context, req ListProviderAccountsRequest) ([]ProviderAccountInfo, error) { + if req.ProviderID != "deepseek-official" { + t.Fatalf("ProviderID = %q, want deepseek-official", req.ProviderID) + } + if req.AccountStatus != "disabled" { + t.Fatalf("AccountStatus = %q, want disabled", req.AccountStatus) + } + return []ProviderAccountInfo{{ + ID: 7, + HostID: "remote43", + ProviderID: "deepseek-official", + ProviderName: "DeepSeek Official", + HostAccountID: "9", + AccountName: "deepseek-01", + AccountStatus: "disabled", + DisabledReason: "manual_disable", + }}, nil + }, + }) + + request := httptestRequest(t, "GET", "/api/provider-accounts?provider_id=deepseek-official&account_status=disabled", nil, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, 200) + var payload map[string][]ProviderAccountInfo + if err := json.Unmarshal(response.Body().Bytes(), &payload); err != nil { + t.Fatalf("json.Unmarshal() error = %v", err) + } + accounts := payload["provider_accounts"] + if len(accounts) != 1 || accounts[0].ID != 7 || accounts[0].AccountStatus != "disabled" { + t.Fatalf("provider_accounts = %+v, want one disabled row id=7", accounts) + } +} + +func TestAPIDisableProviderAccountUsesPathID(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + DisableProviderAccount: func(_ context.Context, req UpdateProviderAccountStatusRequest) (ProviderAccountInfo, error) { + if req.AccountID != 42 { + t.Fatalf("AccountID = %d, want 42", req.AccountID) + } + if req.AccountStatus != "disabled" { + t.Fatalf("AccountStatus = %q, want disabled", req.AccountStatus) + } + if req.DisabledReason != "manual_disable" { + t.Fatalf("DisabledReason = %q, want manual_disable", req.DisabledReason) + } + return ProviderAccountInfo{ID: req.AccountID, AccountStatus: req.AccountStatus, DisabledReason: req.DisabledReason}, nil + }, + }) + + request := httptestRequest(t, "POST", "/api/provider-accounts/42/disable", map[string]any{"reason": "manual_disable"}, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, 200) + assertJSONContains(t, response.Body().Bytes(), "provider_account.id", float64(42)) + assertJSONContains(t, response.Body().Bytes(), "provider_account.account_status", "disabled") +} + +func TestNewActionSetProviderAccountListAndStatusFlow(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "provider-accounts.db") + dsn := "file:" + filepath.ToSlash(dbPath) + "?_busy_timeout=5000" + actions := NewActionSet(dsn) + ctx := context.Background() + + store, err := sqlite.Open(ctx, dsn) + if err != nil { + t.Fatalf("sqlite.Open() error = %v", err) + } + defer store.Close() + + hostID, err := store.Hosts().Create(ctx, sqlite.Host{ + HostID: "remote43", + BaseURL: "https://host.example.com", + HostVersion: "0.1.129", + CapabilityProbeJSON: `{"accounts":true}`, + AuthType: "apikey", + AuthToken: "host-key", + }) + if err != nil { + t.Fatalf("Hosts().Create() error = %v", err) + } + hostRow, err := store.Hosts().GetByID(ctx, hostID) + if err != nil { + t.Fatalf("Hosts().GetByID() error = %v", err) + } + packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "chk"}) + if err != nil { + t.Fatalf("Packs().Create() error = %v", err) + } + providerRowID, err := store.Providers().Create(ctx, sqlite.Provider{ + PackID: packID, + ProviderID: "deepseek-official", + DisplayName: "DeepSeek Official", + BaseURL: "https://api.deepseek.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + providerAccountID, err := store.ProviderAccounts().Create(ctx, sqlite.ProviderAccount{ + HostID: hostRow.ID, + ProviderID: providerRowID, + HostAccountID: "9", + KeyFingerprint: "sha256:abc", + AccountName: "deepseek-01", + AccountStatus: sqlite.ProviderAccountStatusActive, + LastProbeStatus: "passed", + LastProbeAt: "2026-05-29T00:00:00Z", + }) + if err != nil { + t.Fatalf("ProviderAccounts().Create() error = %v", err) + } + + listed, err := actions.ListProviderAccounts(ctx, ListProviderAccountsRequest{HostID: "remote43", ProviderID: "deepseek-official"}) + if err != nil { + t.Fatalf("ListProviderAccounts() error = %v", err) + } + if len(listed) != 1 || listed[0].ID != providerAccountID { + t.Fatalf("ListProviderAccounts() = %+v, want one row for id %d", listed, providerAccountID) + } + + disabled, err := actions.DisableProviderAccount(ctx, UpdateProviderAccountStatusRequest{ + AccountID: providerAccountID, + DisabledReason: "manual_disable", + }) + if err != nil { + t.Fatalf("DisableProviderAccount() error = %v", err) + } + if disabled.AccountStatus != sqlite.ProviderAccountStatusDisabled || disabled.DisabledReason != "manual_disable" { + t.Fatalf("DisableProviderAccount() = %+v", disabled) + } + + enabled, err := actions.EnableProviderAccount(ctx, UpdateProviderAccountStatusRequest{AccountID: providerAccountID}) + if err != nil { + t.Fatalf("EnableProviderAccount() error = %v", err) + } + if enabled.AccountStatus != sqlite.ProviderAccountStatusActive { + t.Fatalf("EnableProviderAccount() = %+v, want active", enabled) + } +} diff --git a/internal/provision/runtime_import_service.go b/internal/provision/runtime_import_service.go index 4ddd666b..81db2ccb 100644 --- a/internal/provision/runtime_import_service.go +++ b/internal/provision/runtime_import_service.go @@ -121,6 +121,9 @@ func (s *RuntimeImportService) Import(ctx context.Context, req RuntimeImportRequ if err := s.store.ImportBatches().UpdateStatus(ctx, batchID, report.BatchStatus, report.AccessStatus); err != nil { return RuntimeImportResult{}, err } + if err := sqlite.SyncProviderAccountsFromImportBatch(ctx, s.store, batchID); err != nil { + return RuntimeImportResult{}, err + } if importErr != nil { return RuntimeImportResult{BatchID: batchID, Report: report}, importErr } diff --git a/internal/provision/runtime_import_service_test.go b/internal/provision/runtime_import_service_test.go index 2fd72ca0..8d7fda7c 100644 --- a/internal/provision/runtime_import_service_test.go +++ b/internal/provision/runtime_import_service_test.go @@ -87,6 +87,9 @@ func TestRuntimeImportServicePersistsOperationalState(t *testing.T) { if got := queryCount(t, store.SQLDB(), "access_closure_records"); got != 1 { t.Fatalf("access_closure_records row count = %d, want 1", got) } + if got := queryCount(t, store.SQLDB(), "provider_accounts"); got != 2 { + t.Fatalf("provider_accounts row count = %d, want 2", got) + } var batchStatus string var accessStatus string @@ -111,6 +114,18 @@ func TestRuntimeImportServicePersistsOperationalState(t *testing.T) { if accountStatus != "passed" { t.Fatalf("account_status = %q, want passed", accountStatus) } + + var inventoryStatus string + var inventoryShadowGroup string + if err := store.SQLDB().QueryRowContext(context.Background(), "SELECT account_status, shadow_group_id FROM provider_accounts WHERE host_account_id = ? ORDER BY id LIMIT 1", "account_1").Scan(&inventoryStatus, &inventoryShadowGroup); err != nil { + t.Fatalf("query provider account inventory: %v", err) + } + if inventoryStatus != sqlite.ProviderAccountStatusActive { + t.Fatalf("provider_accounts.account_status = %q, want %q", inventoryStatus, sqlite.ProviderAccountStatusActive) + } + if inventoryShadowGroup == "" { + t.Fatal("provider_accounts.shadow_group_id = empty, want group id") + } } func TestRuntimeImportServiceIncludesMatchingHostOverlaysInReport(t *testing.T) { diff --git a/internal/store/migrations/0012_provider_accounts.sql b/internal/store/migrations/0012_provider_accounts.sql new file mode 100644 index 00000000..80a00b8e --- /dev/null +++ b/internal/store/migrations/0012_provider_accounts.sql @@ -0,0 +1,25 @@ +CREATE TABLE provider_accounts ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + host_id INTEGER NOT NULL, + provider_id INTEGER NOT NULL, + route_id TEXT NOT NULL DEFAULT '', + shadow_group_id TEXT NOT NULL DEFAULT '', + host_account_id TEXT NOT NULL, + key_fingerprint TEXT NOT NULL, + account_name TEXT NOT NULL DEFAULT '', + account_status TEXT NOT NULL, + last_probe_status TEXT NOT NULL DEFAULT '', + last_probe_at TEXT NOT NULL DEFAULT '', + disabled_reason TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, + CONSTRAINT fk_provider_accounts_host FOREIGN KEY (host_id) REFERENCES hosts(id) ON DELETE CASCADE, + CONSTRAINT fk_provider_accounts_provider FOREIGN KEY (provider_id) REFERENCES providers(id) ON DELETE CASCADE, + CONSTRAINT uq_provider_accounts_host_account UNIQUE (host_id, host_account_id), + CHECK (account_status IN ('active', 'disabled', 'deprecated', 'broken')) +); + +CREATE INDEX idx_provider_accounts_provider_host ON provider_accounts(provider_id, host_id); +CREATE INDEX idx_provider_accounts_status ON provider_accounts(account_status); +CREATE INDEX idx_provider_accounts_route_id ON provider_accounts(route_id); +CREATE INDEX idx_provider_accounts_shadow_group_id ON provider_accounts(shadow_group_id); diff --git a/internal/store/sqlite/db.go b/internal/store/sqlite/db.go index e92de045..d71cc701 100644 --- a/internal/store/sqlite/db.go +++ b/internal/store/sqlite/db.go @@ -30,6 +30,7 @@ type Queries struct { RouteDecisionLogs *RouteDecisionLogsRepo RouteFailoverEvents *RouteFailoverEventsRepo RouteStickyAudit *RouteStickyAuditRepo + ProviderAccounts *ProviderAccountsRepo ProviderDrafts *ProviderDraftsRepo ImportBatches *ImportBatchesRepo ImportBatchItems *ImportBatchItemsRepo @@ -127,6 +128,10 @@ func (db *DB) RouteStickyAudit() *RouteStickyAuditRepo { return db.queries.RouteStickyAudit } +func (db *DB) ProviderAccounts() *ProviderAccountsRepo { + return db.queries.ProviderAccounts +} + func (db *DB) ProviderDrafts() *ProviderDraftsRepo { return db.queries.ProviderDrafts } @@ -206,6 +211,7 @@ func newQueries(db execQuerier) *Queries { RouteDecisionLogs: newRouteDecisionLogsRepo(db), RouteFailoverEvents: newRouteFailoverEventsRepo(db), RouteStickyAudit: newRouteStickyAuditRepo(db), + ProviderAccounts: newProviderAccountsRepo(db), ProviderDrafts: newProviderDraftsRepo(db), ImportBatches: newImportBatchesRepo(db), ImportBatchItems: newImportBatchItemsRepo(db), diff --git a/internal/store/sqlite/db_test.go b/internal/store/sqlite/db_test.go index 879ad416..33b0436a 100644 --- a/internal/store/sqlite/db_test.go +++ b/internal/store/sqlite/db_test.go @@ -114,6 +114,7 @@ func TestOpenAppliesLogicalRoutingTables(t *testing.T) { "route_decision_logs", "route_failover_events", "route_sticky_audit", + "provider_accounts", } { found, err := tableExists(context.Background(), db, table) if err != nil { diff --git a/internal/store/sqlite/provider_accounts_repo.go b/internal/store/sqlite/provider_accounts_repo.go new file mode 100644 index 00000000..b097eec4 --- /dev/null +++ b/internal/store/sqlite/provider_accounts_repo.go @@ -0,0 +1,438 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "strings" +) + +const ( + ProviderAccountStatusActive = "active" + ProviderAccountStatusDisabled = "disabled" + ProviderAccountStatusDeprecated = "deprecated" + ProviderAccountStatusBroken = "broken" +) + +type ProviderAccount struct { + ID int64 + HostID int64 + ProviderID int64 + RouteID string + ShadowGroupID string + HostAccountID string + KeyFingerprint string + AccountName string + AccountStatus string + LastProbeStatus string + LastProbeAt string + DisabledReason string + CreatedAt string + UpdatedAt string +} + +type ProviderAccountListFilter struct { + HostID string + ProviderID string + RouteID string + ShadowGroupID string + AccountStatus string + Query string + Limit int +} + +type ProviderAccountView struct { + ID int64 `json:"id"` + HostID string `json:"host_id"` + ProviderID string `json:"provider_id"` + ProviderName string `json:"provider_name"` + RouteID string `json:"route_id,omitempty"` + LogicalGroupID string `json:"logical_group_id,omitempty"` + ShadowGroupID string `json:"shadow_group_id,omitempty"` + HostAccountID string `json:"host_account_id"` + KeyFingerprint string `json:"key_fingerprint"` + AccountName string `json:"account_name"` + AccountStatus string `json:"account_status"` + LastProbeStatus string `json:"last_probe_status,omitempty"` + LastProbeAt string `json:"last_probe_at,omitempty"` + DisabledReason string `json:"disabled_reason,omitempty"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type ProviderAccountsRepo struct { + db execQuerier +} + +func newProviderAccountsRepo(db execQuerier) *ProviderAccountsRepo { + return &ProviderAccountsRepo{db: db} +} + +func (r *ProviderAccountsRepo) Create(ctx context.Context, account ProviderAccount) (int64, error) { + account, err := normalizeProviderAccount(account) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts ( + host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, + account_name, account_status, last_probe_status, last_probe_at, disabled_reason + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, + account.HostID, + account.ProviderID, + account.RouteID, + account.ShadowGroupID, + account.HostAccountID, + account.KeyFingerprint, + account.AccountName, + account.AccountStatus, + account.LastProbeStatus, + account.LastProbeAt, + account.DisabledReason, + ) + if err != nil { + return 0, fmt.Errorf("insert provider account %q: %w", account.HostAccountID, err) + } + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read inserted provider account id for %q: %w", account.HostAccountID, err) + } + return id, nil +} + +func (r *ProviderAccountsRepo) Upsert(ctx context.Context, account ProviderAccount) (int64, error) { + account, err := normalizeProviderAccount(account) + if err != nil { + return 0, err + } + + result, err := r.db.ExecContext(ctx, `INSERT INTO provider_accounts ( + host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, + account_name, account_status, last_probe_status, last_probe_at, disabled_reason + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(host_id, host_account_id) DO UPDATE SET + provider_id = excluded.provider_id, + route_id = excluded.route_id, + shadow_group_id = excluded.shadow_group_id, + key_fingerprint = excluded.key_fingerprint, + account_name = excluded.account_name, + account_status = excluded.account_status, + last_probe_status = excluded.last_probe_status, + last_probe_at = excluded.last_probe_at, + disabled_reason = excluded.disabled_reason, + updated_at = CURRENT_TIMESTAMP`, + account.HostID, + account.ProviderID, + account.RouteID, + account.ShadowGroupID, + account.HostAccountID, + account.KeyFingerprint, + account.AccountName, + account.AccountStatus, + account.LastProbeStatus, + account.LastProbeAt, + account.DisabledReason, + ) + if err != nil { + return 0, fmt.Errorf("upsert provider account %q: %w", account.HostAccountID, err) + } + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("read provider account last insert id for %q: %w", account.HostAccountID, err) + } + if existing, err := r.GetByHostIDAndAccountID(ctx, account.HostID, account.HostAccountID); err == nil { + return existing.ID, nil + } + return id, nil +} + +func (r *ProviderAccountsRepo) GetByID(ctx context.Context, id int64) (ProviderAccount, error) { + if id <= 0 { + return ProviderAccount{}, fmt.Errorf("id is required") + } + return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE id = ?`, id) +} + +func (r *ProviderAccountsRepo) GetByHostIDAndAccountID(ctx context.Context, hostID int64, hostAccountID string) (ProviderAccount, error) { + if hostID <= 0 { + return ProviderAccount{}, fmt.Errorf("host_id is required") + } + hostAccountID = strings.TrimSpace(hostAccountID) + if hostAccountID == "" { + return ProviderAccount{}, fmt.Errorf("host_account_id is required") + } + return r.scanOne(ctx, `SELECT id, host_id, provider_id, route_id, shadow_group_id, host_account_id, key_fingerprint, account_name, account_status, last_probe_status, last_probe_at, disabled_reason, created_at, updated_at FROM provider_accounts WHERE host_id = ? AND host_account_id = ?`, hostID, hostAccountID) +} + +func (r *ProviderAccountsRepo) GetViewByID(ctx context.Context, id int64) (ProviderAccountView, error) { + if id <= 0 { + return ProviderAccountView{}, fmt.Errorf("id is required") + } + return r.scanViewOne(ctx, `SELECT + pa.id, + h.host_id, + p.provider_id, + p.display_name, + COALESCE(pa.route_id, ''), + COALESCE(lgr.logical_group_id, ''), + COALESCE(pa.shadow_group_id, ''), + pa.host_account_id, + pa.key_fingerprint, + pa.account_name, + pa.account_status, + COALESCE(pa.last_probe_status, ''), + COALESCE(pa.last_probe_at, ''), + COALESCE(pa.disabled_reason, ''), + pa.created_at, + pa.updated_at + FROM provider_accounts pa + JOIN hosts h ON h.id = pa.host_id + JOIN providers p ON p.id = pa.provider_id + LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id + WHERE pa.id = ?`, id) +} + +func (r *ProviderAccountsRepo) UpdateStatusByID(ctx context.Context, id int64, accountStatus, disabledReason string) error { + if id <= 0 { + return fmt.Errorf("id is required") + } + accountStatus = normalizeProviderAccountStatus(accountStatus) + if accountStatus == "" { + return fmt.Errorf("account_status is required") + } + disabledReason = strings.TrimSpace(disabledReason) + result, err := r.db.ExecContext(ctx, `UPDATE provider_accounts SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, accountStatus, disabledReason, id) + if err != nil { + return fmt.Errorf("update provider account %d status: %w", id, err) + } + rows, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("provider account %d rows affected: %w", id, err) + } + if rows == 0 { + return sql.ErrNoRows + } + return nil +} + +func (r *ProviderAccountsRepo) DeprecateMissingForScope(ctx context.Context, providerID, hostID int64, keepHostAccountIDs []string, reason string) error { + if providerID <= 0 { + return fmt.Errorf("provider_id is required") + } + if hostID <= 0 { + return fmt.Errorf("host_id is required") + } + reason = strings.TrimSpace(reason) + args := []any{ProviderAccountStatusDeprecated, reason, providerID, hostID} + query := `UPDATE provider_accounts + SET account_status = ?, disabled_reason = ?, updated_at = CURRENT_TIMESTAMP + WHERE provider_id = ? AND host_id = ?` + if len(keepHostAccountIDs) > 0 { + placeholders := make([]string, 0, len(keepHostAccountIDs)) + for _, id := range keepHostAccountIDs { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + continue + } + placeholders = append(placeholders, "?") + args = append(args, trimmed) + } + if len(placeholders) > 0 { + query += ` AND host_account_id NOT IN (` + strings.Join(placeholders, ",") + `)` + } + } + query += ` AND account_status IN ('active', 'broken')` + if _, err := r.db.ExecContext(ctx, query, args...); err != nil { + return fmt.Errorf("deprecate missing provider accounts for provider_id=%d host_id=%d: %w", providerID, hostID, err) + } + return nil +} + +func (r *ProviderAccountsRepo) List(ctx context.Context, filter ProviderAccountListFilter) ([]ProviderAccountView, error) { + query := `SELECT + pa.id, + h.host_id, + p.provider_id, + p.display_name, + COALESCE(pa.route_id, ''), + COALESCE(lgr.logical_group_id, ''), + COALESCE(pa.shadow_group_id, ''), + pa.host_account_id, + pa.key_fingerprint, + pa.account_name, + pa.account_status, + COALESCE(pa.last_probe_status, ''), + COALESCE(pa.last_probe_at, ''), + COALESCE(pa.disabled_reason, ''), + pa.created_at, + pa.updated_at + FROM provider_accounts pa + JOIN hosts h ON h.id = pa.host_id + JOIN providers p ON p.id = pa.provider_id + LEFT JOIN logical_group_routes lgr ON lgr.route_id = pa.route_id + WHERE 1 = 1` + args := make([]any, 0) + if value := strings.TrimSpace(filter.HostID); value != "" { + query += ` AND h.host_id = ?` + args = append(args, value) + } + if value := strings.TrimSpace(filter.ProviderID); value != "" { + query += ` AND p.provider_id = ?` + args = append(args, value) + } + if value := strings.TrimSpace(filter.RouteID); value != "" { + query += ` AND pa.route_id = ?` + args = append(args, value) + } + if value := strings.TrimSpace(filter.ShadowGroupID); value != "" { + query += ` AND pa.shadow_group_id = ?` + args = append(args, value) + } + if value := normalizeProviderAccountStatus(filter.AccountStatus); value != "" { + query += ` AND pa.account_status = ?` + args = append(args, value) + } + if value := strings.TrimSpace(filter.Query); value != "" { + like := "%" + strings.ToLower(value) + "%" + query += ` AND ( + LOWER(pa.host_account_id) LIKE ? OR + LOWER(pa.account_name) LIKE ? OR + LOWER(pa.key_fingerprint) LIKE ? OR + LOWER(p.provider_id) LIKE ? OR + LOWER(h.host_id) LIKE ? + )` + args = append(args, like, like, like, like, like) + } + query += ` ORDER BY pa.updated_at DESC, pa.id DESC` + limit := filter.Limit + if limit <= 0 { + limit = 200 + } + query += ` LIMIT ?` + args = append(args, limit) + + rows, err := r.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, fmt.Errorf("list provider accounts: %w", err) + } + defer rows.Close() + + views := make([]ProviderAccountView, 0) + for rows.Next() { + var view ProviderAccountView + if err := rows.Scan( + &view.ID, + &view.HostID, + &view.ProviderID, + &view.ProviderName, + &view.RouteID, + &view.LogicalGroupID, + &view.ShadowGroupID, + &view.HostAccountID, + &view.KeyFingerprint, + &view.AccountName, + &view.AccountStatus, + &view.LastProbeStatus, + &view.LastProbeAt, + &view.DisabledReason, + &view.CreatedAt, + &view.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan provider account view: %w", err) + } + views = append(views, view) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate provider accounts: %w", err) + } + return views, nil +} + +func (r *ProviderAccountsRepo) scanOne(ctx context.Context, query string, args ...any) (ProviderAccount, error) { + var row ProviderAccount + if err := r.db.QueryRowContext(ctx, query, args...).Scan( + &row.ID, + &row.HostID, + &row.ProviderID, + &row.RouteID, + &row.ShadowGroupID, + &row.HostAccountID, + &row.KeyFingerprint, + &row.AccountName, + &row.AccountStatus, + &row.LastProbeStatus, + &row.LastProbeAt, + &row.DisabledReason, + &row.CreatedAt, + &row.UpdatedAt, + ); err != nil { + return ProviderAccount{}, err + } + return row, nil +} + +func (r *ProviderAccountsRepo) scanViewOne(ctx context.Context, query string, args ...any) (ProviderAccountView, error) { + var view ProviderAccountView + if err := r.db.QueryRowContext(ctx, query, args...).Scan( + &view.ID, + &view.HostID, + &view.ProviderID, + &view.ProviderName, + &view.RouteID, + &view.LogicalGroupID, + &view.ShadowGroupID, + &view.HostAccountID, + &view.KeyFingerprint, + &view.AccountName, + &view.AccountStatus, + &view.LastProbeStatus, + &view.LastProbeAt, + &view.DisabledReason, + &view.CreatedAt, + &view.UpdatedAt, + ); err != nil { + return ProviderAccountView{}, err + } + return view, nil +} + +func normalizeProviderAccount(account ProviderAccount) (ProviderAccount, error) { + account.RouteID = strings.TrimSpace(account.RouteID) + account.ShadowGroupID = strings.TrimSpace(account.ShadowGroupID) + account.HostAccountID = strings.TrimSpace(account.HostAccountID) + account.KeyFingerprint = strings.TrimSpace(account.KeyFingerprint) + account.AccountName = strings.TrimSpace(account.AccountName) + account.AccountStatus = normalizeProviderAccountStatus(account.AccountStatus) + account.LastProbeStatus = strings.TrimSpace(account.LastProbeStatus) + account.LastProbeAt = strings.TrimSpace(account.LastProbeAt) + account.DisabledReason = strings.TrimSpace(account.DisabledReason) + + switch { + case account.HostID <= 0: + return ProviderAccount{}, fmt.Errorf("host_id is required") + case account.ProviderID <= 0: + return ProviderAccount{}, fmt.Errorf("provider_id is required") + case account.HostAccountID == "": + return ProviderAccount{}, fmt.Errorf("host_account_id is required") + case account.KeyFingerprint == "": + return ProviderAccount{}, fmt.Errorf("key_fingerprint is required") + case account.AccountStatus == "": + return ProviderAccount{}, fmt.Errorf("account_status is required") + } + return account, nil +} + +func normalizeProviderAccountStatus(status string) string { + switch strings.TrimSpace(status) { + case ProviderAccountStatusActive: + return ProviderAccountStatusActive + case ProviderAccountStatusDisabled: + return ProviderAccountStatusDisabled + case ProviderAccountStatusDeprecated: + return ProviderAccountStatusDeprecated + case ProviderAccountStatusBroken: + return ProviderAccountStatusBroken + default: + return "" + } +} diff --git a/internal/store/sqlite/provider_accounts_repo_test.go b/internal/store/sqlite/provider_accounts_repo_test.go new file mode 100644 index 00000000..86831fd1 --- /dev/null +++ b/internal/store/sqlite/provider_accounts_repo_test.go @@ -0,0 +1,286 @@ +package sqlite + +import ( + "context" + "testing" +) + +func TestProviderAccountsRepoCRUDAndFilters(t *testing.T) { + t.Parallel() + + store := openTestDBWithFK(t) + ctx := context.Background() + hostID := createTestHost(t, store) + packID := createTestPack(t, store) + providerID, err := store.Providers().Create(ctx, Provider{ + PackID: packID, + ProviderID: "deepseek-official", + DisplayName: "DeepSeek Official", + BaseURL: "https://api.deepseek.com", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + if _, err := store.LogicalGroups().Create(ctx, LogicalGroup{LogicalGroupID: "lg-1", DisplayName: "LG 1", Status: "active"}); err != nil { + t.Fatalf("LogicalGroups().Create() error = %v", err) + } + if _, err := store.LogicalGroupRoutes().Create(ctx, LogicalGroupRoute{ + RouteID: "route-1", + LogicalGroupID: "lg-1", + Name: "Route 1", + Status: "active", + Priority: 10, + Weight: 100, + ShadowGroupID: "shadow-group-1", + ShadowHostID: "shadow-host-1", + }); err != nil { + t.Fatalf("LogicalGroupRoutes().Create() error = %v", err) + } + + accountRepo := store.ProviderAccounts() + accountID, err := accountRepo.Create(ctx, ProviderAccount{ + HostID: hostID, + ProviderID: providerID, + RouteID: "route-1", + ShadowGroupID: "shadow-group-1", + HostAccountID: "account-1", + KeyFingerprint: "sha256:abc", + AccountName: "deepseek-01", + AccountStatus: ProviderAccountStatusActive, + LastProbeStatus: "passed", + LastProbeAt: "2026-05-29T00:00:00Z", + }) + if err != nil { + t.Fatalf("ProviderAccounts().Create() error = %v", err) + } + + got, err := accountRepo.GetByID(ctx, accountID) + if err != nil { + t.Fatalf("ProviderAccounts().GetByID() error = %v", err) + } + if got.HostAccountID != "account-1" || got.AccountStatus != ProviderAccountStatusActive { + t.Fatalf("ProviderAccounts().GetByID() = %+v", got) + } + + if _, err := accountRepo.Upsert(ctx, ProviderAccount{ + HostID: hostID, + ProviderID: providerID, + RouteID: "route-1", + ShadowGroupID: "shadow-group-1", + HostAccountID: "account-1", + KeyFingerprint: "sha256:abc", + AccountName: "deepseek-01", + AccountStatus: ProviderAccountStatusBroken, + LastProbeStatus: "failed", + LastProbeAt: "2026-05-29T01:00:00Z", + }); err != nil { + t.Fatalf("ProviderAccounts().Upsert() error = %v", err) + } + + view, err := accountRepo.GetViewByID(ctx, accountID) + if err != nil { + t.Fatalf("ProviderAccounts().GetViewByID() error = %v", err) + } + if view.ProviderID != "deepseek-official" || view.LogicalGroupID != "lg-1" || view.AccountStatus != ProviderAccountStatusBroken { + t.Fatalf("ProviderAccounts().GetViewByID() = %+v", view) + } + + rows, err := accountRepo.List(ctx, ProviderAccountListFilter{ + HostID: "host-" + sanitizeTestName(t.Name()), + ProviderID: "deepseek-official", + RouteID: "route-1", + ShadowGroupID: "shadow-group-1", + AccountStatus: ProviderAccountStatusBroken, + Query: "deepseek", + }) + if err != nil { + t.Fatalf("ProviderAccounts().List() error = %v", err) + } + if len(rows) != 1 || rows[0].ID != accountID { + t.Fatalf("ProviderAccounts().List() = %+v, want one row for account_id %d", rows, accountID) + } + + if err := accountRepo.UpdateStatusByID(ctx, accountID, ProviderAccountStatusDisabled, "manual_disable"); err != nil { + t.Fatalf("ProviderAccounts().UpdateStatusByID() error = %v", err) + } + got, err = accountRepo.GetByID(ctx, accountID) + if err != nil { + t.Fatalf("ProviderAccounts().GetByID() after status update error = %v", err) + } + if got.AccountStatus != ProviderAccountStatusDisabled || got.DisabledReason != "manual_disable" { + t.Fatalf("ProviderAccounts().GetByID() after status update = %+v", got) + } +} + +func TestSyncProviderAccountsFromImportBatchCreatesAndDeprecatesInventory(t *testing.T) { + t.Parallel() + + store := openTestDBWithFK(t) + ctx := context.Background() + hostID := createTestHost(t, store) + packID := createTestPack(t, store) + providerID, err := store.Providers().Create(ctx, Provider{ + PackID: packID, + ProviderID: "asxs-provider", + DisplayName: "ASXS Provider", + BaseURL: "https://api.asxs.top/v1", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batch1, err := store.ImportBatches().Create(ctx, ImportBatch{ + HostID: hostID, + PackID: packID, + ProviderID: providerID, + Mode: "strict", + BatchStatus: "succeeded", + AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("ImportBatches().Create(batch1) error = %v", err) + } + if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{ + BatchID: batch1, + KeyFingerprint: "sha256:key1", + AccountStatus: "passed", + ProbeSummaryJSON: `{"account_id":"account-1","probe_status":"passed"}`, + }); err != nil { + t.Fatalf("ImportBatchItems().Create(batch1) error = %v", err) + } + for _, resource := range []ManagedResource{ + {BatchID: batch1, HostID: hostID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "ASXS Group"}, + {BatchID: batch1, HostID: hostID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "asxs-01"}, + } { + if _, err := store.ManagedResources().Create(ctx, resource); err != nil { + t.Fatalf("ManagedResources().Create(batch1/%s) error = %v", resource.ResourceType, err) + } + } + if err := SyncProviderAccountsFromImportBatch(ctx, store, batch1); err != nil { + t.Fatalf("SyncProviderAccountsFromImportBatch(batch1) error = %v", err) + } + + account1, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1") + if err != nil { + t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-1) error = %v", err) + } + if account1.AccountStatus != ProviderAccountStatusActive || account1.ShadowGroupID != "group-1" { + t.Fatalf("account-1 = %+v, want active shadow group-1", account1) + } + + batch2, err := store.ImportBatches().Create(ctx, ImportBatch{ + HostID: hostID, + PackID: packID, + ProviderID: providerID, + Mode: "strict", + BatchStatus: "succeeded", + AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("ImportBatches().Create(batch2) error = %v", err) + } + if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{ + BatchID: batch2, + KeyFingerprint: "sha256:key2", + AccountStatus: "failed", + ProbeSummaryJSON: `{"account_id":"account-2","probe_status":"failed"}`, + }); err != nil { + t.Fatalf("ImportBatchItems().Create(batch2) error = %v", err) + } + for _, resource := range []ManagedResource{ + {BatchID: batch2, HostID: hostID, ResourceType: "group", HostResourceID: "group-2", ResourceName: "ASXS Group 2"}, + {BatchID: batch2, HostID: hostID, ResourceType: "account", HostResourceID: "account-2", ResourceName: "asxs-02"}, + } { + if _, err := store.ManagedResources().Create(ctx, resource); err != nil { + t.Fatalf("ManagedResources().Create(batch2/%s) error = %v", resource.ResourceType, err) + } + } + if err := SyncProviderAccountsFromImportBatch(ctx, store, batch2); err != nil { + t.Fatalf("SyncProviderAccountsFromImportBatch(batch2) error = %v", err) + } + + account1, err = store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1") + if err != nil { + t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-1 after batch2) error = %v", err) + } + if account1.AccountStatus != ProviderAccountStatusDeprecated || account1.DisabledReason != providerAccountDeprecatedMissingReason { + t.Fatalf("account-1 after batch2 = %+v, want deprecated missing_from_latest_batch", account1) + } + + account2, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-2") + if err != nil { + t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(account-2) error = %v", err) + } + if account2.AccountStatus != ProviderAccountStatusBroken || account2.LastProbeStatus != "failed" { + t.Fatalf("account-2 = %+v, want broken failed", account2) + } +} + +func TestSyncProviderAccountsFromImportBatchPreservesManualDisabledStatus(t *testing.T) { + t.Parallel() + + store := openTestDBWithFK(t) + ctx := context.Background() + hostID := createTestHost(t, store) + packID := createTestPack(t, store) + providerID, err := store.Providers().Create(ctx, Provider{ + PackID: packID, + ProviderID: "asxs-provider", + DisplayName: "ASXS Provider", + BaseURL: "https://api.asxs.top/v1", + Platform: "openai", + }) + if err != nil { + t.Fatalf("Providers().Create() error = %v", err) + } + batchID, err := store.ImportBatches().Create(ctx, ImportBatch{ + HostID: hostID, + PackID: packID, + ProviderID: providerID, + Mode: "strict", + BatchStatus: "succeeded", + AccessStatus: "subscription_ready", + }) + if err != nil { + t.Fatalf("ImportBatches().Create() error = %v", err) + } + if _, err := store.ImportBatchItems().Create(ctx, ImportBatchItem{ + BatchID: batchID, + KeyFingerprint: "sha256:key1", + AccountStatus: "passed", + ProbeSummaryJSON: `{"account_id":"account-1","probe_status":"passed"}`, + }); err != nil { + t.Fatalf("ImportBatchItems().Create() error = %v", err) + } + for _, resource := range []ManagedResource{ + {BatchID: batchID, HostID: hostID, ResourceType: "group", HostResourceID: "group-1", ResourceName: "ASXS Group"}, + {BatchID: batchID, HostID: hostID, ResourceType: "account", HostResourceID: "account-1", ResourceName: "asxs-01"}, + } { + if _, err := store.ManagedResources().Create(ctx, resource); err != nil { + t.Fatalf("ManagedResources().Create(%s) error = %v", resource.ResourceType, err) + } + } + if err := SyncProviderAccountsFromImportBatch(ctx, store, batchID); err != nil { + t.Fatalf("SyncProviderAccountsFromImportBatch() error = %v", err) + } + + account, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1") + if err != nil { + t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID() error = %v", err) + } + if err := store.ProviderAccounts().UpdateStatusByID(ctx, account.ID, ProviderAccountStatusDisabled, "manual_disable"); err != nil { + t.Fatalf("ProviderAccounts().UpdateStatusByID() error = %v", err) + } + + if err := SyncProviderAccountsFromImportBatch(ctx, store, batchID); err != nil { + t.Fatalf("SyncProviderAccountsFromImportBatch(second) error = %v", err) + } + account, err = store.ProviderAccounts().GetByHostIDAndAccountID(ctx, hostID, "account-1") + if err != nil { + t.Fatalf("ProviderAccounts().GetByHostIDAndAccountID(second) error = %v", err) + } + if account.AccountStatus != ProviderAccountStatusDisabled || account.DisabledReason != "manual_disable" { + t.Fatalf("account after resync = %+v, want disabled manual_disable preserved", account) + } +} diff --git a/internal/store/sqlite/provider_accounts_sync.go b/internal/store/sqlite/provider_accounts_sync.go new file mode 100644 index 00000000..1da7cd20 --- /dev/null +++ b/internal/store/sqlite/provider_accounts_sync.go @@ -0,0 +1,181 @@ +package sqlite + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "time" +) + +const providerAccountDeprecatedMissingReason = "missing_from_latest_batch" + +func SyncProviderAccountsFromLatestImportBatches(ctx context.Context, store *DB) error { + if store == nil { + return fmt.Errorf("store is required") + } + batches, err := store.ImportBatches().ListLatestReconcilable(ctx) + if err != nil { + return err + } + for _, batch := range batches { + if err := SyncProviderAccountsFromImportBatch(ctx, store, batch.ID); err != nil { + return err + } + } + return nil +} + +func SyncProviderAccountsFromImportBatch(ctx context.Context, store *DB, batchID int64) error { + if store == nil { + return fmt.Errorf("store is required") + } + if batchID <= 0 { + return fmt.Errorf("batch_id is required") + } + + batch, err := store.ImportBatches().GetByID(ctx, batchID) + if err != nil { + return fmt.Errorf("get import batch %d: %w", batchID, err) + } + switch strings.TrimSpace(batch.BatchStatus) { + case "succeeded", "partially_succeeded": + default: + return nil + } + resources, err := store.ManagedResources().GetByBatchID(ctx, batchID) + if err != nil { + return fmt.Errorf("get managed resources for batch %d: %w", batchID, err) + } + items, err := store.ImportBatchItems().GetByBatchID(ctx, batchID) + if err != nil { + return fmt.Errorf("get import batch items for batch %d: %w", batchID, err) + } + + nowText := time.Now().UTC().Format(time.RFC3339) + shadowGroupID := "" + for _, resource := range resources { + if strings.TrimSpace(resource.ResourceType) == "group" { + shadowGroupID = strings.TrimSpace(resource.HostResourceID) + break + } + } + + accountResources := make([]ManagedResource, 0) + for _, resource := range resources { + if strings.TrimSpace(resource.ResourceType) == "account" { + accountResources = append(accountResources, resource) + } + } + itemByAccountID, unmatchedItems := indexBatchItemsByAccountID(items) + keepAccountIDs := make([]string, 0, len(accountResources)) + for index, resource := range accountResources { + hostAccountID := strings.TrimSpace(resource.HostResourceID) + if hostAccountID == "" { + continue + } + keepAccountIDs = append(keepAccountIDs, hostAccountID) + match, ok := itemByAccountID[hostAccountID] + if !ok && index < len(unmatchedItems) { + match = unmatchedItems[index] + } + row := ProviderAccount{ + HostID: batch.HostID, + ProviderID: batch.ProviderID, + ShadowGroupID: shadowGroupID, + HostAccountID: hostAccountID, + KeyFingerprint: fallbackString(match.KeyFingerprint, "legacy:"+hostAccountID), + AccountName: fallbackString(resource.ResourceName, hostAccountID), + AccountStatus: providerAccountStatusFromLegacy(match.AccountStatus), + LastProbeStatus: strings.TrimSpace(match.ProbeStatus), + LastProbeAt: nowText, + } + if existing, err := store.ProviderAccounts().GetByHostIDAndAccountID(ctx, batch.HostID, hostAccountID); err == nil { + if strings.TrimSpace(existing.RouteID) != "" { + row.RouteID = existing.RouteID + } + if strings.TrimSpace(existing.ShadowGroupID) != "" { + row.ShadowGroupID = existing.ShadowGroupID + } + preserveManagedProviderAccountStatus(&row, existing) + } + if _, err := store.ProviderAccounts().Upsert(ctx, row); err != nil { + return fmt.Errorf("upsert provider account %q from batch %d: %w", hostAccountID, batchID, err) + } + } + if err := store.ProviderAccounts().DeprecateMissingForScope(ctx, batch.ProviderID, batch.HostID, keepAccountIDs, providerAccountDeprecatedMissingReason); err != nil { + return err + } + return nil +} + +type legacyBatchAccountProjection struct { + KeyFingerprint string + AccountStatus string + ProbeStatus string + AccountID string +} + +func indexBatchItemsByAccountID(items []ImportBatchItem) (map[string]legacyBatchAccountProjection, []legacyBatchAccountProjection) { + indexed := make(map[string]legacyBatchAccountProjection, len(items)) + unmatched := make([]legacyBatchAccountProjection, 0, len(items)) + for _, item := range items { + projection := legacyBatchAccountProjection{ + KeyFingerprint: strings.TrimSpace(item.KeyFingerprint), + AccountStatus: strings.TrimSpace(item.AccountStatus), + } + var payload map[string]any + if err := json.Unmarshal([]byte(defaultJSON(strings.TrimSpace(item.ProbeSummaryJSON), "{}")), &payload); err == nil { + if value, ok := payload["probe_status"].(string); ok { + projection.ProbeStatus = strings.TrimSpace(value) + } + if value, ok := payload["account_id"].(string); ok { + projection.AccountID = strings.TrimSpace(value) + } + } + if projection.AccountID != "" { + indexed[projection.AccountID] = projection + continue + } + unmatched = append(unmatched, projection) + } + return indexed, unmatched +} + +func providerAccountStatusFromLegacy(accountStatus string) string { + switch strings.TrimSpace(accountStatus) { + case "passed", "warning": + return ProviderAccountStatusActive + case ProviderAccountStatusDisabled: + return ProviderAccountStatusDisabled + case ProviderAccountStatusDeprecated: + return ProviderAccountStatusDeprecated + default: + return ProviderAccountStatusBroken + } +} + +func fallbackString(values ...string) string { + for _, value := range values { + if trimmed := strings.TrimSpace(value); trimmed != "" { + return trimmed + } + } + return "" +} + +func preserveManagedProviderAccountStatus(row *ProviderAccount, existing ProviderAccount) { + if row == nil { + return + } + switch strings.TrimSpace(existing.AccountStatus) { + case ProviderAccountStatusDisabled: + row.AccountStatus = ProviderAccountStatusDisabled + row.DisabledReason = strings.TrimSpace(existing.DisabledReason) + case ProviderAccountStatusDeprecated: + if strings.TrimSpace(existing.DisabledReason) != providerAccountDeprecatedMissingReason { + row.AccountStatus = ProviderAccountStatusDeprecated + row.DisabledReason = strings.TrimSpace(existing.DisabledReason) + } + } +} diff --git a/tests/integration/store_init_test.go b/tests/integration/store_init_test.go index d2c83852..7e36b50e 100644 --- a/tests/integration/store_init_test.go +++ b/tests/integration/store_init_test.go @@ -144,6 +144,7 @@ func TestStoreAppliesLatestMigration(t *testing.T) { "route_decision_logs", "route_failover_events", "route_sticky_audit", + "provider_accounts", } { if !tableExists(t, store.SQLDB(), table) { t.Fatalf("table %q does not exist after latest migration", table) @@ -272,6 +273,23 @@ func TestStoreAppliesLatestMigration(t *testing.T) { t.Fatalf("column %q missing from route_sticky_audit", column) } } + + for _, column := range []string{ + "host_id", + "provider_id", + "route_id", + "shadow_group_id", + "host_account_id", + "key_fingerprint", + "account_status", + "last_probe_status", + "last_probe_at", + "disabled_reason", + } { + if !tableColumnExists(t, store.SQLDB(), "provider_accounts", column) { + t.Fatalf("column %q missing from provider_accounts", column) + } + } } func TestStoreInitEnforcesLogicalRoutingConstraints(t *testing.T) {