test(project): achieve ≥70% package coverage across all internal packages

- store/sqlite: 75.4% (repos + db coverage)
- host/sub2api: 80.8% (httptest mock server, pure function tests)
- app: 74.2% (handler error paths, NewActionSet closures)
- pack: 72.4%
- provision: 75.2%
- access: 77.3%
- config: 94.7% (lookup mock tests)

All tests pass: build, vet, race, coverage gates.
This commit is contained in:
phamnazage-jpg
2026-05-15 19:26:25 +08:00
parent 70ec9d393b
commit 71cbaf5fa6
74 changed files with 10229 additions and 84 deletions

View File

@@ -0,0 +1,77 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type AccessClosureRecord struct {
ID int64
BatchID int64
ClosureType string
Status string
DetailsJSON string
}
type AccessClosureRecordsRepo struct {
db execQuerier
}
func newAccessClosureRecordsRepo(db execQuerier) *AccessClosureRecordsRepo {
return &AccessClosureRecordsRepo{db: db}
}
func (r *AccessClosureRecordsRepo) Create(ctx context.Context, record AccessClosureRecord) (int64, error) {
closureType := strings.TrimSpace(record.ClosureType)
status := strings.TrimSpace(record.Status)
detailsJSON := strings.TrimSpace(record.DetailsJSON)
if detailsJSON == "" {
detailsJSON = "{}"
}
switch {
case record.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case closureType == "":
return 0, fmt.Errorf("closure_type is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO access_closure_records (batch_id, closure_type, status, details_json) VALUES (?, ?, ?, ?)`, record.BatchID, closureType, status, detailsJSON)
if err != nil {
return 0, fmt.Errorf("insert access closure record: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted access closure record id: %w", err)
}
return id, nil
}
func (r *AccessClosureRecordsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]AccessClosureRecord, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, closure_type, status, details_json FROM access_closure_records WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query access closure records: %w", err)
}
defer rows.Close()
records := make([]AccessClosureRecord, 0)
for rows.Next() {
var record AccessClosureRecord
if err := rows.Scan(&record.ID, &record.BatchID, &record.ClosureType, &record.Status, &record.DetailsJSON); err != nil {
return nil, fmt.Errorf("scan access closure record: %w", err)
}
records = append(records, record)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate access closure records: %w", err)
}
return records, nil
}

View File

@@ -15,13 +15,20 @@ import (
type execQuerier interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...any) *sql.Row
}
type Queries struct {
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
ImportBatches *ImportBatchesRepo
ImportBatchItems *ImportBatchItemsRepo
ManagedResources *ManagedResourcesRepo
ProbeResults *ProbeResultsRepo
AccessClosures *AccessClosureRecordsRepo
ReconcileRuns *ReconcileRunsRepo
}
type DB struct {
@@ -76,6 +83,30 @@ func (db *DB) Providers() *ProvidersRepo {
return db.queries.Providers
}
func (db *DB) ImportBatches() *ImportBatchesRepo {
return db.queries.ImportBatches
}
func (db *DB) ImportBatchItems() *ImportBatchItemsRepo {
return db.queries.ImportBatchItems
}
func (db *DB) ManagedResources() *ManagedResourcesRepo {
return db.queries.ManagedResources
}
func (db *DB) ProbeResults() *ProbeResultsRepo {
return db.queries.ProbeResults
}
func (db *DB) AccessClosures() *AccessClosureRecordsRepo {
return db.queries.AccessClosures
}
func (db *DB) ReconcileRuns() *ReconcileRunsRepo {
return db.queries.ReconcileRuns
}
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := db.sqlDB.BeginTx(ctx, nil)
if err != nil {
@@ -101,9 +132,15 @@ func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
func newQueries(db execQuerier) *Queries {
return &Queries{
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
ImportBatches: newImportBatchesRepo(db),
ImportBatchItems: newImportBatchItemsRepo(db),
ManagedResources: newManagedResourcesRepo(db),
ProbeResults: newProbeResultsRepo(db),
AccessClosures: newAccessClosureRecordsRepo(db),
ReconcileRuns: newReconcileRunsRepo(db),
}
}

View File

@@ -0,0 +1,161 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"path/filepath"
"testing"
)
func TestOpenClose(t *testing.T) {
store := openTestDB(t)
if store == nil {
t.Fatal("Open() returned nil")
}
}
func TestOpenInvalidDSN(t *testing.T) {
_, err := Open(context.Background(), "file:/nonexistent/dir/test.db?_pragma=foreign_keys(0)")
if err == nil {
t.Fatal("Open() with invalid dsn error = nil, want error")
}
}
func TestWithTxCommit(t *testing.T) {
store := openTestDB(t)
err := store.WithTx(context.Background(), func(q *Queries) error {
_, err := q.Hosts.Create(context.Background(), Host{
HostID: "tx-host", BaseURL: "https://tx.com", HostVersion: "1.0",
})
return err
})
if err != nil {
t.Fatalf("WithTx() error = %v", err)
}
host, err := store.Hosts().GetByHostID(context.Background(), "tx-host")
if err != nil {
t.Fatalf("GetByHostID() after tx = %v", err)
}
if host.HostID != "tx-host" {
t.Fatalf("host = %+v, want tx-host", host)
}
}
func TestWithTxRollbackOnError(t *testing.T) {
store := openTestDB(t)
err := store.WithTx(context.Background(), func(q *Queries) error {
q.Hosts.Create(context.Background(), Host{
HostID: "rollback-host", BaseURL: "https://r.com", HostVersion: "1.0",
})
return errors.New("rollback")
})
if err == nil {
t.Fatal("WithTx() error = nil, want rollback error")
}
_, err = store.Hosts().GetByHostID(context.Background(), "rollback-host")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByHostID() after rollback error = %v, want sql.ErrNoRows", err)
}
}
func TestTableExists(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
found, err := tableExists(context.Background(), db, "hosts")
if err != nil {
t.Fatalf("tableExists('hosts') error = %v", err)
}
if !found {
t.Fatal("tableExists('hosts') = false, want true")
}
found, err = tableExists(context.Background(), db, "nonexistent")
if err != nil {
t.Fatalf("tableExists('nonexistent') error = %v", err)
}
if found {
t.Fatal("tableExists('nonexistent') = true, want false")
}
}
func TestDetectLegacy0001Schema(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
tx, err := db.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("BeginTx error = %v", err)
}
defer tx.Rollback()
// After migration all three host/packs/providers tables exist,
// so detectLegacy0001Schema reports complete=true.
complete, partial, err := detectLegacy0001Schema(context.Background(), tx)
if err != nil {
t.Fatalf("detectLegacy0001Schema() error = %v", err)
}
if !complete {
t.Fatalf("detectLegacy0001Schema() = (complete=%v, partial=%v), want (true, false)", complete, partial)
}
if partial {
t.Fatal("partial should be false when all 3 tables exist")
}
}
func TestWithForeignKeysEnabled(t *testing.T) {
if got := withForeignKeysEnabled("file:test.db"); got != "file:test.db?_pragma=foreign_keys(1)" {
t.Fatalf("withForeignKeysEnabled no query = %q", got)
}
if got := withForeignKeysEnabled("file:test.db?a=1"); got != "file:test.db?a=1&_pragma=foreign_keys(1)" {
t.Fatalf("withForeignKeysEnabled with query = %q", got)
}
}
func TestSQLDB(t *testing.T) {
store := openTestDB(t)
db := store.SQLDB()
if db == nil {
t.Fatal("SQLDB() returned nil")
}
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("Ping() error = %v", err)
}
}
func TestMigrationFileNames(t *testing.T) {
names, err := migrationFileNames()
if err != nil {
t.Fatalf("migrationFileNames() error = %v", err)
}
if len(names) == 0 {
t.Fatal("migrationFileNames() returned empty")
}
for _, name := range names {
if filepath.Ext(name) != ".sql" {
t.Fatalf("migrationFileNames() entry %q not .sql file", name)
}
}
}
func TestReadMigration(t *testing.T) {
content, err := readMigration("0001_init.sql")
if err != nil {
t.Fatalf("readMigration('0001_init.sql') error = %v", err)
}
if len(content) == 0 {
t.Fatal("readMigration() returned empty content")
}
}
func TestReadMigrationNotFound(t *testing.T) {
_, err := readMigration("nonexistent.sql")
if err == nil {
t.Fatal("readMigration('nonexistent.sql') error = nil, want error")
}
}

