Expand coverage for runtime and sqlite paths

This commit is contained in:
phamnazage-jpg
2026-05-23 10:55:57 +08:00
parent 2ad277743d
commit bcc67c4a8a
17 changed files with 3393 additions and 1 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -202,6 +202,64 @@ func TestBatchImportHTTP(t *testing.T) {
})
}
func TestBatchImportWrapperFunctions(t *testing.T) {
t.Parallel()
t.Run("handleCreateBatchImportRun requires action", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodPost, "/api/batch-import/runs", map[string]any{}, "")
rec := &responseRecorder{header: map[string][]string{}}
handleCreateBatchImportRun(rec, req, nil)
assertStatusCode(t, rec, http.StatusInternalServerError)
assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured")
})
t.Run("handleCreateBatchImportRun classifies action error", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodPost, "/api/batch-import/runs", map[string]any{
"host_id": "host-1",
"mode": "strict",
"access_mode": "self_service",
"probe_api_key": "probe-key",
"entries": []map[string]any{
{"base_url": "https://kimi.example.com/v1", "api_key": "sk-test"},
},
}, "")
rec := &responseRecorder{header: map[string][]string{}}
handleCreateBatchImportRun(rec, req, func(context.Context, CreateBatchImportRunRequest) (BatchImportRunCreateResponse, error) {
return BatchImportRunCreateResponse{}, fmt.Errorf("host x not found")
})
assertStatusCode(t, rec, http.StatusNotFound)
assertJSONContains(t, rec.Body().Bytes(), "error.code", "not_found")
})
t.Run("handleListBatchImportRuns requires action", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs", nil, "")
rec := &responseRecorder{header: map[string][]string{}}
handleListBatchImportRuns(rec, req, nil)
assertStatusCode(t, rec, http.StatusInternalServerError)
assertJSONContains(t, rec.Body().Bytes(), "error.code", "server_misconfigured")
})
t.Run("handleListBatchImportRuns returns empty array", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs?limit=5", nil, "")
rec := &responseRecorder{header: map[string][]string{}}
handleListBatchImportRuns(rec, req, func(_ context.Context, got ListBatchImportRunsRequest) (ListBatchImportRunsResponse, error) {
if got.Limit != 5 {
t.Fatalf("ListBatchImportRunsRequest.Limit = %d, want 5", got.Limit)
}
return ListBatchImportRunsResponse{}, nil
})
assertStatusCode(t, rec, http.StatusOK)
runs, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "runs")
if !ok || len(runs) != 0 {
t.Fatalf("runs = %#v, want empty array", runs)
}
})
}
func newBatchImportActionStubServer(t *testing.T) http.Handler {
t.Helper()

View File

@@ -173,6 +173,72 @@ func TestBatchRunsHTTP(t *testing.T) {
})
}
func TestBatchRunWrapperFunctions(t *testing.T) {
t.Parallel()
t.Run("handleGetBatchImportRun validates inputs", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1", nil, "")
rec := &responseRecorder{header: map[string][]string{}}
handleGetBatchImportRun(rec, req, nil)
assertStatusCode(t, rec, http.StatusInternalServerError)
req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs/", nil, "")
rec = &responseRecorder{header: map[string][]string{}}
handleGetBatchImportRun(rec, req, func(context.Context, string) (batch.RunSummaryProjection, error) {
return batch.RunSummaryProjection{}, nil
})
assertStatusCode(t, rec, http.StatusBadRequest)
assertJSONContains(t, rec.Body().Bytes(), "error.message", "run_id is required")
})
t.Run("handleListBatchImportRunItems validates run id and empty result", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items?has_warning=true", nil, "")
rec := &responseRecorder{header: map[string][]string{}}
handleListBatchImportRunItems(rec, req, nil)
assertStatusCode(t, rec, http.StatusInternalServerError)
req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs//items?has_warning=true", nil, "")
rec = &responseRecorder{header: map[string][]string{}}
handleListBatchImportRunItems(rec, req, func(context.Context, ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) {
return ListBatchImportRunItemsResponse{}, nil
})
assertStatusCode(t, rec, http.StatusBadRequest)
req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items?has_warning=true&limit=3", nil, "")
req.SetPathValue("run_id", "run-1")
rec = &responseRecorder{header: map[string][]string{}}
handleListBatchImportRunItems(rec, req, func(_ context.Context, got ListBatchImportRunItemsRequest) (ListBatchImportRunItemsResponse, error) {
if got.RunID != "run-1" || got.Limit != 3 || got.HasWarning == nil || !*got.HasWarning {
t.Fatalf("ListBatchImportRunItemsRequest = %+v, want parsed filters", got)
}
return ListBatchImportRunItemsResponse{}, nil
})
assertStatusCode(t, rec, http.StatusOK)
items, ok := decodeTopLevelArray(t, rec.Body().Bytes(), "items")
if !ok || len(items) != 0 {
t.Fatalf("items = %#v, want empty array", items)
}
})
t.Run("handleGetBatchImportRunItem validates ids", func(t *testing.T) {
t.Parallel()
req := httptestRequest(t, http.MethodGet, "/api/batch-import/runs/run-1/items/item-1", nil, "")
rec := &responseRecorder{header: map[string][]string{}}
handleGetBatchImportRunItem(rec, req, nil)
assertStatusCode(t, rec, http.StatusInternalServerError)
req = httptestRequest(t, http.MethodGet, "/api/batch-import/runs//items/", nil, "")
rec = &responseRecorder{header: map[string][]string{}}
handleGetBatchImportRunItem(rec, req, func(context.Context, GetBatchImportRunItemRequest) (batch.ItemDetailProjection, error) {
return batch.ItemDetailProjection{}, nil
})
assertStatusCode(t, rec, http.StatusBadRequest)
assertJSONContains(t, rec.Body().Bytes(), "error.message", "run_id and item_id are required")
})
}
func batchProbeProfile() probe.CapabilityProfile {
return probe.CapabilityProfile{
TransportProfile: probe.TransportProfile{

View File

@@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
@@ -467,6 +468,194 @@ func TestReconcileProbeAPIKeyRejectsUnsupportedClosureType(t *testing.T) {
}
}
func TestReconcileProbeAPIKeyRequiresAccessClosure(t *testing.T) {
t.Parallel()
store := openReconcileBackgroundTestStore(t)
defer closeAppTestStore(t, store)
batchID, _, _ := seedReconcileBackgroundBatch(t, store)
hostRow := mustGetBackgroundHost(t, store)
_, err := reconcileProbeAPIKey(context.Background(), store, hostRow, sqlite.ImportBatch{ID: batchID}, nil)
if err == nil || err.Error() != fmt.Sprintf("access closure not found for batch %d", batchID) {
t.Fatalf("reconcileProbeAPIKey() error = %v, want missing access closure", err)
}
}
func TestReconcileProbeAPIKeySubscriptionSuccess(t *testing.T) {
t.Parallel()
var assignCalls int
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && strings.HasPrefix(r.URL.RequestURI(), "/api/v1/admin/users?"):
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"items": []map[string]any{}}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84, "email": "relay-sub-user-1@sub2api.local"}})
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/users/84":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/users/84/balance":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 84}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/admin/subscriptions/assign":
assignCalls++
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 401}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/auth/login":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"access_token": "user-jwt"}})
case r.Method == http.MethodPost && r.URL.Path == "/api/v1/keys":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"id": 501, "key": "sk-relay-key", "name": "managed-key"}})
case r.Method == http.MethodPut && r.URL.Path == "/api/v1/admin/api-keys/501":
writeJSON(w, http.StatusOK, map[string]any{"data": map[string]any{"api_key": map[string]any{"id": 501}}})
default:
http.NotFound(w, r)
}
}))
defer server.Close()
store := openReconcileBackgroundTestStore(t)
defer closeAppTestStore(t, store)
hostPK, err := store.Hosts().Create(context.Background(), sqlite.Host{
HostID: "host-subscription",
BaseURL: server.URL,
HostVersion: "0.1.126",
CapabilityProbeJSON: "{}",
AuthType: "bearer",
AuthToken: "admin-token",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packPK := createBackgroundPack(t, store)
providerPK := createBackgroundProvider(t, store, packPK)
batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{
HostID: hostPK,
PackID: packPK,
ProviderID: providerPK,
Mode: provision.ImportModePartial,
BatchStatus: provision.BatchStatusSucceeded,
AccessStatus: provision.AccessStatusSubscriptionReady,
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{
BatchID: batchID,
HostID: hostPK,
ResourceType: "group",
HostResourceID: "101",
ResourceName: "group-101",
}); err != nil {
t.Fatalf("ManagedResources().Create(group) error = %v", err)
}
hostRow := sqlite.Host{HostID: "host-subscription", BaseURL: server.URL, AuthType: "bearer", AuthToken: "admin-token"}
got, err := reconcileProbeAPIKey(context.Background(), store, hostRow, sqlite.ImportBatch{ID: batchID}, []sqlite.AccessClosureRecord{{
BatchID: batchID,
ClosureType: provision.AccessModeSubscription,
Status: provision.AccessStatusSubscriptionReady,
DetailsJSON: `{"subscription_users":["crm-user-1"],"subscription_days":0}`,
}})
if err != nil {
t.Fatalf("reconcileProbeAPIKey() error = %v", err)
}
if !strings.HasPrefix(got, "sk-relay-") {
t.Fatalf("reconcileProbeAPIKey() = %q, want sk-relay-*", got)
}
if assignCalls != 1 {
t.Fatalf("subscription assign calls = %d, want 1 (EnsureSubscriptionAccess only)", assignCalls)
}
}
func TestReconcileProbeAPIKeySubscriptionRequiresHostAuth(t *testing.T) {
t.Parallel()
store := openReconcileBackgroundTestStore(t)
defer closeAppTestStore(t, store)
hostPK, err := store.Hosts().Create(context.Background(), sqlite.Host{
HostID: "host-subscription",
BaseURL: "https://sub2api.example.com",
HostVersion: "0.1.126",
CapabilityProbeJSON: "{}",
AuthType: "bearer",
AuthToken: "admin-token",
})
if err != nil {
t.Fatalf("Hosts().Create() error = %v", err)
}
packPK := createBackgroundPack(t, store)
providerPK := createBackgroundProvider(t, store, packPK)
batchID, err := store.ImportBatches().Create(context.Background(), sqlite.ImportBatch{
HostID: hostPK,
PackID: packPK,
ProviderID: providerPK,
Mode: provision.ImportModePartial,
BatchStatus: provision.BatchStatusSucceeded,
AccessStatus: provision.AccessStatusSubscriptionReady,
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.ManagedResources().Create(context.Background(), sqlite.ManagedResource{
BatchID: batchID,
HostID: hostPK,
ResourceType: "group",
HostResourceID: "101",
ResourceName: "group-101",
}); err != nil {
t.Fatalf("ManagedResources().Create(group) error = %v", err)
}
_, err = reconcileProbeAPIKey(context.Background(), store, sqlite.Host{BaseURL: "https://sub2api.example.com", AuthType: "bearer"}, sqlite.ImportBatch{ID: batchID}, []sqlite.AccessClosureRecord{{
BatchID: batchID,
ClosureType: provision.AccessModeSubscription,
Status: provision.AccessStatusSubscriptionReady,
DetailsJSON: `{"subscription_users":["crm-user-1"],"subscription_days":30}`,
}})
if err == nil || !strings.Contains(err.Error(), "auth.token is required") {
t.Fatalf("reconcileProbeAPIKey() error = %v, want auth.token is required", err)
}
}
func createBackgroundPack(t *testing.T, store *sqlite.DB) int64 {
t.Helper()
packPK, err := store.Packs().Create(context.Background(), sqlite.Pack{
PackID: "openai-cn-pack",
Version: "1.0.0",
Checksum: "checksum-1",
Vendor: "OpenAI CN",
TargetHost: "sub2api",
MinHostVersion: "0.1.126",
MaxHostVersion: "0.2.x",
ManifestJSON: `{"pack_id":"openai-cn-pack","version":"1.0.0","target_host":"sub2api"}`,
})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
return packPK
}
func createBackgroundProvider(t *testing.T, store *sqlite.DB, packPK int64) int64 {
t.Helper()
providerPK, err := store.Providers().Create(context.Background(), sqlite.Provider{
PackID: packPK,
ProviderID: "deepseek",
DisplayName: "DeepSeek",
BaseURL: "https://api.example.com",
Platform: "openai",
AccountType: "openai",
SmokeTestModel: "deepseek-chat",
ManifestJSON: `{"provider_id":"deepseek","base_url":"https://api.example.com","platform":"openai","account_type":"openai","smoke_test_model":"deepseek-chat"}`,
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
return providerPK
}
func mustGetBackgroundHost(t *testing.T, store *sqlite.DB) sqlite.Host {
t.Helper()

View File

@@ -671,6 +671,7 @@ type fakeHostAdapter struct {
testedModels map[string]string
disableResponsesCalls int
disabledResponsesAccountIDs []string
deleteErrors map[string]error
}
func (f *fakeHostAdapter) GetHostVersion(context.Context) (string, error) {
@@ -688,6 +689,9 @@ func (f *fakeHostAdapter) CreateGroup(_ context.Context, req sub2api.CreateGroup
return sub2api.GroupRef{ID: "group_1", Name: "g"}, nil
}
func (f *fakeHostAdapter) DeleteGroup(_ context.Context, groupID string) error {
if err := f.deleteErrors["group:"+groupID]; err != nil {
return err
}
f.deletedResources = append(f.deletedResources, "group:"+groupID)
return nil
}
@@ -703,6 +707,9 @@ func (f *fakeHostAdapter) UpdateChannel(_ context.Context, channelID string, req
return nil
}
func (f *fakeHostAdapter) DeleteChannel(_ context.Context, channelID string) error {
if err := f.deleteErrors["channel:"+channelID]; err != nil {
return err
}
f.deletedResources = append(f.deletedResources, "channel:"+channelID)
return nil
}
@@ -711,6 +718,9 @@ func (f *fakeHostAdapter) CreatePlan(context.Context, sub2api.CreatePlanRequest)
return sub2api.PlanRef{ID: "plan_1", Name: "p"}, nil
}
func (f *fakeHostAdapter) DeletePlan(_ context.Context, planID string) error {
if err := f.deleteErrors["plan:"+planID]; err != nil {
return err
}
f.deletedResources = append(f.deletedResources, "plan:"+planID)
return nil
}
@@ -725,6 +735,9 @@ func (f *fakeHostAdapter) BatchCreateAccounts(_ context.Context, req sub2api.Bat
return f.batchAccounts, nil
}
func (f *fakeHostAdapter) DeleteAccount(_ context.Context, accountID string) error {
if err := f.deleteErrors["account:"+accountID]; err != nil {
return err
}
f.callSequence = append(f.callSequence, "deleteAccount:"+accountID)
f.deletedResources = append(f.deletedResources, "account:"+accountID)
return nil

View File

@@ -188,3 +188,103 @@ func TestProviderStatusServiceRequiresPackIDWhenProviderIDIsAmbiguous(t *testing
t.Fatalf("GetStatus() error = %v, want ambiguous provider error", err)
}
}
func TestProviderStatusServiceRequiresLatestBatchWhenProviderHasNoImports(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
if _, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai",
}); err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
if err == nil || err.Error() != "latest import batch not found for provider" {
t.Fatalf("GetStatus() error = %v, want latest batch not found", err)
}
}
func TestProviderStatusServiceRequiresHostWhenProviderExistsOnMultipleHosts(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
hostA, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-a", BaseURL: "https://a.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{}`})
if err != nil {
t.Fatalf("Hosts().Create(host-a) error = %v", err)
}
hostB, err := store.Hosts().Create(ctx, sqlite.Host{HostID: "host-b", BaseURL: "https://b.example.com", HostVersion: "0.1.126", CapabilityProbeJSON: `{}`})
if err != nil {
t.Fatalf("Hosts().Create(host-b) error = %v", err)
}
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
if _, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady,
}); err != nil {
t.Fatalf("ImportBatches().Create(host-a) error = %v", err)
}
if _, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostB, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady,
}); err != nil {
t.Fatalf("ImportBatches().Create(host-b) error = %v", err)
}
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek"})
if err == nil || err.Error() != "provider exists on multiple hosts; host_id is required" {
t.Fatalf("GetStatus() error = %v, want host_id required", err)
}
}
func TestProviderStatusServiceFailsOnInvalidReconcileSummary(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
ctx := context.Background()
hostID := seedProvisionHost(t, store, "host-1", "https://sub2api.example.com")
packID, err := store.Packs().Create(ctx, sqlite.Pack{PackID: "openai-cn-pack", Version: "1.0.0", Checksum: "checksum-1"})
if err != nil {
t.Fatalf("Packs().Create() error = %v", err)
}
providerID, err := store.Providers().Create(ctx, sqlite.Provider{
PackID: packID, ProviderID: "deepseek", DisplayName: "DeepSeek", BaseURL: "https://api.deepseek.com", Platform: "openai",
})
if err != nil {
t.Fatalf("Providers().Create() error = %v", err)
}
batchID, err := store.ImportBatches().Create(ctx, sqlite.ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID, Mode: ImportModePartial, BatchStatus: BatchStatusSucceeded, AccessStatus: AccessStatusSelfServiceReady,
})
if err != nil {
t.Fatalf("ImportBatches().Create() error = %v", err)
}
if _, err := store.AccessClosures().Create(ctx, sqlite.AccessClosureRecord{
BatchID: batchID, ClosureType: AccessModeSelfService, Status: AccessStatusSelfServiceReady, DetailsJSON: `{"ok":true}`,
}); err != nil {
t.Fatalf("AccessClosures().Create() error = %v", err)
}
if _, err := store.ReconcileRuns().Create(ctx, sqlite.ReconcileRun{
BatchID: batchID, HostID: hostID, ProviderID: providerID, Status: "active", SummaryJSON: `{"missing_count":`,
}); err != nil {
t.Fatalf("ReconcileRuns().Create() error = %v", err)
}
_, err = NewProviderStatusService(store).GetStatus(ctx, ProviderQuery{ProviderID: "deepseek", HostID: "host-1"})
if err == nil || err.Error() == "" {
t.Fatal("GetStatus() error = nil, want decode reconcile summary failure")
}
}

