Files
sub2api-cn-relay-manager/internal/store/sqlite/db.go
2026-05-28 21:24:05 +08:00

453 lines
11 KiB
Go

package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
_ "modernc.org/sqlite"
"sub2api-cn-relay-manager/internal/store/migrations"
)
type execQuerier interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
QueryContext(context.Context, string, ...any) (*sql.Rows, error)
QueryRowContext(context.Context, string, ...any) *sql.Row
}
type Queries struct {
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
LogicalGroups *LogicalGroupsRepo
LogicalGroupModels *LogicalGroupModelsRepo
LogicalGroupRoutes *LogicalGroupRoutesRepo
LogicalGroupRouteModels *LogicalGroupRouteModelsRepo
RouteDecisionLogs *RouteDecisionLogsRepo
RouteFailoverEvents *RouteFailoverEventsRepo
RouteStickyAudit *RouteStickyAuditRepo
ProviderDrafts *ProviderDraftsRepo
ImportBatches *ImportBatchesRepo
ImportBatchItems *ImportBatchItemsRepo
ImportRuns *ImportRunsRepo
ImportRunItems *ImportRunItemsRepo
ImportRunEvents *ImportRunItemEventsRepo
ManagedResources *ManagedResourcesRepo
ProbeResults *ProbeResultsRepo
AccessClosures *AccessClosureRecordsRepo
ReconcileRuns *ReconcileRunsRepo
}
type DB struct {
sqlDB *sql.DB
queries *Queries
}
func Open(ctx context.Context, dsn string) (*DB, error) {
sqlDB, err := sql.Open("sqlite", withForeignKeysEnabled(dsn))
if err != nil {
return nil, fmt.Errorf("open sqlite database: %w", err)
}
// SQLite only tolerates a single writer at a time. Pin the pool to one
// connection so HTTP flows that chain several writes (for example host
// probe refresh + import persistence) do not self-deadlock into SQLITE_BUSY.
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
if err := sqlDB.PingContext(ctx); err != nil {
_ = sqlDB.Close()
return nil, fmt.Errorf("ping sqlite database: %w", err)
}
if err := ensureForeignKeys(ctx, sqlDB); err != nil {
_ = sqlDB.Close()
return nil, err
}
if err := migrate(ctx, sqlDB); err != nil {
_ = sqlDB.Close()
return nil, err
}
return &DB{
sqlDB: sqlDB,
queries: newQueries(sqlDB),
}, nil
}
func (db *DB) Close() error {
return db.sqlDB.Close()
}
func (db *DB) SQLDB() *sql.DB {
return db.sqlDB
}
func (db *DB) Hosts() *HostsRepo {
return db.queries.Hosts
}
func (db *DB) Packs() *PacksRepo {
return db.queries.Packs
}
func (db *DB) Providers() *ProvidersRepo {
return db.queries.Providers
}
func (db *DB) LogicalGroups() *LogicalGroupsRepo {
return db.queries.LogicalGroups
}
func (db *DB) LogicalGroupModels() *LogicalGroupModelsRepo {
return db.queries.LogicalGroupModels
}
func (db *DB) LogicalGroupRoutes() *LogicalGroupRoutesRepo {
return db.queries.LogicalGroupRoutes
}
func (db *DB) LogicalGroupRouteModels() *LogicalGroupRouteModelsRepo {
return db.queries.LogicalGroupRouteModels
}
func (db *DB) RouteDecisionLogs() *RouteDecisionLogsRepo {
return db.queries.RouteDecisionLogs
}
func (db *DB) RouteFailoverEvents() *RouteFailoverEventsRepo {
return db.queries.RouteFailoverEvents
}
func (db *DB) RouteStickyAudit() *RouteStickyAuditRepo {
return db.queries.RouteStickyAudit
}
func (db *DB) ProviderDrafts() *ProviderDraftsRepo {
return db.queries.ProviderDrafts
}
func (db *DB) ImportBatches() *ImportBatchesRepo {
return db.queries.ImportBatches
}
func (db *DB) ImportBatchItems() *ImportBatchItemsRepo {
return db.queries.ImportBatchItems
}
func (db *DB) ImportRuns() *ImportRunsRepo {
return db.queries.ImportRuns
}
func (db *DB) ImportRunItems() *ImportRunItemsRepo {
return db.queries.ImportRunItems
}
func (db *DB) ImportRunEvents() *ImportRunItemEventsRepo {
return db.queries.ImportRunEvents
}
func (db *DB) ImportRunItemEvents() *ImportRunItemEventsRepo {
return db.queries.ImportRunEvents
}
func (db *DB) ManagedResources() *ManagedResourcesRepo {
return db.queries.ManagedResources
}
func (db *DB) ProbeResults() *ProbeResultsRepo {
return db.queries.ProbeResults
}
func (db *DB) AccessClosures() *AccessClosureRecordsRepo {
return db.queries.AccessClosures
}
func (db *DB) ReconcileRuns() *ReconcileRunsRepo {
return db.queries.ReconcileRuns
}
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := db.sqlDB.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin sqlite transaction: %w", err)
}
queries := newQueries(tx)
if err := fn(queries); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return errors.Join(err, fmt.Errorf("rollback sqlite transaction: %w", rollbackErr))
}
return err
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit sqlite transaction: %w", err)
}
return nil
}
func newQueries(db execQuerier) *Queries {
return &Queries{
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
LogicalGroups: newLogicalGroupsRepo(db),
LogicalGroupModels: newLogicalGroupModelsRepo(db),
LogicalGroupRoutes: newLogicalGroupRoutesRepo(db),
LogicalGroupRouteModels: newLogicalGroupRouteModelsRepo(db),
RouteDecisionLogs: newRouteDecisionLogsRepo(db),
RouteFailoverEvents: newRouteFailoverEventsRepo(db),
RouteStickyAudit: newRouteStickyAuditRepo(db),
ProviderDrafts: newProviderDraftsRepo(db),
ImportBatches: newImportBatchesRepo(db),
ImportBatchItems: newImportBatchItemsRepo(db),
ImportRuns: newImportRunsRepo(db),
ImportRunItems: newImportRunItemsRepo(db),
ImportRunEvents: newImportRunItemEventsRepo(db),
ManagedResources: newManagedResourcesRepo(db),
ProbeResults: newProbeResultsRepo(db),
AccessClosures: newAccessClosureRecordsRepo(db),
ReconcileRuns: newReconcileRunsRepo(db),
}
}
func migrate(ctx context.Context, db *sql.DB) error {
migrationNames, err := migrationFileNames()
if err != nil {
return err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin sqlite migration transaction: %w", err)
}
if err := ensureMigrationLedger(ctx, tx); err != nil {
return rollbackMigration(tx, err)
}
appliedMigrations, err := loadAppliedMigrations(ctx, tx)
if err != nil {
return rollbackMigration(tx, err)
}
if err := backfillLegacySchemaIfNeeded(ctx, tx, migrationNames, appliedMigrations); err != nil {
return rollbackMigration(tx, err)
}
for _, name := range migrationNames {
if appliedMigrations[name] {
continue
}
migrationSQL, err := readMigration(name)
if err != nil {
return rollbackMigration(tx, err)
}
if _, err := tx.ExecContext(ctx, migrationSQL); err != nil {
return rollbackMigration(tx, fmt.Errorf("apply sqlite migration %s: %w", name, err))
}
if _, err := tx.ExecContext(
ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
name,
); err != nil {
return rollbackMigration(tx, fmt.Errorf("record sqlite migration %s: %w", name, err))
}
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit sqlite migration transaction: %w", err)
}
return nil
}
func withForeignKeysEnabled(dsn string) string {
const pragma = "_pragma=foreign_keys(1)"
if strings.Contains(dsn, "?") {
return dsn + "&" + pragma
}
return dsn + "?" + pragma
}
func ensureForeignKeys(ctx context.Context, db *sql.DB) error {
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
return fmt.Errorf("enable sqlite foreign keys: %w", err)
}
var enabled int
if err := db.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&enabled); err != nil {
return fmt.Errorf("verify sqlite foreign keys: %w", err)
}
if enabled != 1 {
return errors.New("sqlite foreign keys are disabled")
}
return nil
}
func ensureMigrationLedger(ctx context.Context, tx *sql.Tx) error {
const createLedgerSQL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)`
if _, err := tx.ExecContext(ctx, createLedgerSQL); err != nil {
return fmt.Errorf("create schema_migrations table: %w", err)
}
return nil
}
func loadAppliedMigrations(ctx context.Context, tx *sql.Tx) (map[string]bool, error) {
rows, err := tx.QueryContext(ctx, "SELECT version FROM schema_migrations")
if err != nil {
return nil, fmt.Errorf("query applied sqlite migrations: %w", err)
}
defer rows.Close()
applied := make(map[string]bool)
for rows.Next() {
var version string
if err := rows.Scan(&version); err != nil {
return nil, fmt.Errorf("scan applied sqlite migration: %w", err)
}
applied[version] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate applied sqlite migrations: %w", err)
}
return applied, nil
}
func migrationFileNames() ([]string, error) {
entries, err := fs.ReadDir(migrations.Files, ".")
if err != nil {
return nil, fmt.Errorf("read embedded sqlite migrations: %w", err)
}
var names []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
continue
}
names = append(names, entry.Name())
}
sort.Strings(names)
return names, nil
}
func readMigration(name string) (string, error) {
data, err := fs.ReadFile(migrations.Files, name)
if err != nil {
return "", fmt.Errorf("read embedded migration %s: %w", name, err)
}
return string(data), nil
}
func rollbackMigration(tx *sql.Tx, err error) error {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return errors.Join(err, fmt.Errorf("rollback sqlite migration transaction: %w", rollbackErr))
}
return err
}
func backfillLegacySchemaIfNeeded(ctx context.Context, tx *sql.Tx, migrationNames []string, appliedMigrations map[string]bool) error {
if len(migrationNames) == 0 {
return nil
}
if len(appliedMigrations) != 0 {
return nil
}
firstMigration := migrationNames[0]
if firstMigration != "0001_init.sql" {
return nil
}
complete, partial, err := detectLegacy0001Schema(ctx, tx)
if err != nil {
return err
}
if partial {
return errors.New("legacy sqlite schema is partially applied without schema_migrations")
}
if !complete {
return nil
}
if _, err := tx.ExecContext(
ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
firstMigration,
); err != nil {
return fmt.Errorf("backfill sqlite migration %s: %w", firstMigration, err)
}
appliedMigrations[firstMigration] = true
return nil
}
func detectLegacy0001Schema(ctx context.Context, tx *sql.Tx) (complete bool, partial bool, err error) {
legacyTables := []string{"hosts", "packs", "providers"}
existing := 0
for _, table := range legacyTables {
found, err := tableExists(ctx, tx, table)
if err != nil {
return false, false, err
}
if found {
existing++
}
}
switch existing {
case 0:
return false, false, nil
case len(legacyTables):
return true, false, nil
default:
return false, true, nil
}
}
func tableExists(ctx context.Context, db execQuerier, table string) (bool, error) {
var name string
err := db.QueryRowContext(
ctx,
"SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?",
table,
).Scan(&name)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("check sqlite table %s: %w", table, err)
}
return name == table, nil
}