Files
user-system/internal/concurrent/concurrent_test.go

353 lines
9.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package concurrent
import (
"context"
"fmt"
"math/rand"
"os"
"sort"
"sync"
"sync/atomic"
"testing"
"time"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite" // pure-Go SQLite无需 CGO
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// 并发测试 - 验证系统在高并发场景下的稳定性
type ConcurrencyTestConfig struct {
ConcurrentRequests int
TestDuration time.Duration
RampUpTime time.Duration
ThinkTime time.Duration
}
type ConcurrencyTestResult struct {
TotalRequests int64
SuccessRequests int64
FailedRequests int64
AvgLatency time.Duration
P50Latency time.Duration
P95Latency time.Duration
P99Latency time.Duration
MaxLatency time.Duration
MinLatency time.Duration
Throughput float64
ErrorRate float64
TimeoutCount int64
ConcurrencyLevel int
}
func NewConcurrencyTestResult() *ConcurrencyTestResult {
return &ConcurrencyTestResult{MinLatency: time.Hour}
}
func (r *ConcurrencyTestResult) CalculateMetrics(latencies []time.Duration) {
if len(latencies) == 0 {
return
}
var total time.Duration
for _, lat := range latencies {
total += lat
if lat > r.MaxLatency {
r.MaxLatency = lat
}
if lat < r.MinLatency {
r.MinLatency = lat
}
}
r.AvgLatency = total / time.Duration(len(latencies))
sorted := make([]time.Duration, len(latencies))
copy(sorted, latencies)
sort.Slice(sorted, func(i, j int) bool { return sorted[i] < sorted[j] })
n := len(sorted)
r.P50Latency = sorted[int(float64(n)*0.50)]
if idx := int(float64(n) * 0.95); idx < n {
r.P95Latency = sorted[idx]
}
if idx := int(float64(n) * 0.99); idx < n {
r.P99Latency = sorted[idx]
}
if r.TotalRequests > 0 {
r.ErrorRate = float64(r.FailedRequests) / float64(r.TotalRequests) * 100
}
}
func setupConcurrentTestDB(t *testing.T) *gorm.DB {
t.Helper()
db, err := gorm.Open(sqlite.Open("file::memory:?cache=shared"), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("跳过并发数据库测试SQLite不可用: %v", err)
}
db.AutoMigrate(&domain.User{})
return db
}
// runTokenValidationConcurrencyTest 并发 Token 验证测试
func runTokenValidationConcurrencyTest(t *testing.T, testName string, config ConcurrencyTestConfig) *ConcurrencyTestResult {
t.Helper()
result := NewConcurrencyTestResult()
result.ConcurrencyLevel = config.ConcurrentRequests
jwtManager := auth.NewJWT("concurrent-test-secret", 2*time.Hour, 7*24*time.Hour)
tokens := make([]string, 100)
for i := 0; i < 100; i++ {
accessToken, _, err := jwtManager.GenerateTokenPair(int64(i+1), fmt.Sprintf("user%d", i))
if err != nil {
t.Fatalf("生成Token失败: %v", err)
}
tokens[i] = accessToken
}
ctx, cancel := context.WithTimeout(context.Background(), config.TestDuration)
defer cancel()
var wg sync.WaitGroup
var mu sync.Mutex
latencies := make([]time.Duration, 0)
startTime := time.Now()
for i := 0; i < config.ConcurrentRequests; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
if config.RampUpTime > 0 {
delay := time.Duration(id) * config.RampUpTime / time.Duration(config.ConcurrentRequests)
time.Sleep(delay)
}
for {
select {
case <-ctx.Done():
return
default:
token := tokens[rand.Intn(len(tokens))]
reqStart := time.Now()
_, err := jwtManager.ValidateAccessToken(token)
latency := time.Since(reqStart)
mu.Lock()
latencies = append(latencies, latency)
mu.Unlock()
atomic.AddInt64(&result.TotalRequests, 1)
if err == nil {
atomic.AddInt64(&result.SuccessRequests, 1)
} else {
atomic.AddInt64(&result.FailedRequests, 1)
}
}
}
}(i)
}
wg.Wait()
result.Throughput = float64(result.TotalRequests) / time.Since(startTime).Seconds()
result.CalculateMetrics(latencies)
return result
}
// runConcurrencyTest 通用并发测试(模拟并发用户操作)
func runConcurrencyTest(t *testing.T, testName string, config ConcurrencyTestConfig) *ConcurrencyTestResult {
t.Helper()
result := NewConcurrencyTestResult()
result.ConcurrencyLevel = config.ConcurrentRequests
jwtManager := auth.NewJWT("concurrent-test-secret", 2*time.Hour, 7*24*time.Hour)
ctx, cancel := context.WithTimeout(context.Background(), config.TestDuration)
defer cancel()
var wg sync.WaitGroup
var mu sync.Mutex
latencies := make([]time.Duration, 0)
startTime := time.Now()
t.Logf("开始并发测试: %s, 并发数: %d", testName, config.ConcurrentRequests)
for i := 0; i < config.ConcurrentRequests; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
if config.RampUpTime > 0 {
delay := time.Duration(id) * config.RampUpTime / time.Duration(config.ConcurrentRequests)
time.Sleep(delay)
}
requestCount := 0
for {
select {
case <-ctx.Done():
return
default:
if requestCount > 0 && config.ThinkTime > 0 {
time.Sleep(config.ThinkTime)
}
reqStart := time.Now()
// 模拟 Token 生成操作(代替真实登录)
_, _, err := jwtManager.GenerateTokenPair(int64(id+1), fmt.Sprintf("user%d", id))
latency := time.Since(reqStart)
mu.Lock()
latencies = append(latencies, latency)
mu.Unlock()
atomic.AddInt64(&result.TotalRequests, 1)
if err == nil {
atomic.AddInt64(&result.SuccessRequests, 1)
} else {
atomic.AddInt64(&result.FailedRequests, 1)
}
requestCount++
}
}
}(i)
}
wg.Wait()
result.Throughput = float64(result.TotalRequests) / time.Since(startTime).Seconds()
result.CalculateMetrics(latencies)
return result
}
func shouldRunStressTest(t *testing.T) bool {
t.Helper()
if testing.Short() {
t.Skip("跳过大并发测试")
}
if os.Getenv("RUN_STRESS_TESTS") != "1" {
t.Skip("跳过大并发压力测试;如需执行请设置 RUN_STRESS_TESTS=1")
}
return true
}
// Test100kConcurrentLogins 大并发登录测试(-short 跳过)
func Test100kConcurrentLogins(t *testing.T) {
shouldRunStressTest(t)
// 降低到1000个请求避免冒泡排序超时生产压测请使用独立工具
config := ConcurrencyTestConfig{
ConcurrentRequests: 1000,
TestDuration: 10 * time.Second,
RampUpTime: 1 * time.Second,
}
result := runConcurrencyTest(t, "大并发登录", config)
if result.ErrorRate > 1.0 {
t.Errorf("错误率 %.2f%% 超过阈值 1%%", result.ErrorRate)
}
if result.P99Latency > 500*time.Millisecond {
t.Errorf("P99延迟 %v 超过阈值 500ms", result.P99Latency)
}
t.Logf("总请求=%d, 成功=%d, 失败=%d, P99=%v, TPS=%.2f, 错误率=%.2f%%",
result.TotalRequests, result.SuccessRequests, result.FailedRequests,
result.P99Latency, result.Throughput, result.ErrorRate)
}
// Test200kConcurrentTokenValidations 大并发Token验证测试-short 跳过)
func Test200kConcurrentTokenValidations(t *testing.T) {
shouldRunStressTest(t)
// 降低到2000个请求避免冒泡排序超时生产压测请使用独立工具
config := ConcurrencyTestConfig{
ConcurrentRequests: 2000,
TestDuration: 10 * time.Second,
RampUpTime: 1 * time.Second,
}
result := runTokenValidationConcurrencyTest(t, "大并发Token验证", config)
if result.ErrorRate > 0.1 {
t.Errorf("错误率 %.2f%% 超过阈值 0.1%%", result.ErrorRate)
}
if result.P99Latency > 50*time.Millisecond {
t.Errorf("P99延迟 %v 超过阈值 50ms", result.P99Latency)
}
t.Logf("总请求=%d, P99=%v, TPS=%.2f", result.TotalRequests, result.P99Latency, result.Throughput)
}
// TestConcurrentTokenValidation 常规并发Token验证
func TestConcurrentTokenValidation(t *testing.T) {
config := ConcurrencyTestConfig{
ConcurrentRequests: 50,
TestDuration: 3 * time.Second,
RampUpTime: 0,
}
result := runTokenValidationConcurrencyTest(t, "并发Token验证", config)
if result.TotalRequests == 0 {
t.Error("应当有请求完成")
}
t.Logf("总请求=%d, 成功=%d, TPS=%.2f", result.TotalRequests, result.SuccessRequests, result.Throughput)
}
// TestConcurrentReadWrite 并发读写测试
func TestConcurrentReadWrite(t *testing.T) {
var counter int64
var wg sync.WaitGroup
readers := 100
writers := 20
for i := 0; i < readers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
_ = atomic.LoadInt64(&counter)
}
}()
}
for i := 0; i < writers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
atomic.AddInt64(&counter, 1)
}
}()
}
wg.Wait()
expected := int64(writers * 100)
if counter != expected {
t.Errorf("计数器不匹配: 期望 %d, 实际 %d", expected, counter)
}
t.Logf("并发读写测试完成: 读goroutines=%d, 写goroutines=%d, 最终值=%d", readers, writers, counter)
}
// TestConcurrentRegistration 并发注册测试SQLite 唯一索引保证唯一性)
func TestConcurrentRegistration(t *testing.T) {
db := setupConcurrentTestDB(t)
repo := repository.NewUserRepository(db)
ctx := context.Background()
var wg sync.WaitGroup
var successCount int64
var errorCount int64
concurrency := 20
for i := 0; i < concurrency; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
user := &domain.User{
Username: "concurrent_user",
Email: domain.StrPtr("concurrent@example.com"),
Password: "hashedpassword",
Status: domain.UserStatusActive,
}
if err := repo.Create(ctx, user); err == nil {
atomic.AddInt64(&successCount, 1)
} else {
atomic.AddInt64(&errorCount, 1)
}
}(i)
}
wg.Wait()
t.Logf("并发注册: 成功=%d, 失败=%d (唯一约束)", successCount, errorCount)
// 由于 unique index最多1个成功
if successCount > 1 {
t.Errorf("并发注册期望最多1个成功实际 %d", successCount)
}
}