View File

@@ -2,7 +2,9 @@ package provision
import (
"context"
"errors"
"reflect"
"strings"
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
@@ -69,3 +71,58 @@ func TestRollbackServiceRollbackStoredResourcesDeletesOnlyProvidedIDs(t *testing
t.Fatalf("deleted resources = %#v, want %#v", host.deletedResources, want)
}
}
func TestRollbackServiceRequiresHost(t *testing.T) {
svc := NewRollbackService(nil)
if _, err := svc.Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()}); err == nil || err.Error() != "rollback host is required" {
t.Fatalf("Rollback() error = %v, want rollback host is required", err)
}
if _, err := svc.RollbackStoredResources(context.Background(), nil); err == nil || err.Error() != "rollback host is required" {
t.Fatalf("RollbackStoredResources() error = %v, want rollback host is required", err)
}
}
func TestRollbackServiceCollectsDeleteErrors(t *testing.T) {
host := &fakeHostAdapter{
managedSnapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group_1", Name: "DeepSeek 默认分组"}},
Channels: []sub2api.NamedResource{{ID: "channel_1", Name: "DeepSeek 默认渠道"}},
Plans: []sub2api.NamedResource{{ID: "plan_1", Name: "DeepSeek 默认套餐"}},
Accounts: []sub2api.NamedResource{{ID: "account_1", Name: "deepseek-01"}, {ID: "account_2", Name: "deepseek-02"}},
},
deleteErrors: map[string]error{
"account:account_2": errors.New("account blocked"),
"channel:channel_1": errors.New("channel blocked"),
},
}
report, err := NewRollbackService(host).Rollback(context.Background(), RollbackRequest{Provider: sampleProviderManifest()})
if err == nil {
t.Fatal("Rollback() error = nil, want joined delete errors")
}
if report.AccountsDeleted != 1 || report.PlansDeleted != 1 || report.ChannelsDeleted != 0 || report.GroupsDeleted != 1 {
t.Fatalf("Rollback() report = %+v, want partial success counts", report)
}
if !strings.Contains(err.Error(), "delete account account_2") || !strings.Contains(err.Error(), "delete channel channel_1") {
t.Fatalf("Rollback() error = %v, want joined account/channel errors", err)
}
}
func TestNamedResourceSnapshotFromStoredIgnoresUnknownTypes(t *testing.T) {
snapshot := namedResourceSnapshotFromStored([]sqlite.ManagedResource{
{ResourceType: "group", HostResourceID: "group_1", ResourceName: "g1"},
{ResourceType: "mystery", HostResourceID: "mystery_1", ResourceName: "m1"},
{ResourceType: "account", HostResourceID: "account_1", ResourceName: "a1"},
})
if len(snapshot.Groups) != 1 || snapshot.Groups[0].ID != "group_1" {
t.Fatalf("snapshot.Groups = %#v, want group_1 only", snapshot.Groups)
}
if len(snapshot.Accounts) != 1 || snapshot.Accounts[0].ID != "account_1" {
t.Fatalf("snapshot.Accounts = %#v, want account_1 only", snapshot.Accounts)
}
if len(snapshot.Plans) != 0 || len(snapshot.Channels) != 0 {
t.Fatalf("snapshot = %#v, want unknown type ignored", snapshot)
}
}

