fix startup bootstrap recovery and local verification

This commit is contained in:
2026-04-23 10:27:13 +08:00
parent 32b2c23a04
commit fa0aacc559
9 changed files with 211 additions and 59 deletions

View File

@@ -6,6 +6,7 @@ import (
"crypto/tls"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"os"
"strconv"
@@ -14,6 +15,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/repository"
"github.com/Wei-Shaw/sub2api/internal/service"
@@ -404,11 +406,11 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
defer cancel()
var totalUsers int64
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil {
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users").Scan(&totalUsers); err != nil {
return false, "", err
}
var adminUsers int64
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil {
return false, "", err
}
decision := decideAdminBootstrap(totalUsers, adminUsers)
@@ -442,7 +444,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) {
_, err = db.ExecContext(
ctx,
`INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
`INSERT INTO public.users (email, password_hash, role, balance, concurrency, status, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`,
admin.Email,
admin.PasswordHash,
@@ -706,3 +708,68 @@ func AutoSetupFromEnv() error {
logger.LegacyPrintf("setup", "%s", "Auto setup completed successfully!")
return nil
}
// RecoverAutoSetupAdmin repairs an interrupted bootstrap by creating the admin
// user when the initialized application state still has no users.
func RecoverAutoSetupAdmin(ctx context.Context, userRepo service.UserRepository, cfg *config.Config) error {
if cfg == nil || userRepo == nil {
return nil
}
if ctx == nil {
ctx = context.Background()
}
if _, err := userRepo.GetFirstAdmin(ctx); err == nil {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
return nil
} else if !errors.Is(err, service.ErrUserNotFound) {
return err
}
_, page, err := userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 1})
if err != nil {
return err
}
if page != nil && page.Total > 0 {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonUsersExistWithoutAdmin)
return nil
}
adminEmail := strings.TrimSpace(cfg.Default.AdminEmail)
if adminEmail == "" {
adminEmail = "admin@sub2api.local"
}
adminPassword := getEnvOrDefault("ADMIN_PASSWORD", cfg.Default.AdminPassword)
if strings.TrimSpace(adminPassword) == "" {
password, genErr := generateSecret(16)
if genErr != nil {
return fmt.Errorf("failed to generate admin password: %w", genErr)
}
adminPassword = password
fmt.Printf("Generated admin password (one-time): %s\n", adminPassword)
fmt.Println("IMPORTANT: Save this password! It will not be shown again.")
}
admin := &service.User{
Email: getEnvOrDefault("ADMIN_EMAIL", adminEmail),
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 0,
Concurrency: setupDefaultAdminConcurrency(),
}
if err := admin.SetPassword(adminPassword); err != nil {
return err
}
if err := userRepo.Create(ctx, admin); err != nil {
if errors.Is(err, service.ErrEmailExists) {
logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists)
return nil
}
return err
}
logger.LegacyPrintf("setup", "startup admin recovery result: created=true reason=%s", adminBootstrapReasonEmptyDatabase)
return nil
}

View File

@@ -126,8 +126,12 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
quoted := quoteIdentifier(attack)
// Invariant 1: Output always starts and ends with exactly one double quote
if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") }
if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") }
if !strings.HasPrefix(quoted, `"`) {
t.Errorf("must start with double quote")
}
if !strings.HasSuffix(quoted, `"`) {
t.Errorf("must end with double quote")
}
// Invariant 2: All internal double quotes are escaped (doubled)
inner := quoted[1 : len(quoted)-1]
@@ -139,19 +143,28 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
// Invariant 3: When used in SQL, the result is a single valid identifier
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") }
if !strings.Contains(sql, quoted) {
t.Error("SQL must contain the exact quoted identifier")
}
})
}
}
func min(a, b int) int { if a < b { return a }; return b }
func min(a, b int) int {
if a < b {
return a
}
return b
}
func hashString(s string) int {
h := 0
for _, c := range s {
h = h*31 + int(c)
}
if h < 0 { h = -h }
if h < 0 {
h = -h
}
return h % 10000
}