View File

@@ -7,6 +7,7 @@ import (
)
type Host struct {
ID int64
HostID string
BaseURL string
HostVersion string
@@ -21,6 +22,31 @@ func newHostsRepo(db execQuerier) *HostsRepo {
return &HostsRepo{db: db}
}
func (r *HostsRepo) GetByID(ctx context.Context, id int64) (Host, error) {
if id <= 0 {
return Host{}, fmt.Errorf("id is required")
}
var host Host
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE id = ?`, id).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
return Host{}, err
}
return host, nil
}
func (r *HostsRepo) GetByHostID(ctx context.Context, hostID string) (Host, error) {
hostID = strings.TrimSpace(hostID)
if hostID == "" {
return Host{}, fmt.Errorf("host_id is required")
}
var host Host
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, base_url, host_version, capability_probe_json FROM hosts WHERE host_id = ?`, hostID).Scan(&host.ID, &host.HostID, &host.BaseURL, &host.HostVersion, &host.CapabilityProbeJSON); err != nil {
return Host{}, err
}
return host, nil
}
func (r *HostsRepo) Create(ctx context.Context, host Host) (int64, error) {
hostID := strings.TrimSpace(host.HostID)
baseURL := strings.TrimSpace(host.BaseURL)

View File

@@ -0,0 +1,199 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"path/filepath"
"testing"
)
// openTestDB creates a test database with foreign keys disabled.
func openTestDB(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test.db")
dsn := "file:" + filepath.ToSlash(dbPath) + "?_pragma=foreign_keys(0)"
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { store.Close() })
return store
}
// openTestDBWithFK creates a test database with foreign keys enforced.
func openTestDBWithFK(t *testing.T) *DB {
t.Helper()
dbPath := filepath.Join(t.TempDir(), "test-fk.db")
dsn := "file:" + filepath.ToSlash(dbPath)
store, err := Open(context.Background(), dsn)
if err != nil {
t.Fatalf("Open() error = %v", err)
}
t.Cleanup(func() { store.Close() })
return store
}
func createTestPack(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "pack-" + sanitizeTestName(t.Name()), Version: "1.0.0", Checksum: "chk",
})
if err != nil {
t.Fatalf("createTestPack error = %v", err)
}
return id
}
func createTestHost(t *testing.T, store *DB) int64 {
t.Helper()
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-" + sanitizeTestName(t.Name()), BaseURL: "https://h.com", HostVersion: "0.1.0",
})
if err != nil {
t.Fatalf("createTestHost error = %v", err)
}
return id
}
func createTestBatch(t *testing.T, store *DB) int64 {
t.Helper()
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID, ProviderID: "test-provider", DisplayName: "Test",
BaseURL: "https://t.com", Platform: "openai",
})
if err != nil {
t.Fatalf("createTestBatch create provider error = %v", err)
}
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatch error = %v", err)
}
return id
}
func createTestBatchItem(t *testing.T, store *DB, batchID int64) int64 {
t.Helper()
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:test", AccountStatus: "pending",
})
if err != nil {
t.Fatalf("createTestBatchItem error = %v", err)
}
return id
}
func sanitizeTestName(name string) string {
result := ""
for _, c := range name {
if (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' {
result += string(c)
}
}
if result == "" {
result = "default"
}
return result
}
// --- Hosts Repo Tests ---
func TestHostsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
id, err := store.Hosts().Create(context.Background(), Host{
HostID: "host-1",
BaseURL: "https://sub2api.example.com",
HostVersion: "0.1.126",
CapabilityProbeJSON: `{"groups":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.Hosts().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.HostID != "host-1" || got.BaseURL != "https://sub2api.example.com" {
t.Fatalf("GetByID() = %+v, want host-1", got)
}
got2, err := store.Hosts().GetByHostID(context.Background(), "host-1")
if err != nil {
t.Fatalf("GetByHostID() error = %v", err)
}
if got2.ID != id {
t.Fatalf("GetByHostID() id = %d, want %d", got2.ID, id)
}
}
func TestHostsRepoCreateDefaultsCapabilityProbe(t *testing.T) {
store := openTestDB(t)
id, _ := store.Hosts().Create(context.Background(), Host{
HostID: "host-empty", BaseURL: "https://example.com", HostVersion: "0.1.0",
})
got, _ := store.Hosts().GetByID(context.Background(), id)
if got.CapabilityProbeJSON != "{}" {
t.Fatalf("CapabilityProbeJSON = %q, want {}", got.CapabilityProbeJSON)
}
}
func TestHostsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
host Host
}{
{"empty host_id", Host{BaseURL: "b", HostVersion: "v"}},
{"empty base_url", Host{HostID: "h", HostVersion: "v"}},
{"empty host_version", Host{HostID: "h", BaseURL: "b"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Hosts().Create(context.Background(), tt.host)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestHostsRepoGetByIDZeroError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 0)
if err == nil {
t.Fatal("GetByID(0) error = nil, want error")
}
}
func TestHostsRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByID(context.Background(), 999)
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestHostsRepoGetByHostIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "")
if err == nil {
t.Fatal("GetByHostID('') error = nil, want error")
}
}
func TestHostsRepoGetByHostIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Hosts().GetByHostID(context.Background(), "nonexistent")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByHostID('nonexistent') error = %v, want sql.ErrNoRows", err)
}
}

View File

@@ -0,0 +1,187 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ImportBatch struct {
ID int64
HostID int64
PackID int64
ProviderID int64
Mode string
BatchStatus string
AccessStatus string
}
type ImportBatchItem struct {
ID int64
BatchID int64
KeyFingerprint string
AccountStatus string
ProbeSummaryJSON string
}
type ImportBatchesRepo struct {
db execQuerier
}
type ImportBatchItemsRepo struct {
db execQuerier
}
func newImportBatchesRepo(db execQuerier) *ImportBatchesRepo {
return &ImportBatchesRepo{db: db}
}
func newImportBatchItemsRepo(db execQuerier) *ImportBatchItemsRepo {
return &ImportBatchItemsRepo{db: db}
}
func (r *ImportBatchesRepo) GetByID(ctx context.Context, id int64) (ImportBatch, error) {
if id <= 0 {
return ImportBatch{}, fmt.Errorf("id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE id = ?`, id).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchesRepo) Create(ctx context.Context, batch ImportBatch) (int64, error) {
mode := strings.TrimSpace(batch.Mode)
batchStatus := strings.TrimSpace(batch.BatchStatus)
accessStatus := strings.TrimSpace(batch.AccessStatus)
switch {
case batch.HostID <= 0:
return 0, fmt.Errorf("host_id is required")
case batch.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case batch.ProviderID <= 0:
return 0, fmt.Errorf("provider_id is required")
case mode == "":
return 0, fmt.Errorf("mode is required")
case batchStatus == "":
return 0, fmt.Errorf("batch_status is required")
case accessStatus == "":
return 0, fmt.Errorf("access_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batches (host_id, pack_id, provider_id, mode, batch_status, access_status) VALUES (?, ?, ?, ?, ?, ?)`, batch.HostID, batch.PackID, batch.ProviderID, mode, batchStatus, accessStatus)
if err != nil {
return 0, fmt.Errorf("insert import batch: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch id: %w", err)
}
return id, nil
}
func (r *ImportBatchesRepo) UpdateStatus(ctx context.Context, id int64, batchStatus, accessStatus string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
batchStatus = strings.TrimSpace(batchStatus)
accessStatus = strings.TrimSpace(accessStatus)
if batchStatus == "" {
return fmt.Errorf("batch_status is required")
}
if accessStatus == "" {
return fmt.Errorf("access_status is required")
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batches SET batch_status = ?, access_status = ? WHERE id = ?`, batchStatus, accessStatus, id); err != nil {
return fmt.Errorf("update import batch %d: %w", id, err)
}
return nil
}
func (r *ImportBatchesRepo) GetLatestByProviderID(ctx context.Context, providerID int64) (ImportBatch, error) {
if providerID <= 0 {
return ImportBatch{}, fmt.Errorf("provider_id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC LIMIT 1`, providerID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchItemsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ImportBatchItem, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, key_fingerprint, account_status, probe_summary_json FROM import_batch_items WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query import batch items: %w", err)
}
defer rows.Close()
items := make([]ImportBatchItem, 0)
for rows.Next() {
var item ImportBatchItem
if err := rows.Scan(&item.ID, &item.BatchID, &item.KeyFingerprint, &item.AccountStatus, &item.ProbeSummaryJSON); err != nil {
return nil, fmt.Errorf("scan import batch item: %w", err)
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import batch items: %w", err)
}
return items, nil
}
func (r *ImportBatchItemsRepo) Create(ctx context.Context, item ImportBatchItem) (int64, error) {
keyFingerprint := strings.TrimSpace(item.KeyFingerprint)
accountStatus := strings.TrimSpace(item.AccountStatus)
probeSummaryJSON := strings.TrimSpace(item.ProbeSummaryJSON)
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
switch {
case item.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case keyFingerprint == "":
return 0, fmt.Errorf("key_fingerprint is required")
case accountStatus == "":
return 0, fmt.Errorf("account_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batch_items (batch_id, key_fingerprint, account_status, probe_summary_json) VALUES (?, ?, ?, ?)`, item.BatchID, keyFingerprint, accountStatus, probeSummaryJSON)
if err != nil {
return 0, fmt.Errorf("insert import batch item: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch item id: %w", err)
}
return id, nil
}
func (r *ImportBatchItemsRepo) UpdateResult(ctx context.Context, id int64, accountStatus, probeSummaryJSON string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
accountStatus = strings.TrimSpace(accountStatus)
probeSummaryJSON = strings.TrimSpace(probeSummaryJSON)
if accountStatus == "" {
return fmt.Errorf("account_status is required")
}
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batch_items SET account_status = ?, probe_summary_json = ? WHERE id = ?`, accountStatus, probeSummaryJSON, id); err != nil {
return fmt.Errorf("update import batch item %d: %w", id, err)
}
return nil
}

View File

@@ -0,0 +1,429 @@
package sqlite
import (
"context"
"testing"
)
func TestImportBatchesRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
id, err := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, _ := store.ImportBatches().GetByID(context.Background(), id)
if got.Mode != "partial" || got.BatchStatus != "running" {
t.Fatalf("GetByID() = %+v, want running batch", got)
}
}
func TestImportBatchesRepoUpdateStatus(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
id, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
err := store.ImportBatches().UpdateStatus(context.Background(), id, "succeeded", "subscription_ready")
if err != nil {
t.Fatalf("UpdateStatus() error = %v", err)
}
got, _ := store.ImportBatches().GetByID(context.Background(), id)
if got.BatchStatus != "succeeded" || got.AccessStatus != "subscription_ready" {
t.Fatalf("status = %+v, want succeeded/subscription_ready", got)
}
}
func TestImportBatchesRepoGetLatestByProviderID(t *testing.T) {
store := openTestDB(t)
hostID := createTestHost(t, store)
packID := createTestPack(t, store)
providerID := createTestProviderWithPack(t, store, packID)
store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "partial", BatchStatus: "running", AccessStatus: "pending",
})
id2, _ := store.ImportBatches().Create(context.Background(), ImportBatch{
HostID: hostID, PackID: packID, ProviderID: providerID,
Mode: "strict", BatchStatus: "succeeded", AccessStatus: "subscription_ready",
})
got, _ := store.ImportBatches().GetLatestByProviderID(context.Background(), providerID)
if got.ID != id2 {
t.Fatalf("latest id = %d, want %d", got.ID, id2)
}
}
func TestImportBatchesRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.ImportBatches().GetByID(context.Background(), 999)
if err == nil {
t.Fatal("GetByID(999) error = nil, want error")
}
}
func TestImportBatchesRepoCreateValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
batch ImportBatch
}{
{"host_id zero", ImportBatch{HostID: 0, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"pack_id zero", ImportBatch{HostID: 1, PackID: 0, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"provider_id zero", ImportBatch{HostID: 1, PackID: 1, ProviderID: 0, Mode: "m", BatchStatus: "s", AccessStatus: "s"}},
{"empty mode", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "", BatchStatus: "s", AccessStatus: "s"}},
{"empty batch_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "", AccessStatus: "s"}},
{"empty access_status", ImportBatch{HostID: 1, PackID: 1, ProviderID: 1, Mode: "m", BatchStatus: "s", AccessStatus: ""}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ImportBatches().Create(context.Background(), tt.batch)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
func TestImportBatchesRepoUpdateStatusValidation(t *testing.T) {
store := openTestDB(t)
if err := store.ImportBatches().UpdateStatus(context.Background(), 0, "s", "s"); err == nil {
t.Fatal("UpdateStatus id=0 error = nil")
}
if err := store.ImportBatches().UpdateStatus(context.Background(), 1, "", "s"); err == nil {
t.Fatal("UpdateStatus empty batch_status error = nil")
}
}
// --- ImportBatchItems Repo Tests ---
func TestImportBatchItemsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID,
KeyFingerprint: "sha256:abc",
AccountStatus: "passed",
ProbeSummaryJSON: `{"ok":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if len(items) != 1 || items[0].AccountStatus != "passed" {
t.Fatalf("items = %+v, want 1 with passed status", items)
}
}
func TestImportBatchItemsRepoMultipleItems(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
for _, status := range []string{"passed", "failed"} {
store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:" + status, AccountStatus: status,
})
}
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if len(items) != 2 {
t.Fatalf("count = %d, want 2", len(items))
}
}
func TestImportBatchItemsRepoUpdateResult(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID, _ := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: batchID, KeyFingerprint: "sha256:x", AccountStatus: "pending",
})
store.ImportBatchItems().UpdateResult(context.Background(), itemID, "passed", `{"ok":true}`)
items, _ := store.ImportBatchItems().GetByBatchID(context.Background(), batchID)
if items[0].AccountStatus != "passed" {
t.Fatalf("AccountStatus = %q, want passed", items[0].AccountStatus)
}
}
func TestImportBatchItemsRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
items, err := store.ImportBatchItems().GetByBatchID(context.Background(), 999)
if err != nil {
t.Fatalf("GetByBatchID() error = %v, want empty result", err)
}
if len(items) != 0 {
t.Fatalf("count = %d, want 0", len(items))
}
}
func TestImportBatchItemsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.ImportBatchItems().Create(context.Background(), ImportBatchItem{
BatchID: 0, KeyFingerprint: "k", AccountStatus: "s",
})
if err == nil {
t.Fatal("Create batch_id=0 error = nil")
}
}
// --- Managed Resources Repo Tests ---
func TestManagedResourcesRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.ManagedResources().Create(context.Background(), ManagedResource{
BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "test-group",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
if len(resources) != 1 || resources[0].HostResourceID != "g_01" {
t.Fatalf("resources = %+v, want 1 with g_01", resources)
}
_ = id
}
func TestManagedResourcesRepoMultipleResources(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
for _, r := range []ManagedResource{
{BatchID: batchID, ResourceType: "group", HostResourceID: "g_01", ResourceName: "group-1"},
{BatchID: batchID, ResourceType: "channel", HostResourceID: "c_01", ResourceName: "channel-1"},
{BatchID: batchID, ResourceType: "account", HostResourceID: "a_01", ResourceName: "account-1"},
} {
store.ManagedResources().Create(context.Background(), r)
}
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), batchID)
if len(resources) != 3 {
t.Fatalf("count = %d, want 3", len(resources))
}
}
func TestManagedResourcesRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
resources, _ := store.ManagedResources().GetByBatchID(context.Background(), 999)
if len(resources) != 0 {
t.Fatalf("count = %d, want 0", len(resources))
}
}
func TestManagedResourcesRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
r ManagedResource
}{
{"batch_id zero", ManagedResource{ResourceType: "g", HostResourceID: "h", ResourceName: "n"}},
{"empty resource_type", ManagedResource{BatchID: 1, HostResourceID: "h", ResourceName: "n"}},
{"empty host_resource_id", ManagedResource{BatchID: 1, ResourceType: "g", ResourceName: "n"}},
{"empty resource_name", ManagedResource{BatchID: 1, ResourceType: "g", HostResourceID: "h"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ManagedResources().Create(context.Background(), tt.r)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
// --- Probe Results Repo Tests ---
func TestProbeResultsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID := createTestBatchItem(t, store, batchID)
id, err := store.ProbeResults().Create(context.Background(), ProbeResult{
BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
if len(results) != 1 || results[0].ProbeType != "account_smoke" {
t.Fatalf("results = %+v, want 1 with account_smoke", results)
}
_ = id
}
func TestProbeResultsRepoMultipleResults(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
itemID := createTestBatchItem(t, store, batchID)
for _, p := range []ProbeResult{
{BatchItemID: itemID, ProbeType: "account_smoke", Status: "passed", SummaryJSON: `{"ok":true}`},
{BatchItemID: itemID, ProbeType: "model_list", Status: "passed", SummaryJSON: `{"models":["m1"]}`},
} {
store.ProbeResults().Create(context.Background(), p)
}
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), itemID)
if len(results) != 2 {
t.Fatalf("count = %d, want 2", len(results))
}
}
func TestProbeResultsRepoGetByBatchItemIDEmpty(t *testing.T) {
store := openTestDB(t)
results, _ := store.ProbeResults().GetByBatchItemID(context.Background(), 999)
if len(results) != 0 {
t.Fatalf("count = %d, want 0", len(results))
}
}
func TestProbeResultsRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
for _, tt := range []struct {
name string
probe ProbeResult
}{
{"batch_item_id zero", ProbeResult{ProbeType: "t", Status: "s"}},
{"empty probe_type", ProbeResult{BatchItemID: 1, Status: "s"}},
{"empty status", ProbeResult{BatchItemID: 1, ProbeType: "t"}},
} {
t.Run(tt.name, func(t *testing.T) {
_, err := store.ProbeResults().Create(context.Background(), tt.probe)
if err == nil {
t.Fatal("Create() error = nil")
}
})
}
}
// --- Access Closures Repo Tests ---
func TestAccessClosureRecordsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
id, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{
BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: `{"status_code":200}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
if len(records) != 1 || records[0].ClosureType != "subscription" {
t.Fatalf("records = %+v, want 1 subscription", records)
}
_ = id
}
func TestAccessClosureRecordsRepoMultiple(t *testing.T) {
store := openTestDB(t)
batchID := createTestBatch(t, store)
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "subscription", Status: "subscription_ready", DetailsJSON: "{}"})
store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: batchID, ClosureType: "self_service", Status: "self_service_ready", DetailsJSON: "{}"})
records, _ := store.AccessClosures().GetByBatchID(context.Background(), batchID)
if len(records) != 2 {
t.Fatalf("count = %d, want 2", len(records))
}
}
func TestAccessClosureRecordsRepoGetByBatchIDEmpty(t *testing.T) {
store := openTestDB(t)
records, _ := store.AccessClosures().GetByBatchID(context.Background(), 999)
if len(records) != 0 {
t.Fatalf("count = %d, want 0", len(records))
}
}
func TestAccessClosureRecordsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.AccessClosures().Create(context.Background(), AccessClosureRecord{BatchID: 0, ClosureType: "t", Status: "s"})
if err == nil {
t.Fatal("Create batch_id=0 error = nil")
}
}
// --- Reconcile Runs Repo Tests ---
func createTestProviderWithPack(t *testing.T, store *DB, packID int64) int64 {
t.Helper()
id, err := store.Providers().Create(context.Background(), Provider{
PackID: packID, ProviderID: "test-provider-" + sanitizeTestName(t.Name()), DisplayName: "TP",
BaseURL: "https://tp.com", Platform: "openai",
})
if err != nil {
t.Fatalf("createTestProviderWithPack error = %v", err)
}
return id
}
func createTestProvider(t *testing.T, store *DB) int64 {
t.Helper()
packID := createTestPack(t, store)
return createTestProviderWithPack(t, store, packID)
}
func TestReconcileRunsRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
providerID := createTestProvider(t, store)
id, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{
ProviderID: providerID, Status: "active", SummaryJSON: `{"drifted":false}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
if len(runs) != 1 || runs[0].Status != "active" {
t.Fatalf("runs = %+v, want 1 active", runs)
}
_ = id
}
func TestReconcileRunsRepoMultipleRunsOrderedDesc(t *testing.T) {
store := openTestDB(t)
providerID := createTestProvider(t, store)
id1, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "first", SummaryJSON: "{}"})
id2, _ := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: providerID, Status: "second", SummaryJSON: "{}"})
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), providerID)
if len(runs) != 2 || runs[0].ID != id2 || runs[1].ID != id1 {
t.Fatalf("order: got %d, %d; want %d, %d (DESC)", runs[0].ID, runs[1].ID, id2, id1)
}
}
func TestReconcileRunsRepoGetByProviderIDEmpty(t *testing.T) {
store := openTestDB(t)
runs, _ := store.ReconcileRuns().GetByProviderID(context.Background(), 999)
if len(runs) != 0 {
t.Fatalf("count = %d, want 0", len(runs))
}
}
func TestReconcileRunsRepoValidation(t *testing.T) {
store := openTestDB(t)
_, err := store.ReconcileRuns().Create(context.Background(), ReconcileRun{ProviderID: 0, Status: "s"})
if err == nil {
t.Fatal("Create provider_id=0 error = nil")
}
}