View File

@@ -514,6 +514,101 @@ func TestRuntimeImportServiceRepeatedImportReusesManagedResources(t *testing.T)
}
}
func TestRuntimeImportServiceResolvesHostByBaseURL(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
seedProvisionHost(t, store, "host-1", "https://sub2api.example.com")
host := &fakeHostAdapter{
batchAccounts: []sub2api.AccountRef{{ID: "account_1"}},
testResults: map[string]sub2api.ProbeResult{
"account_1": {OK: true, Status: "passed"},
},
models: map[string][]sub2api.AccountModel{
"account_1": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{
OK: true,
StatusCode: 200,
HasExpectedModel: true,
Models: []string{"deepseek-chat"},
CompletionOK: true,
CompletionStatus: 200,
},
}
result, err := NewRuntimeImportService(store, host).Import(context.Background(), RuntimeImportRequest{
HostBaseURL: "https://sub2api.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{"key-1"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err != nil {
t.Fatalf("RuntimeImportService.Import() error = %v", err)
}
if result.BatchID <= 0 {
t.Fatalf("BatchID = %d, want positive id", result.BatchID)
}
}
func TestRuntimeImportServiceRejectsUnregisteredHostBaseURL(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
_, err := NewRuntimeImportService(store, &fakeHostAdapter{}).Import(context.Background(), RuntimeImportRequest{
HostBaseURL: "https://missing.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{"key-1"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err == nil || !strings.Contains(err.Error(), `host_id is required for unregistered host_base_url "https://missing.example.com"`) {
t.Fatalf("RuntimeImportService.Import() error = %v, want unregistered host_base_url error", err)
}
}
func TestRuntimeImportServiceRejectsHostBaseURLMismatch(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)
seedProvisionHost(t, store, "host-1", "https://sub2api.example.com")
_, err := NewRuntimeImportService(store, &fakeHostAdapter{}).Import(context.Background(), RuntimeImportRequest{
HostID: "host-1",
HostBaseURL: "https://other.example.com",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api", MinHostVersion: "0.1.126", MaxHostVersion: "0.2.x"},
Checksum: "checksum-1",
},
Provider: sampleProviderManifest(),
Mode: ImportModePartial,
Keys: []string{"key-1"},
Access: AccessRequest{
Mode: AccessModeSelfService,
ProbeAPIKey: "user-key",
},
})
if err == nil || err.Error() != `host "host-1" base_url mismatch: registered=https://sub2api.example.com runtime=https://other.example.com` {
t.Fatalf("RuntimeImportService.Import() error = %v, want base_url mismatch", err)
}
}
func TestRuntimeImportServiceImportReconcilesExistingChannelConfiguration(t *testing.T) {
store := openProvisionTestStore(t)
defer closeProvisionTestStore(t, store)

View File

@@ -8,6 +8,7 @@ import (
"testing"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
@@ -382,6 +383,127 @@ func TestStoredResourcesForReconcileMergesCurrentBatchAndSharedScaffoldingOnly(t
}
}
func TestReconcileValidatesTopLevelRequest(t *testing.T) {
t.Parallel()
req := Request{}
if _, err := NewService(nil, &reconcileHostStub{}).Reconcile(context.Background(), req); err == nil || err.Error() != "store is required" {
t.Fatalf("Reconcile(nil store) error = %v, want store is required", err)
}
store := openReconcileTestStore(t)
defer closeReconcileTestStore(t, store)
if _, err := NewService(store, nil).Reconcile(context.Background(), req); err == nil || err.Error() != "host adapter is required" {
t.Fatalf("Reconcile(nil host) error = %v, want host adapter is required", err)
}
if _, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), req); err == nil || err.Error() != "host_id is required" {
t.Fatalf("Reconcile(missing host_id) error = %v, want host_id is required", err)
}
if _, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), Request{HostID: "host-1"}); err == nil || err.Error() != "host_base_url is required" {
t.Fatalf("Reconcile(missing host_base_url) error = %v, want host_base_url is required", err)
}
}
func TestReconcileRejectsNonReconcilableLatestBatch(t *testing.T) {
t.Parallel()
store := openReconcileTestStore(t)
defer closeReconcileTestStore(t, store)
fixture := seedReconcileFixture(t, store)
mustExecReconcileSQL(t, store, `UPDATE import_batches SET batch_status = 'failed' WHERE id = ?`, fixture.batchID)
_, err := NewService(store, &reconcileHostStub{}).Reconcile(context.Background(), Request{
HostID: "host-1",
HostBaseURL: "https://sub2api.example.com",
AccessProbeAPIKey: "user-key",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api"},
},
Provider: pack.ProviderManifest{
ProviderID: "deepseek",
SmokeTestModel: "deepseek-chat",
},
})
if err == nil || err.Error() != "latest import batch is failed; run import again before reconcile" {
t.Fatalf("Reconcile() error = %v, want non-reconcilable batch error", err)
}
}
func TestReconcilePersistsActiveRunForHealthySnapshot(t *testing.T) {
t.Parallel()
store := openReconcileTestStore(t)
defer closeReconcileTestStore(t, store)
fixture := seedReconcileFixture(t, store)
mustCreateImportBatchItem(t, store, fixture.batchID, "fp-1", `{"account_id":"account-1"}`)
mustCreateManagedResource(t, store, fixture.batchID, fixture.hostPK, "group", "group-1", "group one")
mustCreateManagedResource(t, store, fixture.batchID, fixture.hostPK, "account", "account-1", "account one")
mustCreateAndLoadAccessClosures(t, store, fixture.batchID, sqlite.AccessClosureRecord{
BatchID: fixture.batchID,
ClosureType: accessModeSelfService,
Status: accessStatusSelfServiceReady,
})
host := &reconcileHostStub{
testResults: map[string]sub2api.ProbeResult{
"account-1": {OK: true, Status: accountStatusPassed, Message: "ok"},
},
models: map[string][]sub2api.AccountModel{
"account-1": {{ID: "deepseek-chat"}},
},
gatewayResult: sub2api.GatewayAccessResult{
OK: true,
StatusCode: 200,
HasExpectedModel: true,
CompletionOK: true,
Models: []string{"deepseek-chat"},
},
completionResults: []sub2api.GatewayCompletionResult{
{OK: true, StatusCode: 200},
},
managedResourceSnapshot: sub2api.ManagedResourceSnapshot{
Groups: []sub2api.NamedResource{{ID: "group-1", Name: "group one"}},
Accounts: []sub2api.NamedResource{{ID: "account-1", Name: "account one"}},
},
}
result, err := NewService(store, host).Reconcile(context.Background(), Request{
HostID: "host-1",
HostBaseURL: "https://sub2api.example.com",
AccessProbeAPIKey: "user-key",
Pack: pack.LoadedPack{
Manifest: pack.Manifest{PackID: "openai-cn-pack", Version: "1.0.0", TargetHost: "sub2api"},
},
Provider: pack.ProviderManifest{
ProviderID: "deepseek",
SmokeTestModel: "deepseek-chat",
},
})
if err != nil {
t.Fatalf("Reconcile() error = %v", err)
}
if result.Status != "active" {
t.Fatalf("result.Status = %q, want active", result.Status)
}
if result.MissingCount != 0 || result.ExtraCount != 0 || result.ProbeFailureCount != 0 {
t.Fatalf("result = %+v, want no drift and no probe failures", result)
}
if result.AccessStatus != accessStatusSelfServiceReady {
t.Fatalf("result.AccessStatus = %q, want %q", result.AccessStatus, accessStatusSelfServiceReady)
}
runs, err := store.ReconcileRuns().GetByBatchID(context.Background(), fixture.batchID)
if err != nil {
t.Fatalf("ReconcileRuns().GetByBatchID() error = %v", err)
}
if len(runs) != 1 || runs[0].Status != "active" {
t.Fatalf("persisted reconcile runs = %+v, want single active run", runs)
}
}
type reconcileFixture struct {
hostPK int64
packPK int64
@@ -512,6 +634,13 @@ func mustCreateAndLoadAccessClosures(t *testing.T, store *sqlite.DB, batchID int
return loaded
}
func mustExecReconcileSQL(t *testing.T, store *sqlite.DB, query string, args ...any) {
t.Helper()
if _, err := store.SQLDB().Exec(query, args...); err != nil {
t.Fatalf("Exec(%q) error = %v", query, err)
}
}
type reconcileHostStub struct {
disableResponsesErr error
gatewayErr error
@@ -525,6 +654,8 @@ type reconcileHostStub struct {
completionCalls int
disableResponsesCalls int
disabledResponsesAccounts []string
managedResourceSnapshot sub2api.ManagedResourceSnapshot
listManagedResourcesErr error
}
func (h *reconcileHostStub) GetHostVersion(context.Context) (string, error) {
@@ -632,5 +763,8 @@ func (h *reconcileHostStub) DisableOpenAIResponsesAPI(_ context.Context, account
}
func (h *reconcileHostStub) ListManagedResources(context.Context, sub2api.ListManagedResourcesRequest) (sub2api.ManagedResourceSnapshot, error) {
return sub2api.ManagedResourceSnapshot{}, nil
if h.listManagedResourcesErr != nil {
return sub2api.ManagedResourceSnapshot{}, h.listManagedResourcesErr
}
return h.managedResourceSnapshot, nil
}

View File

@@ -0,0 +1,57 @@
package sqlite
import (
"context"
"testing"
)
func TestAccessClosureRecordsRepoCreateValidationAndDefaults(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestDB(t)
repo := store.AccessClosures()
if _, err := repo.Create(ctx, AccessClosureRecord{}); err == nil || err.Error() != "batch_id is required" {
t.Fatalf("Create() error = %v, want batch_id is required", err)
}
if _, err := repo.Create(ctx, AccessClosureRecord{BatchID: 1}); err == nil || err.Error() != "closure_type is required" {
t.Fatalf("Create() error = %v, want closure_type is required", err)
}
if _, err := repo.Create(ctx, AccessClosureRecord{BatchID: 1, ClosureType: "self_service"}); err == nil || err.Error() != "status is required" {
t.Fatalf("Create() error = %v, want status is required", err)
}
batchID := createTestBatch(t, store)
recordID, err := repo.Create(ctx, AccessClosureRecord{
BatchID: batchID,
ClosureType: "self_service",
Status: "ready",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if recordID <= 0 {
t.Fatalf("recordID = %d, want positive id", recordID)
}
records, err := repo.GetByBatchID(ctx, batchID)
if err != nil {
t.Fatalf("GetByBatchID() error = %v", err)
}
if len(records) != 1 {
t.Fatalf("len(records) = %d, want 1", len(records))
}
if records[0].DetailsJSON != "{}" {
t.Fatalf("DetailsJSON = %q, want {}", records[0].DetailsJSON)
}
}
func TestAccessClosureRecordsRepoGetByBatchIDValidation(t *testing.T) {
t.Parallel()
_, err := openTestDB(t).AccessClosures().GetByBatchID(context.Background(), 0)
if err == nil || err.Error() != "batch_id is required" {
t.Fatalf("GetByBatchID() error = %v, want batch_id is required", err)
}
}

View File

@@ -159,3 +159,115 @@ func TestReadMigrationNotFound(t *testing.T) {
t.Fatal("readMigration('nonexistent.sql') error = nil, want error")
}
}
func TestEnsureMigrationLedgerAndLoadAppliedMigrations(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx() error = %v", err)
}
defer tx.Rollback()
if err := ensureMigrationLedger(context.Background(), tx); err != nil {
t.Fatalf("ensureMigrationLedger() error = %v", err)
}
if err := ensureMigrationLedger(context.Background(), tx); err != nil {
t.Fatalf("ensureMigrationLedger() second call error = %v", err)
}
if _, err := tx.ExecContext(context.Background(), "INSERT INTO schema_migrations (version) VALUES (?)", "9999_test.sql"); err != nil {
t.Fatalf("insert schema_migrations error = %v", err)
}
applied, err := loadAppliedMigrations(context.Background(), tx)
if err != nil {
t.Fatalf("loadAppliedMigrations() error = %v", err)
}
if !applied["9999_test.sql"] {
t.Fatalf("loadAppliedMigrations() = %#v, want 9999_test.sql=true", applied)
}
}
func TestBackfillLegacySchemaIfNeededRecordsCompleteLegacySchema(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "legacy.db")
rawDB, err := sql.Open("sqlite", "file:"+filepath.ToSlash(dbPath))
if err != nil {
t.Fatalf("sql.Open() error = %v", err)
}
defer rawDB.Close()
for _, ddl := range []string{
"CREATE TABLE hosts (id INTEGER PRIMARY KEY, host_id TEXT NOT NULL, base_url TEXT NOT NULL, host_version TEXT NOT NULL, capability_probe_json TEXT NOT NULL DEFAULT '', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)",
"CREATE TABLE packs (id INTEGER PRIMARY KEY, pack_id TEXT NOT NULL, version TEXT NOT NULL, checksum TEXT NOT NULL, manifest_json TEXT NOT NULL DEFAULT '{}', metadata_json TEXT NOT NULL DEFAULT '{}', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)",
"CREATE TABLE providers (id INTEGER PRIMARY KEY, pack_id INTEGER NOT NULL, provider_id TEXT NOT NULL, display_name TEXT NOT NULL, base_url TEXT NOT NULL, platform TEXT NOT NULL, account_type TEXT NOT NULL DEFAULT 'apikey', smoke_test_model TEXT NOT NULL DEFAULT '', provider_manifest_json TEXT NOT NULL DEFAULT '{}', created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP, updated_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP)",
} {
if _, err := rawDB.ExecContext(context.Background(), ddl); err != nil {
t.Fatalf("Exec legacy ddl error = %v", err)
}
}
tx, err := rawDB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx() error = %v", err)
}
defer tx.Rollback()
if err := ensureMigrationLedger(context.Background(), tx); err != nil {
t.Fatalf("ensureMigrationLedger() error = %v", err)
}
applied, err := loadAppliedMigrations(context.Background(), tx)
if err != nil {
t.Fatalf("loadAppliedMigrations() error = %v", err)
}
if err := backfillLegacySchemaIfNeeded(context.Background(), tx, []string{"0001_init.sql"}, applied); err != nil {
t.Fatalf("backfillLegacySchemaIfNeeded() error = %v", err)
}
if !applied["0001_init.sql"] {
t.Fatalf("applied = %#v, want 0001_init.sql marked", applied)
}
}
func TestBackfillLegacySchemaIfNeededRejectsPartialLegacySchema(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "legacy-partial.db")
rawDB, err := sql.Open("sqlite", "file:"+filepath.ToSlash(dbPath))
if err != nil {
t.Fatalf("sql.Open() error = %v", err)
}
defer rawDB.Close()
if _, err := rawDB.ExecContext(context.Background(), "CREATE TABLE hosts (id INTEGER PRIMARY KEY)"); err != nil {
t.Fatalf("Exec partial legacy ddl error = %v", err)
}
tx, err := rawDB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx() error = %v", err)
}
defer tx.Rollback()
if err := ensureMigrationLedger(context.Background(), tx); err != nil {
t.Fatalf("ensureMigrationLedger() error = %v", err)
}
applied, err := loadAppliedMigrations(context.Background(), tx)
if err != nil {
t.Fatalf("loadAppliedMigrations() error = %v", err)
}
err = backfillLegacySchemaIfNeeded(context.Background(), tx, []string{"0001_init.sql"}, applied)
if err == nil || err.Error() != "legacy sqlite schema is partially applied without schema_migrations" {
t.Fatalf("backfillLegacySchemaIfNeeded() error = %v, want partial legacy schema error", err)
}
}
func TestRollbackMigrationReturnsOriginalErrorWhenRollbackSucceeds(t *testing.T) {
store := openTestDB(t)
tx, err := store.SQLDB().BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx() error = %v", err)
}
baseErr := errors.New("apply failed")
if err := rollbackMigration(tx, baseErr); !errors.Is(err, baseErr) {
t.Fatalf("rollbackMigration() error = %v, want original error", err)
}
}

