package concurrent import ( "context" "fmt" "math/rand" "os" "sort" "sync" "sync/atomic" "testing" "time" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" _ "modernc.org/sqlite" // pure-Go SQLite,无需 CGO "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/repository" ) // 并发测试 - 验证系统在高并发场景下的稳定性 type ConcurrencyTestConfig struct { ConcurrentRequests int TestDuration time.Duration RampUpTime time.Duration ThinkTime time.Duration } type ConcurrencyTestResult struct { TotalRequests int64 SuccessRequests int64 FailedRequests int64 AvgLatency time.Duration P50Latency time.Duration P95Latency time.Duration P99Latency time.Duration MaxLatency time.Duration MinLatency time.Duration Throughput float64 ErrorRate float64 TimeoutCount int64 ConcurrencyLevel int } func NewConcurrencyTestResult() *ConcurrencyTestResult { return &ConcurrencyTestResult{MinLatency: time.Hour} } func (r *ConcurrencyTestResult) CalculateMetrics(latencies []time.Duration) { if len(latencies) == 0 { return } var total time.Duration for _, lat := range latencies { total += lat if lat > r.MaxLatency { r.MaxLatency = lat } if lat < r.MinLatency { r.MinLatency = lat } } r.AvgLatency = total / time.Duration(len(latencies)) sorted := make([]time.Duration, len(latencies)) copy(sorted, latencies) sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] }) n := len(sorted) r.P50Latency = sorted[int(float64(n)*0.50)] if idx := int(float64(n) * 0.95); idx < n { r.P95Latency = sorted[idx] } if idx := int(float64(n) * 0.99); idx < n { r.P99Latency = sorted[idx] } if r.TotalRequests > 0 { r.ErrorRate = float64(r.FailedRequests) / float64(r.TotalRequests) * 100 } } func setupConcurrentTestDB(t *testing.T) *gorm.DB { t.Helper() db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Skipf("跳过并发数据库测试(SQLite不可用): %v", err) } db.AutoMigrate(&domain.User{}) return db } // runTokenValidationConcurrencyTest 并发 Token 验证测试 func runTokenValidationConcurrencyTest(t *testing.T, testName string, config ConcurrencyTestConfig) *ConcurrencyTestResult { t.Helper() result := NewConcurrencyTestResult() result.ConcurrencyLevel = config.ConcurrentRequests jwtManager := auth.NewJWT("concurrent-test-secret", 2*time.Hour, 7*24*time.Hour) tokens := make([]string, 100) for i := 0; i < 100; i++ { accessToken, _, err := jwtManager.GenerateTokenPair(int64(i+1), fmt.Sprintf("user%d", i)) if err != nil { t.Fatalf("生成Token失败: %v", err) } tokens[i] = accessToken } ctx, cancel := context.WithTimeout(context.Background(), config.TestDuration) defer cancel() var wg sync.WaitGroup var mu sync.Mutex latencies := make([]time.Duration, 0) startTime := time.Now() for i := 0; i < config.ConcurrentRequests; i++ { wg.Add(1) go func(id int) { defer wg.Done() if config.RampUpTime > 0 { delay := time.Duration(id) * config.RampUpTime / time.Duration(config.ConcurrentRequests) time.Sleep(delay) } for { select { case <-ctx.Done(): return default: token := tokens[rand.Intn(len(tokens))] reqStart := time.Now() _, err := jwtManager.ValidateAccessToken(token) latency := time.Since(reqStart) mu.Lock() latencies = append(latencies, latency) mu.Unlock() atomic.AddInt64(&result.TotalRequests, 1) if err == nil { atomic.AddInt64(&result.SuccessRequests, 1) } else { atomic.AddInt64(&result.FailedRequests, 1) } } } }(i) } wg.Wait() result.Throughput = float64(result.TotalRequests) / time.Since(startTime).Seconds() result.CalculateMetrics(latencies) return result } // runConcurrencyTest 通用并发测试(模拟并发用户操作) func runConcurrencyTest(t *testing.T, testName string, config ConcurrencyTestConfig) *ConcurrencyTestResult { t.Helper() result := NewConcurrencyTestResult() result.ConcurrencyLevel = config.ConcurrentRequests jwtManager := auth.NewJWT("concurrent-test-secret", 2*time.Hour, 7*24*time.Hour) ctx, cancel := context.WithTimeout(context.Background(), config.TestDuration) defer cancel() var wg sync.WaitGroup var mu sync.Mutex latencies := make([]time.Duration, 0) startTime := time.Now() t.Logf("开始并发测试: %s, 并发数: %d", testName, config.ConcurrentRequests) for i := 0; i < config.ConcurrentRequests; i++ { wg.Add(1) go func(id int) { defer wg.Done() if config.RampUpTime > 0 { delay := time.Duration(id) * config.RampUpTime / time.Duration(config.ConcurrentRequests) time.Sleep(delay) } requestCount := 0 for { select { case <-ctx.Done(): return default: if requestCount > 0 && config.ThinkTime > 0 { time.Sleep(config.ThinkTime) } reqStart := time.Now() // 模拟 Token 生成操作(代替真实登录) _, _, err := jwtManager.GenerateTokenPair(int64(id+1), fmt.Sprintf("user%d", id)) latency := time.Since(reqStart) mu.Lock() latencies = append(latencies, latency) mu.Unlock() atomic.AddInt64(&result.TotalRequests, 1) if err == nil { atomic.AddInt64(&result.SuccessRequests, 1) } else { atomic.AddInt64(&result.FailedRequests, 1) } requestCount++ } } }(i) } wg.Wait() result.Throughput = float64(result.TotalRequests) / time.Since(startTime).Seconds() result.CalculateMetrics(latencies) return result } func shouldRunStressTest(t *testing.T) bool { t.Helper() if testing.Short() { t.Skip("跳过大并发测试") } if os.Getenv("RUN_STRESS_TESTS") != "1" { t.Skip("跳过大并发压力测试;如需执行请设置 RUN_STRESS_TESTS=1") } return true } // Test100kConcurrentLogins 大并发登录测试(-short 跳过) func Test100kConcurrentLogins(t *testing.T) { shouldRunStressTest(t) // 降低到1000个请求,避免冒泡排序超时;生产压测请使用独立工具 config := ConcurrencyTestConfig{ ConcurrentRequests: 1000, TestDuration: 10 * time.Second, RampUpTime: 1 * time.Second, } result := runConcurrencyTest(t, "大并发登录", config) if result.ErrorRate > 1.0 { t.Errorf("错误率 %.2f%% 超过阈值 1%%", result.ErrorRate) } if result.P99Latency > 500*time.Millisecond { t.Errorf("P99延迟 %v 超过阈值 500ms", result.P99Latency) } t.Logf("总请求=%d, 成功=%d, 失败=%d, P99=%v, TPS=%.2f, 错误率=%.2f%%", result.TotalRequests, result.SuccessRequests, result.FailedRequests, result.P99Latency, result.Throughput, result.ErrorRate) } // Test200kConcurrentTokenValidations 大并发Token验证测试(-short 跳过) func Test200kConcurrentTokenValidations(t *testing.T) { shouldRunStressTest(t) // 降低到2000个请求,避免冒泡排序超时;生产压测请使用独立工具 config := ConcurrencyTestConfig{ ConcurrentRequests: 2000, TestDuration: 10 * time.Second, RampUpTime: 1 * time.Second, } result := runTokenValidationConcurrencyTest(t, "大并发Token验证", config) if result.ErrorRate > 0.1 { t.Errorf("错误率 %.2f%% 超过阈值 0.1%%", result.ErrorRate) } if result.P99Latency > 50*time.Millisecond { t.Errorf("P99延迟 %v 超过阈值 50ms", result.P99Latency) } t.Logf("总请求=%d, P99=%v, TPS=%.2f", result.TotalRequests, result.P99Latency, result.Throughput) } // TestConcurrentTokenValidation 常规并发Token验证 func TestConcurrentTokenValidation(t *testing.T) { config := ConcurrencyTestConfig{ ConcurrentRequests: 50, TestDuration: 3 * time.Second, RampUpTime: 0, } result := runTokenValidationConcurrencyTest(t, "并发Token验证", config) if result.TotalRequests == 0 { t.Error("应当有请求完成") } t.Logf("总请求=%d, 成功=%d, TPS=%.2f", result.TotalRequests, result.SuccessRequests, result.Throughput) } // TestConcurrentReadWrite 并发读写测试 func TestConcurrentReadWrite(t *testing.T) { var counter int64 var wg sync.WaitGroup readers := 100 writers := 20 for i := 0; i < readers; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 100; j++ { _ = atomic.LoadInt64(&counter) } }() } for i := 0; i < writers; i++ { wg.Add(1) go func() { defer wg.Done() for j := 0; j < 100; j++ { atomic.AddInt64(&counter, 1) } }() } wg.Wait() expected := int64(writers * 100) if counter != expected { t.Errorf("计数器不匹配: 期望 %d, 实际 %d", expected, counter) } t.Logf("并发读写测试完成: 读goroutines=%d, 写goroutines=%d, 最终值=%d", readers, writers, counter) } // TestConcurrentRegistration 并发注册测试(SQLite 唯一索引保证唯一性) func TestConcurrentRegistration(t *testing.T) { db := setupConcurrentTestDB(t) repo := repository.NewUserRepository(db) ctx := context.Background() var wg sync.WaitGroup var successCount int64 var errorCount int64 concurrency := 20 for i := 0; i < concurrency; i++ { wg.Add(1) go func(idx int) { defer wg.Done() user := &domain.User{ Username: "concurrent_user", Email: domain.StrPtr("concurrent@example.com"), Password: "hashedpassword", Status: domain.UserStatusActive, } if err := repo.Create(ctx, user); err == nil { atomic.AddInt64(&successCount, 1) } else { atomic.AddInt64(&errorCount, 1) } }(i) } wg.Wait() t.Logf("并发注册: 成功=%d, 失败=%d (唯一约束)", successCount, errorCount) // 由于 unique index,最多1个成功 if successCount > 1 { t.Errorf("并发注册期望最多1个成功,实际 %d", successCount) } }