View File

@@ -0,0 +1,76 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ManagedResource struct {
ID int64
BatchID int64
ResourceType string
HostResourceID string
ResourceName string
}
type ManagedResourcesRepo struct {
db execQuerier
}
func newManagedResourcesRepo(db execQuerier) *ManagedResourcesRepo {
return &ManagedResourcesRepo{db: db}
}
func (r *ManagedResourcesRepo) Create(ctx context.Context, resource ManagedResource) (int64, error) {
resourceType := strings.TrimSpace(resource.ResourceType)
hostResourceID := strings.TrimSpace(resource.HostResourceID)
resourceName := strings.TrimSpace(resource.ResourceName)
switch {
case resource.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case resourceType == "":
return 0, fmt.Errorf("resource_type is required")
case hostResourceID == "":
return 0, fmt.Errorf("host_resource_id is required")
case resourceName == "":
return 0, fmt.Errorf("resource_name is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO managed_resources (batch_id, resource_type, host_resource_id, resource_name) VALUES (?, ?, ?, ?)`, resource.BatchID, resourceType, hostResourceID, resourceName)
if err != nil {
return 0, fmt.Errorf("insert managed resource %q: %w", hostResourceID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted managed resource id for %q: %w", hostResourceID, err)
}
return id, nil
}
func (r *ManagedResourcesRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ManagedResource, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, resource_type, host_resource_id, resource_name FROM managed_resources WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query managed resources: %w", err)
}
defer rows.Close()
resources := make([]ManagedResource, 0)
for rows.Next() {
var resource ManagedResource
if err := rows.Scan(&resource.ID, &resource.BatchID, &resource.ResourceType, &resource.HostResourceID, &resource.ResourceName); err != nil {
return nil, fmt.Errorf("scan managed resource: %w", err)
}
resources = append(resources, resource)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate managed resources: %w", err)
}
return resources, nil
}

View File

@@ -7,9 +7,15 @@ import (
)
type Pack struct {
PackID string
Version string
Checksum string
ID int64
PackID string
Version string
Checksum string
Vendor string
TargetHost string
MinHostVersion string
MaxHostVersion string
ManifestJSON string
}
type PacksRepo struct {
@@ -20,10 +26,59 @@ func newPacksRepo(db execQuerier) *PacksRepo {
return &PacksRepo{db: db}
}
func (r *PacksRepo) GetByID(ctx context.Context, id int64) (Pack, error) {
if id <= 0 {
return Pack{}, fmt.Errorf("id is required")
}
var pack Pack
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE id = ?`, id).Scan(
&pack.ID,
&pack.PackID,
&pack.Version,
&pack.Checksum,
&pack.Vendor,
&pack.TargetHost,
&pack.MinHostVersion,
&pack.MaxHostVersion,
&pack.ManifestJSON,
); err != nil {
return Pack{}, err
}
return pack, nil
}
func (r *PacksRepo) GetByPackID(ctx context.Context, packID string) (Pack, error) {
packID = strings.TrimSpace(packID)
if packID == "" {
return Pack{}, fmt.Errorf("pack_id is required")
}
var pack Pack
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json FROM packs WHERE pack_id = ?`, packID).Scan(
&pack.ID,
&pack.PackID,
&pack.Version,
&pack.Checksum,
&pack.Vendor,
&pack.TargetHost,
&pack.MinHostVersion,
&pack.MaxHostVersion,
&pack.ManifestJSON,
); err != nil {
return Pack{}, err
}
return pack, nil
}
func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
packID := strings.TrimSpace(pack.PackID)
version := strings.TrimSpace(pack.Version)
checksum := strings.TrimSpace(pack.Checksum)
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case packID == "":
@@ -36,11 +91,16 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
result, err := r.db.ExecContext(
ctx,
`INSERT INTO packs (pack_id, version, checksum)
VALUES (?, ?, ?)`,
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
packID,
version,
checksum,
strings.TrimSpace(pack.Vendor),
strings.TrimSpace(pack.TargetHost),
strings.TrimSpace(pack.MinHostVersion),
strings.TrimSpace(pack.MaxHostVersion),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("insert pack %q: %w", packID, err)
@@ -50,6 +110,62 @@ func (r *PacksRepo) Create(ctx context.Context, pack Pack) (int64, error) {
if err != nil {
return 0, fmt.Errorf("read inserted pack id for %q: %w", packID, err)
}
return id, nil
}
func (r *PacksRepo) Upsert(ctx context.Context, pack Pack) (int64, error) {
packID := strings.TrimSpace(pack.PackID)
version := strings.TrimSpace(pack.Version)
checksum := strings.TrimSpace(pack.Checksum)
manifestJSON := strings.TrimSpace(pack.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case packID == "":
return 0, fmt.Errorf("pack_id is required")
case version == "":
return 0, fmt.Errorf("version is required")
case checksum == "":
return 0, fmt.Errorf("checksum is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO packs (pack_id, version, checksum, vendor, target_host, min_host_version, max_host_version, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(pack_id) DO UPDATE SET
version = excluded.version,
checksum = excluded.checksum,
vendor = excluded.vendor,
target_host = excluded.target_host,
min_host_version = excluded.min_host_version,
max_host_version = excluded.max_host_version,
manifest_json = excluded.manifest_json`,
packID,
version,
checksum,
strings.TrimSpace(pack.Vendor),
strings.TrimSpace(pack.TargetHost),
strings.TrimSpace(pack.MinHostVersion),
strings.TrimSpace(pack.MaxHostVersion),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("upsert pack %q: %w", packID, err)
}
id, err := result.LastInsertId()
if err == nil && id > 0 {
return id, nil
}
persisted, getErr := r.GetByPackID(ctx, packID)
if getErr != nil {
if err != nil {
return 0, fmt.Errorf("read upserted pack %q: %w", packID, getErr)
}
return 0, getErr
}
return persisted.ID, nil
}

View File

@@ -0,0 +1,159 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"testing"
)
func TestPacksRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "test-pack",
Version: "1.0.0",
Checksum: "abc123",
Vendor: "test-vendor",
TargetHost: "sub2api",
MinHostVersion: "0.1.0",
MaxHostVersion: "0.2.x",
ManifestJSON: `{"name":"test"}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if id <= 0 {
t.Fatalf("Create() id = %d, want positive", id)
}
got, err := store.Packs().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.PackID != "test-pack" || got.Version != "1.0.0" {
t.Fatalf("GetByID() = %+v, want pack test-pack", got)
}
got2, err := store.Packs().GetByPackID(context.Background(), "test-pack")
if err != nil {
t.Fatalf("GetByPackID() error = %v", err)
}
if got2.ID != id {
t.Fatalf("GetByPackID() id = %d, want %d", got2.ID, id)
}
}
func TestPacksRepoCreateDefaultsManifestJSON(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "no-manifest",
Version: "1.0.0",
Checksum: "chk",
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
got, err := store.Packs().GetByID(context.Background(), id)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.ManifestJSON != "{}" {
t.Fatalf("ManifestJSON = %q, want {}", got.ManifestJSON)
}
}
func TestPacksRepoUpsertCreatesNew(t *testing.T) {
store := openTestDB(t)
id, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "1.0.0",
Checksum: "chk1",
})
if err != nil {
t.Fatalf("Upsert() error = %v", err)
}
if id <= 0 {
t.Fatalf("Upsert() id = %d, want positive", id)
}
}
func TestPacksRepoUpsertUpdatesExisting(t *testing.T) {
store := openTestDB(t)
id1, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "1.0.0",
Checksum: "chk1",
})
if err != nil {
t.Fatalf("Upsert() create error = %v", err)
}
id2, err := store.Packs().Upsert(context.Background(), Pack{
PackID: "upsert-pack",
Version: "2.0.0",
Checksum: "chk2",
})
if err != nil {
t.Fatalf("Upsert() update error = %v", err)
}
if id2 != id1 {
t.Fatalf("Upsert() update returned id %d, want original %d", id2, id1)
}
got, err := store.Packs().GetByID(context.Background(), id1)
if err != nil {
t.Fatalf("GetByID() error = %v", err)
}
if got.Version != "2.0.0" {
t.Fatalf("Version after upsert = %q, want 2.0.0", got.Version)
}
}
func TestPacksRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
tests := []struct {
name string
pack Pack
}{
{"empty pack_id", Pack{Version: "v", Checksum: "c"}},
{"empty version", Pack{PackID: "p", Checksum: "c"}},
{"empty checksum", Pack{PackID: "p", Version: "v"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Packs().Create(context.Background(), tt.pack)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestPacksRepoGetByIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByID(context.Background(), 999)
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByID(999) error = %v, want sql.ErrNoRows", err)
}
}
func TestPacksRepoGetByPackIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByPackID(context.Background(), "")
if err == nil {
t.Fatal("GetByPackID('') error = nil, want error")
}
}
func TestPacksRepoGetByPackIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Packs().GetByPackID(context.Background(), "nonexistent")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByPackID() error = %v, want sql.ErrNoRows", err)
}
}

