Files
sub2api-cn-relay-manager/internal/store/sqlite/db.go
2026-05-12 23:25:02 +08:00

341 lines
7.5 KiB
Go

package sqlite
import (
"context"
"database/sql"
"errors"
"fmt"
"io/fs"
"sort"
"strings"
_ "modernc.org/sqlite"
"sub2api-cn-relay-manager/internal/store/migrations"
)
type execQuerier interface {
ExecContext(context.Context, string, ...any) (sql.Result, error)
QueryRowContext(context.Context, string, ...any) *sql.Row
}
type Queries struct {
Hosts *HostsRepo
Packs *PacksRepo
Providers *ProvidersRepo
}
type DB struct {
sqlDB *sql.DB
queries *Queries
}
func Open(ctx context.Context, dsn string) (*DB, error) {
sqlDB, err := sql.Open("sqlite", withForeignKeysEnabled(dsn))
if err != nil {
return nil, fmt.Errorf("open sqlite database: %w", err)
}
if err := sqlDB.PingContext(ctx); err != nil {
_ = sqlDB.Close()
return nil, fmt.Errorf("ping sqlite database: %w", err)
}
if err := ensureForeignKeys(ctx, sqlDB); err != nil {
_ = sqlDB.Close()
return nil, err
}
if err := migrate(ctx, sqlDB); err != nil {
_ = sqlDB.Close()
return nil, err
}
return &DB{
sqlDB: sqlDB,
queries: newQueries(sqlDB),
}, nil
}
func (db *DB) Close() error {
return db.sqlDB.Close()
}
func (db *DB) SQLDB() *sql.DB {
return db.sqlDB
}
func (db *DB) Hosts() *HostsRepo {
return db.queries.Hosts
}
func (db *DB) Packs() *PacksRepo {
return db.queries.Packs
}
func (db *DB) Providers() *ProvidersRepo {
return db.queries.Providers
}
func (db *DB) WithTx(ctx context.Context, fn func(*Queries) error) error {
tx, err := db.sqlDB.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin sqlite transaction: %w", err)
}
queries := newQueries(tx)
if err := fn(queries); err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return errors.Join(err, fmt.Errorf("rollback sqlite transaction: %w", rollbackErr))
}
return err
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit sqlite transaction: %w", err)
}
return nil
}
func newQueries(db execQuerier) *Queries {
return &Queries{
Hosts: newHostsRepo(db),
Packs: newPacksRepo(db),
Providers: newProvidersRepo(db),
}
}
func migrate(ctx context.Context, db *sql.DB) error {
migrationNames, err := migrationFileNames()
if err != nil {
return err
}
tx, err := db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin sqlite migration transaction: %w", err)
}
if err := ensureMigrationLedger(ctx, tx); err != nil {
return rollbackMigration(tx, err)
}
appliedMigrations, err := loadAppliedMigrations(ctx, tx)
if err != nil {
return rollbackMigration(tx, err)
}
if err := backfillLegacySchemaIfNeeded(ctx, tx, migrationNames, appliedMigrations); err != nil {
return rollbackMigration(tx, err)
}
for _, name := range migrationNames {
if appliedMigrations[name] {
continue
}
migrationSQL, err := readMigration(name)
if err != nil {
return rollbackMigration(tx, err)
}
if _, err := tx.ExecContext(ctx, migrationSQL); err != nil {
return rollbackMigration(tx, fmt.Errorf("apply sqlite migration %s: %w", name, err))
}
if _, err := tx.ExecContext(
ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
name,
); err != nil {
return rollbackMigration(tx, fmt.Errorf("record sqlite migration %s: %w", name, err))
}
}
if err := tx.Commit(); err != nil {
_ = tx.Rollback()
return fmt.Errorf("commit sqlite migration transaction: %w", err)
}
return nil
}
func withForeignKeysEnabled(dsn string) string {
const pragma = "_pragma=foreign_keys(1)"
if strings.Contains(dsn, "?") {
return dsn + "&" + pragma
}
return dsn + "?" + pragma
}
func ensureForeignKeys(ctx context.Context, db *sql.DB) error {
if _, err := db.ExecContext(ctx, "PRAGMA foreign_keys = ON"); err != nil {
return fmt.Errorf("enable sqlite foreign keys: %w", err)
}
var enabled int
if err := db.QueryRowContext(ctx, "PRAGMA foreign_keys").Scan(&enabled); err != nil {
return fmt.Errorf("verify sqlite foreign keys: %w", err)
}
if enabled != 1 {
return errors.New("sqlite foreign keys are disabled")
}
return nil
}
func ensureMigrationLedger(ctx context.Context, tx *sql.Tx) error {
const createLedgerSQL = `
CREATE TABLE IF NOT EXISTS schema_migrations (
version TEXT PRIMARY KEY,
applied_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)`
if _, err := tx.ExecContext(ctx, createLedgerSQL); err != nil {
return fmt.Errorf("create schema_migrations table: %w", err)
}
return nil
}
func loadAppliedMigrations(ctx context.Context, tx *sql.Tx) (map[string]bool, error) {
rows, err := tx.QueryContext(ctx, "SELECT version FROM schema_migrations")
if err != nil {
return nil, fmt.Errorf("query applied sqlite migrations: %w", err)
}
defer rows.Close()
applied := make(map[string]bool)
for rows.Next() {
var version string
if err := rows.Scan(&version); err != nil {
return nil, fmt.Errorf("scan applied sqlite migration: %w", err)
}
applied[version] = true
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate applied sqlite migrations: %w", err)
}
return applied, nil
}
func migrationFileNames() ([]string, error) {
entries, err := fs.ReadDir(migrations.Files, ".")
if err != nil {
return nil, fmt.Errorf("read embedded sqlite migrations: %w", err)
}
var names []string
for _, entry := range entries {
if entry.IsDir() || !strings.HasSuffix(entry.Name(), ".sql") {
continue
}
names = append(names, entry.Name())
}
sort.Strings(names)
return names, nil
}
func readMigration(name string) (string, error) {
data, err := fs.ReadFile(migrations.Files, name)
if err != nil {
return "", fmt.Errorf("read embedded migration %s: %w", name, err)
}
return string(data), nil
}
func rollbackMigration(tx *sql.Tx, err error) error {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
return errors.Join(err, fmt.Errorf("rollback sqlite migration transaction: %w", rollbackErr))
}
return err
}
func backfillLegacySchemaIfNeeded(ctx context.Context, tx *sql.Tx, migrationNames []string, appliedMigrations map[string]bool) error {
if len(migrationNames) == 0 {
return nil
}
if len(appliedMigrations) != 0 {
return nil
}
firstMigration := migrationNames[0]
if firstMigration != "0001_init.sql" {
return nil
}
complete, partial, err := detectLegacy0001Schema(ctx, tx)
if err != nil {
return err
}
if partial {
return errors.New("legacy sqlite schema is partially applied without schema_migrations")
}
if !complete {
return nil
}
if _, err := tx.ExecContext(
ctx,
"INSERT INTO schema_migrations (version) VALUES (?)",
firstMigration,
); err != nil {
return fmt.Errorf("backfill sqlite migration %s: %w", firstMigration, err)
}
appliedMigrations[firstMigration] = true
return nil
}
func detectLegacy0001Schema(ctx context.Context, tx *sql.Tx) (complete bool, partial bool, err error) {
legacyTables := []string{"hosts", "packs", "providers"}
existing := 0
for _, table := range legacyTables {
found, err := tableExists(ctx, tx, table)
if err != nil {
return false, false, err
}
if found {
existing++
}
}
switch existing {
case 0:
return false, false, nil
case len(legacyTables):
return true, false, nil
default:
return false, true, nil
}
}
func tableExists(ctx context.Context, db execQuerier, table string) (bool, error) {
var name string
err := db.QueryRowContext(
ctx,
"SELECT name FROM sqlite_master WHERE type = 'table' AND name = ?",
table,
).Scan(&name)
if errors.Is(err, sql.ErrNoRows) {
return false, nil
}
if err != nil {
return false, fmt.Errorf("check sqlite table %s: %w", table, err)
}
return name == table, nil
}