View File

@@ -210,6 +210,30 @@ func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
}
}
func TestHostsRepoGetByBaseURL(t *testing.T) {
store := openTestDB(t)
createTestHostWithBaseURL(t, store, "host-base-url", "https://base-url.example.com")
got, err := store.Hosts().GetByBaseURL(context.Background(), "https://base-url.example.com")
if err != nil {
t.Fatalf("GetByBaseURL() error = %v", err)
}
if got.HostID != "host-base-url" {
t.Fatalf("GetByBaseURL() host_id = %q, want host-base-url", got.HostID)
}
}
func TestHostsRepoGetByBaseURLErrors(t *testing.T) {
store := openTestDB(t)
if _, err := store.Hosts().GetByBaseURL(context.Background(), ""); err == nil {
t.Fatal("GetByBaseURL(\"\") error = nil, want error")
}
if _, err := store.Hosts().GetByBaseURL(context.Background(), "https://missing.example.com"); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByBaseURL(missing) error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoListAll(t *testing.T) {
store := openTestDB(t)
@@ -273,6 +297,52 @@ func TestHostsRepoUpdateProbeByHostID(t *testing.T) {
}
}
func TestHostsRepoUpdateConnectionByHostID(t *testing.T) {
store := openTestDB(t)
createTestHost(t, store)
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()), "https://updated.example.com", "0.3.0", "", "", "token-1"); err != nil {
t.Fatalf("UpdateConnectionByHostID() error = %v", err)
}
host, err := store.Hosts().GetByHostID(context.Background(), "host-"+sanitizeTestName(t.Name()))
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if host.BaseURL != "https://updated.example.com" {
t.Fatalf("BaseURL = %q, want updated URL", host.BaseURL)
}
if host.HostVersion != "0.3.0" {
t.Fatalf("HostVersion = %q, want 0.3.0", host.HostVersion)
}
if host.CapabilityProbeJSON != "{}" {
t.Fatalf("CapabilityProbeJSON = %q, want {}", host.CapabilityProbeJSON)
}
if host.AuthType != "apikey" {
t.Fatalf("AuthType = %q, want apikey", host.AuthType)
}
if host.AuthToken != "token-1" {
t.Fatalf("AuthToken = %q, want token-1", host.AuthToken)
}
}
func TestHostsRepoUpdateConnectionByHostIDErrors(t *testing.T) {
store := openTestDB(t)
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty host_id error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty base_url error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() empty host_version error = nil")
}
if err := store.Hosts().UpdateConnectionByHostID(context.Background(), "missing", "https://example.com", "0.2.0", "{}", "apikey", "token"); err == nil {
t.Fatal("UpdateConnectionByHostID() missing host error = nil")
}
}
func TestHostsRepoDeleteByHostIDNotFound(t *testing.T) {
store := openTestDB(t)
err := store.Hosts().DeleteByHostID(context.Background(), "nonexistent")

View File

@@ -102,6 +102,148 @@ func TestImportBatchesRepoListByProviderID(t *testing.T) {
}
}
func TestImportBatchesRepoGetLatestByProviderIDAndHostID(t *testing.T) {
store := openTestDB(t)
hostA := createTestHost(t, store)
hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b.example.com")
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
}); err != nil {
t.Fatalf("Create(hostA older) error = %v", err)
}
latestA, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("Create(hostA newer) error = %v", err)
}
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "partially_succeeded", AccessStatus: "degraded",
}); err != nil {
t.Fatalf("Create(hostB) error = %v", err)
}
got, err := store.ImportBatches().GetLatestByProviderIDAndHostID(context.Background(), providerID, hostA)
if err != nil {
t.Fatalf("GetLatestByProviderIDAndHostID() error = %v", err)
}
if got.ID != latestA {
t.Fatalf("GetLatestByProviderIDAndHostID() id = %d, want %d", got.ID, latestA)
}
}
func TestImportBatchesRepoListByProviderIDAndHostID(t *testing.T) {
store := openTestDB(t)
hostA := createTestHost(t, store)
hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b2.example.com")
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
olderA, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("Create(hostA older) error = %v", err)
}
newerA, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("Create(hostA newer) error = %v", err)
}
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packID, ProviderID: providerID,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
}); err != nil {
t.Fatalf("Create(hostB) error = %v", err)
}
batches, err := store.ImportBatches().ListByProviderIDAndHostID(context.Background(), providerID, hostA)
if err != nil {
t.Fatalf("ListByProviderIDAndHostID() error = %v", err)
}
if len(batches) != 2 {
t.Fatalf("ListByProviderIDAndHostID() len = %d, want 2", len(batches))
}
if batches[0].ID != newerA || batches[1].ID != olderA {
t.Fatalf("ListByProviderIDAndHostID() ids = [%d %d], want [%d %d]", batches[0].ID, batches[1].ID, newerA, olderA)
}
}
func TestImportBatchesRepoListLatestReconcilable(t *testing.T) {
store := openTestDB(t)
hostA := createTestHost(t, store)
hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://host-b3.example.com")
packAID := createTestPack(t, store)
packBID := createTestPackWithSuffix(t, store, "reconcile")
providerA := createTestProviderWithPack(t, store, packAID)
providerB := createTestProviderWithPack(t, store, packBID)
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packAID, ProviderID: providerA,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
}); err != nil {
t.Fatalf("Create(providerA older) error = %v", err)
}
latestA, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packAID, ProviderID: providerA,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("Create(providerA latest) error = %v", err)
}
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packBID, ProviderID: providerB,
Mode: "strict", BatchStatus: "partially_succeeded", AccessStatus: "degraded",
}); err != nil {
t.Fatalf("Create(providerB latest) error = %v", err)
}
if _, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packBID, ProviderID: providerB,
Mode: "strict", BatchStatus: "failed", AccessStatus: "broken",
}); err != nil {
t.Fatalf("Create(providerB newer failed) error = %v", err)
}
batches, err := store.ImportBatches().ListLatestReconcilable(context.Background())
if err != nil {
t.Fatalf("ListLatestReconcilable() error = %v", err)
}
if len(batches) != 1 {
t.Fatalf("ListLatestReconcilable() len = %d, want 1 because providerB latest batch is not reconcilable", len(batches))
}
if batches[0].ID != latestA {
t.Fatalf("ListLatestReconcilable() id = %d, want %d", batches[0].ID, latestA)
}
batchID, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packBID, ProviderID: providerB,
Mode: "strict", BatchStatus: "partially_succeeded", AccessStatus: "degraded",
})
if err != nil {
t.Fatalf("Create(providerB new reconcilable) error = %v", err)
}
batches, err = store.ImportBatches().ListLatestReconcilable(context.Background())
if err != nil {
t.Fatalf("ListLatestReconcilable() second call error = %v", err)
}
if len(batches) != 2 {
t.Fatalf("ListLatestReconcilable() second len = %d, want 2", len(batches))
}
if batches[0].ID != batchID || batches[1].ID != latestA {
t.Fatalf("ListLatestReconcilable() ids = [%d %d], want [%d %d]", batches[0].ID, batches[1].ID, batchID, latestA)
}
}
func TestImportBatchesRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.ImportBatches().GetByID(context.Background(), 999)
@@ -265,6 +407,102 @@ func TestManagedResourcesRepoGetByBatchIDEmpty(t *testing.T) {
}
}
func TestManagedResourcesRepoGetByResourceIdentity(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
batch, err := store.ImportBatches().GetByID(context.Background(), batchID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if _, err := store.ManagedResources().Create(context.Background(), ManagedResource{
BatchID: batchID, HostID: batch.HostID, ResourceType: "channel", HostResourceID: "channel-1", ResourceName: "Channel 1",
}); err != nil {
t.Fatalf("Create() error = %v", err)
}
resource, err := store.ManagedResources().GetByResourceIdentity(context.Background(), batch.HostID, "channel", "channel-1")
if err != nil {
t.Fatalf("GetByResourceIdentity() error = %v", err)
}
if resource.ResourceName != "Channel 1" {
t.Fatalf("GetByResourceIdentity() resource_name = %q, want Channel 1", resource.ResourceName)
}
}
func TestManagedResourcesRepoListByProviderScopes(t *testing.T) {
store := openTestDB(t)
hostA := createTestHost(t, store)
hostB := createTestHostWithBaseURL(t, store, "host-b-"+sanitizeTestName(t.Name()), "https://managed-host-b.example.com")
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
otherProviderID := createTestProviderWithPack(t, store, createTestPackWithSuffix(t, store, "managed"))
batchA, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostA, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("Create(batchA) error = %v", err)
}
batchB, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "partially_succeeded", AccessStatus: "degraded",
})
if err != nil {
t.Fatalf("Create(batchB) error = %v", err)
}
otherBatch, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostB, PackID: packID, ProviderID: otherProviderID,
Mode: "partial", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
if err != nil {
t.Fatalf("Create(otherBatch) error = %v", err)
}
for _, resource := range []ManagedResource{
{BatchID: batchA, HostID: hostA, ResourceType: "group", HostResourceID: "g-a", ResourceName: "Group A"},
{BatchID: batchB, HostID: hostB, ResourceType: "account", HostResourceID: "a-b", ResourceName: "Account B"},
{BatchID: otherBatch, HostID: hostB, ResourceType: "channel", HostResourceID: "c-other", ResourceName: "Other"},
} {
if _, err := store.ManagedResources().Create(context.Background(), resource); err != nil {
t.Fatalf("Create(%s) error = %v", resource.HostResourceID, err)
}
}
byProvider, err := store.ManagedResources().ListByProviderID(context.Background(), providerID)
if err != nil {
t.Fatalf("ListByProviderID() error = %v", err)
}
if len(byProvider) != 2 {
t.Fatalf("ListByProviderID() len = %d, want 2", len(byProvider))
}
byProviderHost, err := store.ManagedResources().ListByProviderIDAndHostID(context.Background(), providerID, hostB)
if err != nil {
t.Fatalf("ListByProviderIDAndHostID() error = %v", err)
}
if len(byProviderHost) != 1 {
t.Fatalf("ListByProviderIDAndHostID() len = %d, want 1", len(byProviderHost))
}
if byProviderHost[0].HostResourceID != "a-b" {
t.Fatalf("ListByProviderIDAndHostID() resource = %q, want a-b", byProviderHost[0].HostResourceID)
}
}
func TestManagedResourcesRepoQueryValidationErrors(t *testing.T) {
store := openTestDB(t)
if _, err := store.ManagedResources().GetByResourceIdentity(context.Background(), 0, "group", "g-1"); err == nil {
t.Fatal("GetByResourceIdentity() host_id=0 error = nil")
}
if _, err := store.ManagedResources().ListByProviderID(context.Background(), 0); err == nil {
t.Fatal("ListByProviderID() provider_id=0 error = nil")
}
if _, err := store.ManagedResources().ListByProviderIDAndHostID(context.Background(), 1, 0); err == nil {
t.Fatal("ListByProviderIDAndHostID() host_id=0 error = nil")
}
}
func TestManagedResourcesRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {

View File

@@ -0,0 +1,88 @@
package sqlite
import (
"context"
"testing"
)
func TestImportRunItemEventsRepoAppendValidationAndDefaults(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestDB(t)
repo := store.ImportRunEvents()
if err := repo.Append(ctx, ImportRunItemEvent{}); err == nil || err.Error() != "event_id is required" {
t.Fatalf("Append() error = %v, want event_id is required", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1"}); err == nil || err.Error() != "run_id is required" {
t.Fatalf("Append() error = %v, want run_id is required", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1"}); err == nil || err.Error() != "item_id is required" {
t.Fatalf("Append() error = %v, want item_id is required", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1"}); err == nil || err.Error() != "event_type is required" {
t.Fatalf("Append() error = %v, want event_type is required", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1", EventType: "confirm"}); err == nil || err.Error() != "stage is required" {
t.Fatalf("Append() error = %v, want stage is required", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{EventID: "evt-1", RunID: "run-1", ItemID: "item-1", EventType: "confirm", Stage: "confirm"}); err == nil || err.Error() != "message is required" {
t.Fatalf("Append() error = %v, want message is required", err)
}
if err := store.ImportRuns().Create(ctx, ImportRun{
RunID: "run-1",
HostID: "host-1",
Mode: "strict",
AccessMode: "subscription",
State: "running",
}); err != nil {
t.Fatalf("ImportRuns().Create() error = %v", err)
}
if err := store.ImportRunItems().Create(ctx, ImportRunItem{
ItemID: "item-1",
RunID: "run-1",
BaseURL: "https://api.example.com/v1",
ProviderID: "provider-1",
APIKeyFingerprint: "fp-1",
CurrentStage: "confirm",
ConfirmationStatus: "pending",
AccessStatus: "unknown",
MatchedAccountState: "active",
AccountResolution: "created",
}); err != nil {
t.Fatalf("ImportRunItems().Create() error = %v", err)
}
if err := repo.Append(ctx, ImportRunItemEvent{
EventID: "evt-1",
RunID: "run-1",
ItemID: "item-1",
EventType: "confirm",
Stage: "confirm",
Message: "scheduled",
}); err != nil {
t.Fatalf("Append() error = %v", err)
}
events, err := repo.ListByItemID(ctx, "item-1")
if err != nil {
t.Fatalf("ListByItemID() error = %v", err)
}
if len(events) != 1 {
t.Fatalf("len(events) = %d, want 1", len(events))
}
if events[0].PayloadJSON != "{}" {
t.Fatalf("PayloadJSON = %q, want {}", events[0].PayloadJSON)
}
}
func TestImportRunItemEventsRepoListByItemIDValidation(t *testing.T) {
t.Parallel()
_, err := openTestDB(t).ImportRunEvents().ListByItemID(context.Background(), "")
if err == nil || err.Error() != "item_id is required" {
t.Fatalf("ListByItemID() error = %v, want item_id is required", err)
}
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"reflect"
"testing"
"time"
)
func TestRunStateStore(t *testing.T) {
@@ -169,3 +170,182 @@ func TestRunStateStore(t *testing.T) {
t.Fatalf("items[0].AdvisoryMessagesJSON = %q, want advisory json", items[0].AdvisoryMessagesJSON)
}
}
func TestImportRunsRepoList(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestDB(t)
for _, run := range []ImportRun{
{RunID: "run-a", HostID: "host-a", Mode: "strict", AccessMode: "subscription", State: "running"},
{RunID: "run-b", HostID: "host-b", Mode: "partial", AccessMode: "self_service", State: "completed"},
} {
if err := store.ImportRuns().Create(ctx, run); err != nil {
t.Fatalf("ImportRuns().Create(%q) error = %v", run.RunID, err)
}
}
runs, err := store.ImportRuns().List(ctx, 1)
if err != nil {
t.Fatalf("ImportRuns().List(limit=1) error = %v", err)
}
if len(runs) != 1 {
t.Fatalf("ImportRuns().List(limit=1) len = %d, want 1", len(runs))
}
runs, err = store.ImportRuns().List(ctx, 0)
if err != nil {
t.Fatalf("ImportRuns().List(limit=0) error = %v", err)
}
if len(runs) != 2 {
t.Fatalf("ImportRuns().List(limit=0) len = %d, want 2", len(runs))
}
}
func TestImportRunItemsRepoCreateUpdateAndLease(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestDB(t)
run := ImportRun{RunID: "run-lease", HostID: "host-lease", Mode: "strict", AccessMode: "subscription", State: "running"}
if err := store.ImportRuns().Create(ctx, run); err != nil {
t.Fatalf("ImportRuns().Create() error = %v", err)
}
item := ImportRunItem{
ItemID: "item-lease",
RunID: "run-lease",
BaseURL: "https://api.example.com/v1",
ProviderID: "provider-lease",
APIKeyFingerprint: "fp_lease",
CurrentStage: "confirm",
ConfirmationStatus: "pending",
AccessStatus: "unknown",
MatchedAccountState: "active",
AccountResolution: "created",
RequestedModelsJSON: `["model-a"]`,
AdvisoryMessagesJSON: `["first"]`,
}
if err := store.ImportRunItems().Create(ctx, item); err != nil {
t.Fatalf("ImportRunItems().Create() error = %v", err)
}
item.ConfirmationStatus = "confirmed"
item.CurrentStage = "validate"
item.AccessStatus = "active"
item.AdvisoryMessagesJSON = `["updated"]`
if err := store.ImportRunItems().Update(ctx, item); err != nil {
t.Fatalf("ImportRunItems().Update() error = %v", err)
}
got, err := store.ImportRunItems().GetByItemID(ctx, "item-lease")
if err != nil {
t.Fatalf("ImportRunItems().GetByItemID() error = %v", err)
}
if got.CurrentStage != "validate" || got.ConfirmationStatus != "confirmed" {
t.Fatalf("updated item = %+v, want validate/confirmed", got)
}
leaseItem := item
leaseItem.ItemID = "item-pending"
leaseItem.CurrentStage = "confirm"
leaseItem.ConfirmationStatus = "pending"
leaseItem.AccessStatus = "unknown"
if err := store.ImportRunItems().Create(ctx, leaseItem); err != nil {
t.Fatalf("ImportRunItems().Create(item-pending) error = %v", err)
}
now := time.Date(2026, 5, 23, 1, 2, 3, 0, time.UTC)
leased, ok, err := store.ImportRunItems().TryAcquireConfirmationLease(ctx, "item-pending", "worker-1", now, 2*time.Minute)
if err != nil {
t.Fatalf("TryAcquireConfirmationLease() error = %v", err)
}
if !ok {
t.Fatal("TryAcquireConfirmationLease() ok = false, want true")
}
if leased.LeaseOwner != "worker-1" {
t.Fatalf("LeaseOwner = %q, want worker-1", leased.LeaseOwner)
}
if leased.ConfirmationAttempts != 1 {
t.Fatalf("ConfirmationAttempts = %d, want 1", leased.ConfirmationAttempts)
}
_, ok, err = store.ImportRunItems().TryAcquireConfirmationLease(ctx, "item-pending", "worker-2", now.Add(30*time.Second), time.Minute)
if err != nil {
t.Fatalf("TryAcquireConfirmationLease(second) error = %v", err)
}
if ok {
t.Fatal("TryAcquireConfirmationLease(second) ok = true, want false while lease active")
}
}
func TestImportRunEventsRepoCreateAndHelpers(t *testing.T) {
t.Parallel()
ctx := context.Background()
store := openTestDB(t)
if err := store.ImportRuns().Create(ctx, ImportRun{
RunID: "run-events",
HostID: "host-events",
Mode: "strict",
AccessMode: "subscription",
State: "running",
}); err != nil {
t.Fatalf("ImportRuns().Create() error = %v", err)
}
if err := store.ImportRunItems().Create(ctx, ImportRunItem{
ItemID: "item-events",
RunID: "run-events",
BaseURL: "https://api.example.com/v1",
ProviderID: "provider-events",
APIKeyFingerprint: "fp_events",
CurrentStage: "confirm",
ConfirmationStatus: "pending",
AccessStatus: "unknown",
MatchedAccountState: "active",
AccountResolution: "created",
}); err != nil {
t.Fatalf("ImportRunItems().Create() error = %v", err)
}
event := ImportRunItemEvent{
EventID: "evt-create",
RunID: "run-events",
ItemID: "item-events",
EventType: "confirmation",
Stage: "confirm",
Message: "created by wrapper",
}
if err := store.ImportRunEvents().Create(ctx, event); err != nil {
t.Fatalf("ImportRunEvents().Create() error = %v", err)
}
if err := store.ImportRunItemEvents().Create(ctx, ImportRunItemEvent{
EventID: "evt-alias",
RunID: "run-events",
ItemID: "item-events",
EventType: "alias",
Stage: "confirm",
Message: "created by alias accessor",
}); err != nil {
t.Fatalf("ImportRunItemEvents().Create() error = %v", err)
}
events, err := store.ImportRunEvents().ListByItemID(ctx, "item-events")
if err != nil {
t.Fatalf("ImportRunEvents().ListByItemID() error = %v", err)
}
if len(events) != 2 {
t.Fatalf("ImportRunEvents().ListByItemID() len = %d, want 2", len(events))
}
nullable := sqlNullInt64{Int64: 42, Valid: true}
if ptr := nullable.ptr(); ptr == nil || *ptr != 42 {
t.Fatalf("sqlNullInt64.ptr() = %#v, want 42", ptr)
}
if ptr := (sqlNullInt64{}).ptr(); ptr != nil {
t.Fatalf("sqlNullInt64{}.ptr() = %#v, want nil", ptr)
}
}

View File

@@ -157,3 +157,33 @@ func TestPacksRepoGetByPackIDNotFound(t *testing.T) {
t.Fatalf("GetByPackID() error = %v, want sql.ErrNoRows", err)
}
}
func TestPacksRepoListAll(t *testing.T) {
store := openTestDB(t)
packs, err := store.Packs().ListAll(context.Background())
if err != nil {
t.Fatalf("ListAll() empty error = %v", err)
}
if len(packs) != 0 {
t.Fatalf("ListAll() empty len = %d, want 0", len(packs))
}
if _, err := store.Packs().Create(context.Background(), Pack{PackID: "pack-a", Version: "1.0.0", Checksum: "chk-a"}); err != nil {
t.Fatalf("Create(pack-a) error = %v", err)
}
if _, err := store.Packs().Create(context.Background(), Pack{PackID: "pack-b", Version: "1.1.0", Checksum: "chk-b"}); err != nil {
t.Fatalf("Create(pack-b) error = %v", err)
}
packs, err = store.Packs().ListAll(context.Background())
if err != nil {
t.Fatalf("ListAll() error = %v", err)
}
if len(packs) != 2 {
t.Fatalf("ListAll() len = %d, want 2", len(packs))
}
if packs[0].PackID != "pack-a" || packs[1].PackID != "pack-b" {
t.Fatalf("ListAll() ids = [%q %q], want [pack-a pack-b]", packs[0].PackID, packs[1].PackID)
}
}

View File

@@ -55,6 +55,33 @@ func TestProvidersRepoListByProviderID(t *testing.T) {
}
}
func TestProvidersRepoListByPackID(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
otherPackID := createTestPackWithSuffix(t, store, "other")
if _, err := store.Providers().Create(context.Background(), Provider{PackID: packID, ProviderID: "provider-a", DisplayName: "A", BaseURL: "https://a.example.com", Platform: "openai"}); err != nil {
t.Fatalf("Create(provider-a) error = %v", err)
}
if _, err := store.Providers().Create(context.Background(), Provider{PackID: packID, ProviderID: "provider-b", DisplayName: "B", BaseURL: "https://b.example.com", Platform: "openai"}); err != nil {
t.Fatalf("Create(provider-b) error = %v", err)
}
if _, err := store.Providers().Create(context.Background(), Provider{PackID: otherPackID, ProviderID: "provider-c", DisplayName: "C", BaseURL: "https://c.example.com", Platform: "openai"}); err != nil {
t.Fatalf("Create(provider-c) error = %v", err)
}
providers, err := store.Providers().ListByPackID(context.Background(), packID)
if err != nil {
t.Fatalf("ListByPackID() error = %v", err)
}
if len(providers) != 2 {
t.Fatalf("ListByPackID() count = %d, want 2", len(providers))
}
if providers[0].ProviderID != "provider-a" || providers[1].ProviderID != "provider-b" {
t.Fatalf("ListByPackID() provider ids = [%q %q], want [provider-a provider-b]", providers[0].ProviderID, providers[1].ProviderID)
}
}
func createTestPackWithSuffix(t *testing.T, store *DB, suffix string) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
@@ -145,6 +172,44 @@ func TestProvidersRepoGetByPackIDAndProviderIDNotFound(t *testing.T) {
}
}
func TestProvidersRepoGetByID(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID,
ProviderID: "provider-id-lookup",
DisplayName: "Lookup",
BaseURL: "https://lookup.example.com",
Platform: "openai",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
got, err := store.Providers().GetByID(context.Background(), providerID)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.ProviderID != "provider-id-lookup" {
t.Fatalf("GetByID() provider_id = %q, want provider-id-lookup", got.ProviderID)
}
}
func TestProvidersRepoGetByIDErrors(t *testing.T) {
store := openTestDB(t)
if _, err := store.Providers().GetByID(context.Background(), 0); err == nil {
t.Fatal("GetByID(0) error = nil, want error")
}
if _, err := store.Providers().ListByPackID(context.Background(), 0); err == nil {
t.Fatal("ListByPackID(0) error = nil, want error")
}
if _, err := store.Providers().GetByID(context.Background(), 999); !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestProvidersRepoGetByPackIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 0, "p")