View File

@@ -0,0 +1,77 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ProbeResult struct {
ID int64
BatchItemID int64
ProbeType string
Status string
SummaryJSON string
}
type ProbeResultsRepo struct {
db execQuerier
}
func newProbeResultsRepo(db execQuerier) *ProbeResultsRepo {
return &ProbeResultsRepo{db: db}
}
func (r *ProbeResultsRepo) Create(ctx context.Context, probe ProbeResult) (int64, error) {
probeType := strings.TrimSpace(probe.ProbeType)
status := strings.TrimSpace(probe.Status)
summaryJSON := strings.TrimSpace(probe.SummaryJSON)
if summaryJSON == "" {
summaryJSON = "{}"
}
switch {
case probe.BatchItemID <= 0:
return 0, fmt.Errorf("batch_item_id is required")
case probeType == "":
return 0, fmt.Errorf("probe_type is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO probe_results (batch_item_id, probe_type, status, summary_json) VALUES (?, ?, ?, ?)`, probe.BatchItemID, probeType, status, summaryJSON)
if err != nil {
return 0, fmt.Errorf("insert probe result: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted probe result id: %w", err)
}
return id, nil
}
func (r *ProbeResultsRepo) GetByBatchItemID(ctx context.Context, batchItemID int64) ([]ProbeResult, error) {
if batchItemID <= 0 {
return nil, fmt.Errorf("batch_item_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_item_id, probe_type, status, summary_json FROM probe_results WHERE batch_item_id = ? ORDER BY id`, batchItemID)
if err != nil {
return nil, fmt.Errorf("query probe results: %w", err)
}
defer rows.Close()
probes := make([]ProbeResult, 0)
for rows.Next() {
var probe ProbeResult
if err := rows.Scan(&probe.ID, &probe.BatchItemID, &probe.ProbeType, &probe.Status, &probe.SummaryJSON); err != nil {
return nil, fmt.Errorf("scan probe result: %w", err)
}
probes = append(probes, probe)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate probe results: %w", err)
}
return probes, nil
}

View File

@@ -7,11 +7,20 @@ import (
)
type Provider struct {
PackID int64
ProviderID string
DisplayName string
BaseURL string
Platform string
ID int64
PackID int64
ProviderID string
DisplayName string
BaseURL string
Platform string
AccountType string
DefaultModelsJSON string
SmokeTestModel string
GroupTemplateJSON string
ChannelTemplateJSON string
PlanTemplateJSON string
ImportOptionsJSON string
ManifestJSON string
}
type ProvidersRepo struct {
@@ -22,11 +31,87 @@ func newProvidersRepo(db execQuerier) *ProvidersRepo {
return &ProvidersRepo{db: db}
}
func (r *ProvidersRepo) ListByProviderID(ctx context.Context, providerID string) ([]Provider, error) {
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE provider_id = ? ORDER BY id`, providerID)
if err != nil {
return nil, fmt.Errorf("query providers by provider_id %q: %w", providerID, err)
}
defer rows.Close()
providers := make([]Provider, 0)
for rows.Next() {
var provider Provider
if err := rows.Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return nil, fmt.Errorf("scan provider by provider_id %q: %w", providerID, err)
}
providers = append(providers, provider)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate providers by provider_id %q: %w", providerID, err)
}
return providers, nil
}
func (r *ProvidersRepo) GetByPackIDAndProviderID(ctx context.Context, packID int64, providerID string) (Provider, error) {
if packID <= 0 {
return Provider{}, fmt.Errorf("pack_id is required")
}
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return Provider{}, fmt.Errorf("provider_id is required")
}
var provider Provider
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? AND provider_id = ?`, packID, providerID).Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return Provider{}, err
}
return provider, nil
}
func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
@@ -43,13 +128,21 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform)
VALUES (?, ?, ?, ?, ?)`,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("insert provider %q: %w", providerID, err)
@@ -59,6 +152,90 @@ func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, e
if err != nil {
return 0, fmt.Errorf("read inserted provider id for %q: %w", providerID, err)
}
return id, nil
}
func (r *ProvidersRepo) Upsert(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case providerID == "":
return 0, fmt.Errorf("provider_id is required")
case displayName == "":
return 0, fmt.Errorf("display_name is required")
case baseURL == "":
return 0, fmt.Errorf("base_url is required")
case platform == "":
return 0, fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(pack_id, provider_id) DO UPDATE SET
display_name = excluded.display_name,
base_url = excluded.base_url,
platform = excluded.platform,
account_type = excluded.account_type,
default_models_json = excluded.default_models_json,
smoke_test_model = excluded.smoke_test_model,
group_template_json = excluded.group_template_json,
channel_template_json = excluded.channel_template_json,
plan_template_json = excluded.plan_template_json,
import_options_json = excluded.import_options_json,
manifest_json = excluded.manifest_json`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("upsert provider %q: %w", providerID, err)
}
id, err := result.LastInsertId()
if err == nil && id > 0 {
return id, nil
}
persisted, getErr := r.GetByPackIDAndProviderID(ctx, provider.PackID, providerID)
if getErr != nil {
if err != nil {
return 0, fmt.Errorf("read upserted provider %q: %w", providerID, getErr)
}
return 0, getErr
}
return persisted.ID, nil
}
func defaultJSONObject(value string) string {
if strings.TrimSpace(value) == "" {
return "{}"
}
return value
}
func defaultJSONArray(value string) string {
if strings.TrimSpace(value) == "" {
return "[]"
}
return value
}

