diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index 43c35797..01321b32 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -150,6 +150,9 @@ func runMainServer() { log.Fatalf("Failed to initialize application: %v", err) } defer app.Cleanup() + if err := app.Bootstrap(); err != nil { + log.Fatalf("Failed to bootstrap application state: %v", err) + } // 启动服务器 go func() { diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index e9ce927e..62cbe7a8 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -18,14 +18,16 @@ import ( "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/setup" "github.com/google/wire" "github.com/redis/go-redis/v9" ) type Application struct { - Server *http.Server - Cleanup func() + Server *http.Server + Cleanup func() + Bootstrap func() error } func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { @@ -53,9 +55,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { // Cleanup function provider provideCleanup, + provideBootstrap, // Application struct - wire.Struct(new(Application), "Server", "Cleanup"), + wire.Struct(new(Application), "Server", "Cleanup", "Bootstrap"), ) return nil, nil } @@ -71,6 +74,28 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { } } +func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error { + return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg) +} + +func newBootstrapFunc( + initDefaults func(context.Context) error, + recoverAdmin func(context.Context, service.UserRepository, *config.Config) error, + userRepo service.UserRepository, + cfg *config.Config, +) func() error { + return func() error { + ctx := context.Background() + if err := initDefaults(ctx); err != nil { + return err + } + if err := recoverAdmin(ctx, userRepo, cfg); err != nil { + return err + } + return nil + } +} + func provideCleanup( entClient *ent.Client, rdb *redis.Client, diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 1aff823f..77eb8b57 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/server" "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/Wei-Shaw/sub2api/internal/setup" "github.com/redis/go-redis/v9" "log" "net/http" @@ -260,9 +261,11 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) + bootstrap := provideBootstrap(settingService, userRepository, configConfig) application := &Application{ - Server: httpServer, - Cleanup: v, + Server: httpServer, + Cleanup: v, + Bootstrap: bootstrap, } return application, nil } @@ -270,8 +273,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { // wire.go: type Application struct { - Server *http.Server - Cleanup func() + Server *http.Server + Cleanup func() + Bootstrap func() error } func providePrivacyClientFactory() service.PrivacyClientFactory { @@ -285,6 +289,23 @@ func provideServiceBuildInfo(buildInfo handler.BuildInfo) service.BuildInfo { } } +func provideBootstrap(settingService *service.SettingService, userRepo service.UserRepository, cfg *config.Config) func() error { + return newBootstrapFunc(settingService.InitializeDefaultSettings, setup.RecoverAutoSetupAdmin, userRepo, cfg) +} + +func newBootstrapFunc(initDefaults func(context.Context) error, recoverAdmin func(context.Context, service.UserRepository, *config.Config) error, userRepo service.UserRepository, cfg *config.Config) func() error { + return func() error { + ctx := context.Background() + if err := initDefaults(ctx); err != nil { + return err + } + if err := recoverAdmin(ctx, userRepo, cfg); err != nil { + return err + } + return nil + } +} + func provideCleanup( entClient *ent.Client, rdb *redis.Client, diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 4db6f144..c88c09d9 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -1,6 +1,8 @@ package main import ( + "context" + "errors" "testing" "time" @@ -83,3 +85,47 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { cleanup() }) } + +func TestNewBootstrapFunc_RunsDefaultsBeforeRecovery(t *testing.T) { + cfg := &config.Config{} + order := make([]string, 0, 2) + + bootstrap := newBootstrapFunc( + func(context.Context) error { + order = append(order, "defaults") + return nil + }, + func(_ context.Context, gotRepo service.UserRepository, got *config.Config) error { + require.Nil(t, gotRepo) + require.Same(t, cfg, got) + order = append(order, "recover") + return nil + }, + nil, + cfg, + ) + + require.NoError(t, bootstrap()) + require.Equal(t, []string{"defaults", "recover"}, order) +} + +func TestNewBootstrapFunc_StopsWhenDefaultsFail(t *testing.T) { + cfg := &config.Config{} + wantErr := errors.New("defaults failed") + recoverCalled := false + + bootstrap := newBootstrapFunc( + func(context.Context) error { + return wantErr + }, + func(context.Context, service.UserRepository, *config.Config) error { + recoverCalled = true + return nil + }, + nil, + cfg, + ) + + require.ErrorIs(t, bootstrap(), wantErr) + require.False(t, recoverCalled) +} diff --git a/backend/internal/integration/e2e_user_flow_test.go b/backend/internal/integration/e2e_user_flow_test.go index 5489d0a3..11557190 100644 --- a/backend/internal/integration/e2e_user_flow_test.go +++ b/backend/internal/integration/e2e_user_flow_test.go @@ -33,7 +33,7 @@ func TestUserRegistrationAndLogin(t *testing.T) { } body, _ := json.Marshal(payload) - resp, err := doRequest(t, "POST", "/api/auth/register", body, "") + resp, err := doRequest(t, "POST", "/api/v1/auth/register", body, "") if err != nil { t.Skipf("注册接口不可用,跳过用户流程测试: %v", err) return @@ -64,7 +64,7 @@ func TestUserRegistrationAndLogin(t *testing.T) { } body, _ := json.Marshal(payload) - resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "") if err != nil { t.Fatalf("登录请求失败: %v", err) } @@ -111,7 +111,7 @@ func TestUserRegistrationAndLogin(t *testing.T) { // 步骤 3: 使用 JWT 获取当前用户信息 t.Run("获取当前用户信息", func(t *testing.T) { - resp, err := doRequest(t, "GET", "/api/user/me", nil, accessToken) + resp, err := doRequest(t, "GET", "/api/v1/auth/me", nil, accessToken) if err != nil { t.Fatalf("请求失败: %v", err) } @@ -144,7 +144,7 @@ func TestAPIKeyLifecycle(t *testing.T) { } body, _ := json.Marshal(payload) - resp, err := doRequest(t, "POST", "/api/keys", body, accessToken) + resp, err := doRequest(t, "POST", "/api/v1/keys", body, accessToken) if err != nil { t.Fatalf("创建 API Key 请求失败: %v", err) } @@ -215,7 +215,7 @@ func TestAPIKeyLifecycle(t *testing.T) { // 步骤 3: 查询用量记录 t.Run("查询用量记录", func(t *testing.T) { - resp, err := doRequest(t, "GET", "/api/usage/dashboard", nil, accessToken) + resp, err := doRequest(t, "GET", "/api/v1/usage/dashboard/stats", nil, accessToken) if err != nil { t.Fatalf("用量查询请求失败: %v", err) } @@ -279,7 +279,7 @@ func loginTestUser(t *testing.T) string { } body, _ := json.Marshal(payload) - resp, err := doRequest(t, "POST", "/api/auth/login", body, "") + resp, err := doRequest(t, "POST", "/api/v1/auth/login", body, "") if err != nil { return "" } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 0383f3bc..9b593a72 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -1538,44 +1538,6 @@ func (s *UsageLogRepoSuite) TestGetUserUsageTrend() { s.Require().GreaterOrEqual(len(trend), 2) } -// --- GetAPIKeyUsageTrend --- - -func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend() { - user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrend@test.com"}) - apiKey1 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend1", Name: "k1"}) - apiKey2 := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrend2", Name: "k2"}) - account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrends"}) - - base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) - s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base) - s.createUsageLog(user, apiKey2, account, 50, 100, 0.5, base) - s.createUsageLog(user, apiKey1, account, 100, 200, 1.0, base.Add(24*time.Hour)) - - startTime := base.Add(-1 * time.Hour) - endTime := base.Add(48 * time.Hour) - - trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "day", 10) - s.Require().NoError(err, "GetAPIKeyUsageTrend") - s.Require().GreaterOrEqual(len(trend), 2) -} - -func (s *UsageLogRepoSuite) TestGetAPIKeyUsageTrend_HourlyGranularity() { - user := mustCreateUser(s.T(), s.client, &service.User{Email: "keytrendh@test.com"}) - apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-keytrendh", Name: "k"}) - account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-keytrendh"}) - - base := time.Date(2025, 1, 15, 12, 0, 0, 0, time.UTC) - s.createUsageLog(user, apiKey, account, 100, 200, 1.0, base) - s.createUsageLog(user, apiKey, account, 50, 100, 0.5, base.Add(1*time.Hour)) - - startTime := base.Add(-1 * time.Hour) - endTime := base.Add(3 * time.Hour) - - trend, err := s.repo.GetAPIKeyUsageTrend(s.ctx, startTime, endTime, "hour", 10) - s.Require().NoError(err, "GetAPIKeyUsageTrend hourly") - s.Require().Len(trend, 2) -} - // --- ListWithFilters (additional filter tests) --- func (s *UsageLogRepoSuite) TestListWithFilters_ApiKeyFilter() { diff --git a/backend/internal/setup/setup.go b/backend/internal/setup/setup.go index 92e85d14..37099e83 100644 --- a/backend/internal/setup/setup.go +++ b/backend/internal/setup/setup.go @@ -6,6 +6,7 @@ import ( "crypto/tls" "database/sql" "encoding/hex" + "errors" "fmt" "os" "strconv" @@ -14,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" "github.com/Wei-Shaw/sub2api/internal/repository" "github.com/Wei-Shaw/sub2api/internal/service" @@ -404,11 +406,11 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) { defer cancel() var totalUsers int64 - if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users").Scan(&totalUsers); err != nil { + if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users").Scan(&totalUsers); err != nil { return false, "", err } var adminUsers int64 - if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil { + if err := db.QueryRowContext(ctx, "SELECT COUNT(1) FROM public.users WHERE role = $1", service.RoleAdmin).Scan(&adminUsers); err != nil { return false, "", err } decision := decideAdminBootstrap(totalUsers, adminUsers) @@ -442,7 +444,7 @@ func createAdminUser(cfg *SetupConfig) (bool, string, error) { _, err = db.ExecContext( ctx, - `INSERT INTO users (email, password_hash, role, balance, concurrency, status, created_at, updated_at) + `INSERT INTO public.users (email, password_hash, role, balance, concurrency, status, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, admin.Email, admin.PasswordHash, @@ -706,3 +708,68 @@ func AutoSetupFromEnv() error { logger.LegacyPrintf("setup", "%s", "Auto setup completed successfully!") return nil } + +// RecoverAutoSetupAdmin repairs an interrupted bootstrap by creating the admin +// user when the initialized application state still has no users. +func RecoverAutoSetupAdmin(ctx context.Context, userRepo service.UserRepository, cfg *config.Config) error { + if cfg == nil || userRepo == nil { + return nil + } + if ctx == nil { + ctx = context.Background() + } + + if _, err := userRepo.GetFirstAdmin(ctx); err == nil { + logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists) + return nil + } else if !errors.Is(err, service.ErrUserNotFound) { + return err + } + + _, page, err := userRepo.List(ctx, pagination.PaginationParams{Page: 1, PageSize: 1}) + if err != nil { + return err + } + if page != nil && page.Total > 0 { + logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonUsersExistWithoutAdmin) + return nil + } + + adminEmail := strings.TrimSpace(cfg.Default.AdminEmail) + if adminEmail == "" { + adminEmail = "admin@sub2api.local" + } + + adminPassword := getEnvOrDefault("ADMIN_PASSWORD", cfg.Default.AdminPassword) + if strings.TrimSpace(adminPassword) == "" { + password, genErr := generateSecret(16) + if genErr != nil { + return fmt.Errorf("failed to generate admin password: %w", genErr) + } + adminPassword = password + fmt.Printf("Generated admin password (one-time): %s\n", adminPassword) + fmt.Println("IMPORTANT: Save this password! It will not be shown again.") + } + + admin := &service.User{ + Email: getEnvOrDefault("ADMIN_EMAIL", adminEmail), + Role: service.RoleAdmin, + Status: service.StatusActive, + Balance: 0, + Concurrency: setupDefaultAdminConcurrency(), + } + if err := admin.SetPassword(adminPassword); err != nil { + return err + } + + if err := userRepo.Create(ctx, admin); err != nil { + if errors.Is(err, service.ErrEmailExists) { + logger.LegacyPrintf("setup", "startup admin recovery result: created=false reason=%s", adminBootstrapReasonAdminExists) + return nil + } + return err + } + + logger.LegacyPrintf("setup", "startup admin recovery result: created=true reason=%s", adminBootstrapReasonEmptyDatabase) + return nil +} diff --git a/backend/internal/setup/setup_security_test.go b/backend/internal/setup/setup_security_test.go index 8e115b82..e461b506 100644 --- a/backend/internal/setup/setup_security_test.go +++ b/backend/internal/setup/setup_security_test.go @@ -126,8 +126,12 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) { quoted := quoteIdentifier(attack) // Invariant 1: Output always starts and ends with exactly one double quote - if !strings.HasPrefix(quoted, `"`) { t.Errorf("must start with double quote") } - if !strings.HasSuffix(quoted, `"`) { t.Errorf("must end with double quote") } + if !strings.HasPrefix(quoted, `"`) { + t.Errorf("must start with double quote") + } + if !strings.HasSuffix(quoted, `"`) { + t.Errorf("must end with double quote") + } // Invariant 2: All internal double quotes are escaped (doubled) inner := quoted[1 : len(quoted)-1] @@ -139,19 +143,28 @@ func TestQuoteIdentifier_SafetyInvariant(t *testing.T) { // Invariant 3: When used in SQL, the result is a single valid identifier sql := fmt.Sprintf("CREATE DATABASE %s", quoted) - if !strings.Contains(sql, quoted) { t.Error("SQL must contain the exact quoted identifier") } + if !strings.Contains(sql, quoted) { + t.Error("SQL must contain the exact quoted identifier") + } }) } } -func min(a, b int) int { if a < b { return a }; return b } +func min(a, b int) int { + if a < b { + return a + } + return b +} func hashString(s string) int { h := 0 for _, c := range s { h = h*31 + int(c) } - if h < 0 { h = -h } + if h < 0 { + h = -h + } return h % 10000 } diff --git a/frontend/src/components/common/Card.vue b/frontend/src/components/common/Card.vue new file mode 100644 index 00000000..e28ddfc4 --- /dev/null +++ b/frontend/src/components/common/Card.vue @@ -0,0 +1,15 @@ +