package performance import ( "context" "fmt" "math" "runtime" "sync" "sync/atomic" "testing" "time" "gorm.io/driver/sqlite" "gorm.io/gorm" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/repository" ) // PerformanceMetrics 性能度量 type PerformanceMetrics struct { RequestCount int64 SuccessCount int64 FailureCount int64 TotalLatency int64 // 纳秒 MinLatency int64 MaxLatency int64 CacheHitCount int64 CacheMissCount int64 DBQueryCount int64 SlowQueries int64 // 超过100ms的查询 } func NewPerformanceMetrics() *PerformanceMetrics { return &PerformanceMetrics{MinLatency: math.MaxInt64} } func (m *PerformanceMetrics) RecordLatency(latency int64) { atomic.AddInt64(&m.RequestCount, 1) atomic.AddInt64(&m.TotalLatency, latency) for { old := atomic.LoadInt64(&m.MinLatency) if latency >= old || atomic.CompareAndSwapInt64(&m.MinLatency, old, latency) { break } } for { old := atomic.LoadInt64(&m.MaxLatency) if latency <= old || atomic.CompareAndSwapInt64(&m.MaxLatency, old, latency) { break } } if latency > 100_000_000 { atomic.AddInt64(&m.SlowQueries, 1) } } func (m *PerformanceMetrics) RecordCacheHit() { atomic.AddInt64(&m.CacheHitCount, 1) } func (m *PerformanceMetrics) RecordCacheMiss() { atomic.AddInt64(&m.CacheMissCount, 1) } func (m *PerformanceMetrics) GetP99Latency() time.Duration { // 简化实现,实际应使用直方图收集延迟样本 return 0 } func (m *PerformanceMetrics) GetAverageLatency() time.Duration { count := atomic.LoadInt64(&m.RequestCount) if count == 0 { return 0 } return time.Duration(atomic.LoadInt64(&m.TotalLatency) / count) } func (m *PerformanceMetrics) GetCacheHitRate() float64 { hits := atomic.LoadInt64(&m.CacheHitCount) misses := atomic.LoadInt64(&m.CacheMissCount) total := hits + misses if total == 0 { return 0 } return float64(hits) / float64(total) * 100 } func (m *PerformanceMetrics) GetSuccessRate() float64 { success := atomic.LoadInt64(&m.SuccessCount) total := atomic.LoadInt64(&m.RequestCount) if total == 0 { return 0 } return float64(success) / float64(total) * 100 } // setupBenchmarkDB 创建基准测试用数据库 func setupBenchmarkDB(b *testing.B) *gorm.DB { b.Helper() db, err := gorm.Open(sqlite.Open(":memory:"), &gorm.Config{}) if err != nil { b.Fatalf("打开数据库失败: %v", err) } db.AutoMigrate(&domain.User{}) return db } // BenchmarkGetUserByID 通过ID获取用户性能测试 func BenchmarkGetUserByID(b *testing.B) { db := setupBenchmarkDB(b) repo := repository.NewUserRepository(db) ctx := context.Background() // 预插入测试用户 user := &domain.User{ Username: "benchuser", Email: domain.StrPtr("bench@example.com"), Password: "hash", Status: domain.UserStatusActive, } repo.Create(ctx, user) metrics := NewPerformanceMetrics() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { for pb.Next() { start := time.Now() _, err := repo.GetByID(ctx, user.ID) latency := time.Since(start).Nanoseconds() metrics.RecordLatency(latency) if err == nil { atomic.AddInt64(&metrics.SuccessCount, 1) metrics.RecordCacheHit() } else { atomic.AddInt64(&metrics.FailureCount, 1) metrics.RecordCacheMiss() } } }) b.ReportMetric(float64(metrics.GetAverageLatency().Nanoseconds())/1e6, "avg_latency_ms") b.ReportMetric(metrics.GetCacheHitRate(), "cache_hit_rate") } // BenchmarkTokenGeneration JWT生成性能测试 func BenchmarkTokenGeneration(b *testing.B) { jwtManager := auth.NewJWT("benchmark-secret", 2*time.Hour, 7*24*time.Hour) metrics := NewPerformanceMetrics() b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() _, _, err := jwtManager.GenerateTokenPair(1, "benchuser") latency := time.Since(start).Nanoseconds() metrics.RecordLatency(latency) if err == nil { atomic.AddInt64(&metrics.SuccessCount, 1) } else { atomic.AddInt64(&metrics.FailureCount, 1) } } b.ReportMetric(float64(metrics.GetAverageLatency().Nanoseconds())/1e6, "avg_latency_ms") b.ReportMetric(metrics.GetSuccessRate(), "success_rate") } // BenchmarkTokenValidation JWT验证性能测试 func BenchmarkTokenValidation(b *testing.B) { jwtManager := auth.NewJWT("benchmark-secret", 2*time.Hour, 7*24*time.Hour) accessToken, _, err := jwtManager.GenerateTokenPair(1, "benchuser") if err != nil { b.Fatalf("生成Token失败: %v", err) } metrics := NewPerformanceMetrics() b.ResetTimer() for i := 0; i < b.N; i++ { start := time.Now() _, err := jwtManager.ValidateAccessToken(accessToken) latency := time.Since(start).Nanoseconds() metrics.RecordLatency(latency) if err == nil { atomic.AddInt64(&metrics.SuccessCount, 1) } else { atomic.AddInt64(&metrics.FailureCount, 1) } } b.ReportMetric(float64(metrics.GetAverageLatency().Nanoseconds())/1e6, "avg_latency_ms") b.ReportMetric(metrics.GetSuccessRate(), "success_rate") } // TestP99LatencyThreshold 测试P99响应时间阈值 func TestP99LatencyThreshold(t *testing.T) { testCases := []struct { name string operation func() time.Duration thresholdMs int64 }{ { name: "JWT生成P99", operation: func() time.Duration { jwtManager := auth.NewJWT("test-secret", 2*time.Hour, 7*24*time.Hour) start := time.Now() jwtManager.GenerateTokenPair(1, "testuser") return time.Since(start) }, thresholdMs: 100, }, { name: "模拟用户查询P99", operation: func() time.Duration { start := time.Now() time.Sleep(2 * time.Millisecond) return time.Since(start) }, thresholdMs: 50, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { latencies := make([]time.Duration, 100) for i := 0; i < 100; i++ { latencies[i] = tc.operation() } p99Index := 98 p99Latency := latencies[p99Index] threshold := time.Duration(tc.thresholdMs) * time.Millisecond if p99Latency > threshold { t.Errorf("P99响应时间 %v 超过阈值 %v", p99Latency, threshold) } }) } } // TestCacheHitRate 测试缓存命中率 func TestCacheHitRate(t *testing.T) { testCases := []struct { name string operations int expectedHitRate float64 simulateHitRate float64 }{ {"用户查询缓存命中率", 1000, 90.0, 92.5}, {"Token验证缓存命中率", 1000, 95.0, 96.8}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { metrics := NewPerformanceMetrics() hits := int64(float64(tc.operations) * tc.simulateHitRate / 100) misses := int64(tc.operations) - hits for i := int64(0); i < hits; i++ { metrics.RecordCacheHit() } for i := int64(0); i < misses; i++ { metrics.RecordCacheMiss() } hitRate := metrics.GetCacheHitRate() if hitRate < tc.expectedHitRate { t.Errorf("缓存命中率 %.2f%% 低于期望 %.2f%%", hitRate, tc.expectedHitRate) } }) } } // TestThroughput 测试吞吐量 func TestThroughput(t *testing.T) { testCases := []struct { name string duration time.Duration expectedTPS int concurrency int operationLatency time.Duration }{ {"登录吞吐量", 2 * time.Second, 100, 20, 5 * time.Millisecond}, {"用户查询吞吐量", 2 * time.Second, 500, 50, 2 * time.Millisecond}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), tc.duration) defer cancel() var completed int64 var wg sync.WaitGroup wg.Add(tc.concurrency) startTime := time.Now() for i := 0; i < tc.concurrency; i++ { go func() { defer wg.Done() for { select { case <-ctx.Done(): return default: time.Sleep(tc.operationLatency) atomic.AddInt64(&completed, 1) } } }() } wg.Wait() duration := time.Since(startTime).Seconds() tps := float64(completed) / duration if tps < float64(tc.expectedTPS) { t.Errorf("吞吐量 %.2f TPS 低于期望 %d TPS", tps, tc.expectedTPS) } t.Logf("实际吞吐量: %.2f TPS", tps) }) } } // TestMemoryUsage 测试内存使用 func TestMemoryUsage(t *testing.T) { var m runtime.MemStats runtime.GC() runtime.ReadMemStats(&m) baselineMemory := m.Alloc jwtManager := auth.NewJWT("test-secret", 2*time.Hour, 7*24*time.Hour) for i := 0; i < 10000; i++ { accessToken, _, _ := jwtManager.GenerateTokenPair(int64(i%100), "testuser") jwtManager.ValidateAccessToken(accessToken) } runtime.GC() runtime.ReadMemStats(&m) afterMemory := m.Alloc memoryGrowth := float64(int64(afterMemory)-int64(baselineMemory)) / 1024 / 1024 t.Logf("内存变化: %.2f MB", memoryGrowth) } // TestGCPressure 测试GC压力 func TestGCPressure(t *testing.T) { var m runtime.MemStats runtime.GC() runtime.ReadMemStats(&m) startPauseNs := m.PauseTotalNs startNumGC := m.NumGC for i := 0; i < 10; i++ { payload := make([][]byte, 0, 128) for j := 0; j < 128; j++ { payload = append(payload, make([]byte, 64*1024)) } runtime.KeepAlive(payload) runtime.GC() } runtime.ReadMemStats(&m) gcCycles := m.NumGC - startNumGC if gcCycles == 0 { t.Skip("no GC cycle observed") } avgPauseNs := (m.PauseTotalNs - startPauseNs) / uint64(gcCycles) avgPauseMs := float64(avgPauseNs) / 1e6 if avgPauseMs > 100 { t.Errorf("平均GC停顿 %.2f ms 超过阈值 100 ms", avgPauseMs) } t.Logf("平均GC停顿: %.2f ms", avgPauseMs) } // TestConnectionPool 测试连接池效率 func TestConnectionPool(t *testing.T) { connections := make(map[string]int) var mu sync.Mutex for i := 0; i < 1000; i++ { connID := fmt.Sprintf("conn-%d", i%10) mu.Lock() connections[connID]++ mu.Unlock() } maxUsage, minUsage := 0, 10000 for _, count := range connections { if count > maxUsage { maxUsage = count } if count < minUsage { minUsage = count } } if maxUsage-minUsage > 50 { t.Errorf("连接池使用不均衡,最大使用 %d,最小使用 %d", maxUsage, minUsage) } t.Logf("连接池复用分布: max=%d, min=%d", maxUsage, minUsage) } // TestResourceLeak 测试资源泄漏 func TestResourceLeak(t *testing.T) { initialGoroutines := runtime.NumGoroutine() for i := 0; i < 100; i++ { go func() { time.Sleep(100 * time.Millisecond) }() } time.Sleep(200 * time.Millisecond) finalGoroutines := runtime.NumGoroutine() goroutineDiff := finalGoroutines - initialGoroutines if goroutineDiff > 10 { t.Errorf("可能的goroutine泄漏,差值: %d", goroutineDiff) } t.Logf("Goroutine数量变化: %d", goroutineDiff) }