View File

@@ -0,0 +1,172 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"testing"
)
func TestProvidersRepoCreateAndGet(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
providerID, err := store.Providers().Create(context.Background(), Provider{
PackID: packID,
ProviderID: "deepseek",
DisplayName: "DeepSeek",
BaseURL: "https://api.deepseek.com",
Platform: "openai",
AccountType: "api",
SmokeTestModel: "deepseek-chat",
ManifestJSON: `{"models":["deepseek-chat"]}`,
})
if err != nil {
t.Fatalf("Create() error = %v", err)
}
if providerID <= 0 {
t.Fatalf("Create() id = %d, want positive", providerID)
}
got, err := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "deepseek")
if err != nil {
t.Fatalf("GetByPackIDAndProviderID() error = %v", err)
}
if got.ProviderID != "deepseek" || got.DisplayName != "DeepSeek" {
t.Fatalf("GetByPackIDAndProviderID() = %+v, want deepseek", got)
}
}
func TestProvidersRepoListByProviderID(t *testing.T) {
store := openTestDB(t)
packID1 := createTestPackWithSuffix(t, store, "a")
packID2 := createTestPackWithSuffix(t, store, "b")
store.Providers().Create(context.Background(), Provider{PackID: packID1, ProviderID: "deepseek", DisplayName: "DS1", BaseURL: "https://a.com", Platform: "openai"})
store.Providers().Create(context.Background(), Provider{PackID: packID2, ProviderID: "deepseek", DisplayName: "DS2", BaseURL: "https://b.com", Platform: "openai"})
providers, err := store.Providers().ListByProviderID(context.Background(), "deepseek")
if err != nil {
t.Fatalf("ListByProviderID() error = %v", err)
}
if len(providers) != 2 {
t.Fatalf("ListByProviderID() count = %d, want 2", len(providers))
}
}
func createTestPackWithSuffix(t *testing.T, store *DB, suffix string) int64 {
t.Helper()
id, err := store.Packs().Create(context.Background(), Pack{
PackID: "pack-" + sanitizeTestName(t.Name()) + "-" + suffix, Version: "1.0.0", Checksum: "chk",
})
if err != nil {
t.Fatalf("createTestPackWithSuffix error = %v", err)
}
return id
}
func TestProvidersRepoListByProviderIDEmpty(t *testing.T) {
store := openTestDB(t)
providers, err := store.Providers().ListByProviderID(context.Background(), "nonexistent")
if err != nil {
t.Fatalf("ListByProviderID() error = %v", err)
}
if len(providers) != 0 {
t.Fatalf("ListByProviderID() count = %d, want 0", len(providers))
}
}
func TestProvidersRepoUpsertCreatesNew(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
id, err := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P", BaseURL: "https://u.com", Platform: "openai",
})
if err != nil {
t.Fatalf("Upsert() error = %v", err)
}
if id <= 0 {
t.Fatalf("Upsert() id = %d, want positive", id)
}
}
func TestProvidersRepoUpsertUpdatesExisting(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
id1, _ := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P1", BaseURL: "https://u1.com", Platform: "openai",
})
id2, _ := store.Providers().Upsert(context.Background(), Provider{
PackID: packID, ProviderID: "upsert-p", DisplayName: "P2", BaseURL: "https://u2.com", Platform: "openai",
})
if id2 != id1 {
t.Fatalf("Upsert update id = %d, want %d", id2, id1)
}
got, _ := store.Providers().GetByPackIDAndProviderID(context.Background(), packID, "upsert-p")
if got.DisplayName != "P2" {
t.Fatalf("DisplayName after upsert = %q, want P2", got.DisplayName)
}
}
func TestProvidersRepoValidationErrors(t *testing.T) {
store := openTestDB(t)
packID := createTestPack(t, store)
tests := []struct {
name string
provider Provider
}{
{"pack_id zero", Provider{ProviderID: "p", DisplayName: "d", BaseURL: "b", Platform: "openai"}},
{"empty provider_id", Provider{PackID: packID, DisplayName: "d", BaseURL: "b", Platform: "openai"}},
{"empty display_name", Provider{PackID: packID, ProviderID: "p", BaseURL: "b", Platform: "openai"}},
{"empty base_url", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", Platform: "openai"}},
{"empty platform", Provider{PackID: packID, ProviderID: "p", DisplayName: "d", BaseURL: "b"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := store.Providers().Create(context.Background(), tt.provider)
if err == nil {
t.Fatal("Create() error = nil, want validation error")
}
})
}
}
func TestProvidersRepoGetByPackIDAndProviderIDNotFound(t *testing.T) {
store := openTestDB(t)
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 999, "p")
if !errors.Is(err, sql.ErrNoRows) {
t.Fatalf("GetByPackIDAndProviderID() error = %v, want sql.ErrNoRows", err)
}
}
func TestProvidersRepoGetByPackIDEmptyError(t *testing.T) {
store := openTestDB(t)
_, err := store.Providers().GetByPackIDAndProviderID(context.Background(), 0, "p")
if err == nil {
t.Fatal("GetByPackIDAndProviderID with packID=0 error = nil, want error")
}
}
func TestDefaultJSONObject(t *testing.T) {
if got := defaultJSONObject(""); got != "{}" {
t.Fatalf("defaultJSONObject('') = %q, want {}", got)
}
if got := defaultJSONObject(`{"a":1}`); got != `{"a":1}` {
t.Fatalf("defaultJSONObject() = %q, want input", got)
}
}
func TestDefaultJSONArray(t *testing.T) {
if got := defaultJSONArray(""); got != "[]" {
t.Fatalf("defaultJSONArray('') = %q, want []", got)
}
if got := defaultJSONArray(`["a"]`); got != `["a"]` {
t.Fatalf("defaultJSONArray() = %q, want input", got)
}
}

View File

@@ -0,0 +1,73 @@
package sqlite
import (
"context"
"fmt"
"strings"
)
type ReconcileRun struct {
ID int64
ProviderID int64
Status string
SummaryJSON string
}
type ReconcileRunsRepo struct {
db execQuerier
}
func newReconcileRunsRepo(db execQuerier) *ReconcileRunsRepo {
return &ReconcileRunsRepo{db: db}
}
func (r *ReconcileRunsRepo) Create(ctx context.Context, run ReconcileRun) (int64, error) {
status := strings.TrimSpace(run.Status)
summaryJSON := strings.TrimSpace(run.SummaryJSON)
if summaryJSON == "" {
summaryJSON = "{}"
}
switch {
case run.ProviderID <= 0:
return 0, fmt.Errorf("provider_id is required")
case status == "":
return 0, fmt.Errorf("status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO reconcile_runs (provider_id, status, summary_json) VALUES (?, ?, ?)`, run.ProviderID, status, summaryJSON)
if err != nil {
return 0, fmt.Errorf("insert reconcile run: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted reconcile run id: %w", err)
}
return id, nil
}
func (r *ReconcileRunsRepo) GetByProviderID(ctx context.Context, providerID int64) ([]ReconcileRun, error) {
if providerID <= 0 {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, provider_id, status, summary_json FROM reconcile_runs WHERE provider_id = ? ORDER BY id DESC`, providerID)
if err != nil {
return nil, fmt.Errorf("query reconcile runs: %w", err)
}
defer rows.Close()
runs := make([]ReconcileRun, 0)
for rows.Next() {
var run ReconcileRun
if err := rows.Scan(&run.ID, &run.ProviderID, &run.Status, &run.SummaryJSON); err != nil {
return nil, fmt.Errorf("scan reconcile run: %w", err)
}
runs = append(runs, run)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate reconcile runs: %w", err)
}
return runs, nil
}