Files
lijiaoqiao/supply-api/internal/middleware/ratelimit_test.go
Your Name 8ac23bf7d4 test: improve coverage and fix sanitizer bug
- Fix MaskMap to properly handle []string sensitive fields
- Add missing slice handling in sanitizer
- Add comprehensive tests for GetMetrics and CreateEventsBatch
- Improve audit/handler coverage from 49.8% to 68.8%
- Fix test expectations to match actual sanitizer behavior
- All tests pass
2026-04-08 07:44:58 +08:00

539 lines
12 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"net/http"
"net/http/httptest"
"sync"
"testing"
"time"
)
// TestP005_TokenBucketAlgorithm 验证令牌桶算法
func TestP005_TokenBucketAlgorithm(t *testing.T) {
bucket := NewTokenBucket(10, 1) // 容量10补充速率1/秒
// 初始有10个令牌
if !bucket.Allow() {
t.Error("first 10 requests should be allowed")
}
// 消耗完令牌
for i := 0; i < 9; i++ {
bucket.Allow()
}
// 第11个请求应该被拒绝没有令牌
if bucket.Allow() {
t.Error("request beyond capacity should be denied")
}
t.Log("P0-05: 令牌桶容量验证通过")
}
// TestP005_TokenBucketRefill 验证令牌补充
func TestP005_TokenBucketRefill(t *testing.T) {
bucket := NewTokenBucket(5, 100) // 容量5补充速率100/秒
// 消耗所有令牌
for i := 0; i < 5; i++ {
bucket.Allow()
}
// 应该没有令牌了
if bucket.Allow() {
t.Error("bucket should be empty")
}
// 等待20ms应该补充2个令牌 (100/秒 = 1/10ms, 20ms = 2)
time.Sleep(20 * time.Millisecond)
if !bucket.Allow() {
t.Error("after refill, request should be allowed")
}
t.Log("P0-05: 令牌补充验证通过")
}
// TestP005_RateLimitHeaders 验证限流响应头
func TestP005_RateLimitHeaders(t *testing.T) {
config := RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
HeaderType: "X-RateLimit",
}
handler := &RateLimitMiddleware{
config: &config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}),
}
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
// 验证响应头存在
if w.Header().Get("X-RateLimit-Limit") == "" {
t.Error("X-RateLimit-Limit header should be present")
}
if w.Header().Get("X-RateLimit-Remaining") == "" {
t.Error("X-RateLimit-Remaining header should be present")
}
if w.Header().Get("X-RateLimit-Reset") == "" {
t.Error("X-RateLimit-Reset header should be present")
}
t.Log("P0-05: 限流响应头验证通过")
}
// TestP005_MultiDimensionRateLimit 验证多维度限流
func TestP005_MultiDimensionRateLimit(t *testing.T) {
// 模拟多租户限流
tenantLimits := map[string]*TokenBucket{
"tenant_1": NewTokenBucket(100, 10),
"tenant_2": NewTokenBucket(50, 5),
}
// tenant_1 100个请求应该通过
for i := 0; i < 100; i++ {
if !tenantLimits["tenant_1"].Allow() {
t.Errorf("tenant_1 request %d should be allowed", i)
}
}
// tenant_1 第101个应该拒绝
if tenantLimits["tenant_1"].Allow() {
t.Error("tenant_1 exceeded limit")
}
// tenant_2 50个请求应该通过
for i := 0; i < 50; i++ {
if !tenantLimits["tenant_2"].Allow() {
t.Errorf("tenant_2 request %d should be allowed", i)
}
}
// tenant_2 第51个应该拒绝
if tenantLimits["tenant_2"].Allow() {
t.Error("tenant_2 exceeded limit")
}
t.Log("P0-05: 多维度限流验证通过")
}
// TestP005_RateLimitConcurrency 验证并发安全性
func TestP005_RateLimitConcurrency(t *testing.T) {
bucket := NewTokenBucket(100, 0) // 容量100不补充
var allowed int
var mu sync.Mutex
var wg sync.WaitGroup
// 200个并发请求
for i := 0; i < 200; i++ {
wg.Add(1)
go func() {
defer wg.Done()
if bucket.Allow() {
mu.Lock()
allowed++
mu.Unlock()
}
}()
}
wg.Wait()
// 应该只有100个通过
if allowed != 100 {
t.Errorf("expected 100 allowed, got %d", allowed)
}
t.Log("P0-05: 并发安全性验证通过")
}
// TestP005_DegradationOnLimitExceeded 验证限流后降级
func TestP005_DegradationOnLimitExceeded(t *testing.T) {
config := RateLimitConfig{
Enabled: true,
Requests: 10,
Window: time.Second,
DegradationEnabled: true,
DegradationHandler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte("rate limited"))
}),
}
handler := &RateLimitMiddleware{
config: &config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("ok"))
}),
}
// 消耗所有令牌
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
}
// 第11个请求应该返回限流响应
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", w.Code)
}
t.Log("P0-05: 限流降级验证通过")
}
// TestP005_RateLimitConfig 验证限流配置
func TestP005_RateLimitConfig(t *testing.T) {
config := DefaultRateLimitConfig()
if !config.Enabled {
t.Error("rate limit should be enabled by default")
}
if config.Requests != 1000 {
t.Errorf("expected default requests 1000, got %d", config.Requests)
}
if config.Window != time.Minute {
t.Errorf("expected default window 1 minute, got %v", config.Window)
}
t.Log("P0-05: 限流配置验证通过")
}
// TestP005_Summary 测试总结
func TestP005_Summary(t *testing.T) {
t.Log("=== P0-05 限流策略测试总结 ===")
t.Log("问题: PRD P0要求基础限流策略但所有技术文档均未定义限流算法")
t.Log("")
t.Log("修复方案:")
t.Log(" - 令牌桶算法 (Token Bucket)")
t.Log(" - 多维度限流 (tenant/user/IP/endpoint)")
t.Log(" - 滑动窗口 (Sliding Window)")
t.Log(" - 降级策略 (返回429或队列)")
}
// ==================== TokenBucket Reset Tests ====================
func TestTokenBucket_Reset(t *testing.T) {
bucket := NewTokenBucket(5, 1) // capacity 5, refill 1/second
// Consume some tokens
for i := 0; i < 3; i++ {
bucket.Allow()
}
// Verify tokens consumed
if bucket.Remaining() != 2 {
t.Errorf("expected 2 remaining after 3 allows, got %d", bucket.Remaining())
}
// Reset
bucket.Reset()
// Verify full capacity restored
if bucket.Remaining() != 5 {
t.Errorf("expected 5 remaining after reset, got %d", bucket.Remaining())
}
}
// TestTokenBucket_RefillOverflow tests the refill logic when tokens exceed capacity
func TestTokenBucket_RefillOverflow(t *testing.T) {
bucket := NewTokenBucket(5, 100) // capacity 5, refill 100/second
// Consume all tokens
for i := 0; i < 5; i++ {
bucket.Allow()
}
// Remaining should be 0
if bucket.Remaining() != 0 {
t.Errorf("expected 0 remaining, got %d", bucket.Remaining())
}
// Wait for refill - should get more than capacity
time.Sleep(50 * time.Millisecond) // 100/second = 5 tokens in 50ms
// Remaining should be capped at capacity (5)
remaining := bucket.Remaining()
if remaining != 5 {
t.Errorf("expected 5 (capped at capacity), got %d", remaining)
}
}
// TestGetBucket_NewBucketCreation tests bucket creation on first access
func TestGetBucket_NewBucketCreation(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 50,
Window: time.Minute,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
}
// First access - should create new bucket
bucket1 := rl.getBucket("new-key")
if bucket1 == nil {
t.Fatal("expected non-nil bucket")
}
// Second access - should return same bucket
bucket2 := rl.getBucket("new-key")
if bucket1 != bucket2 {
t.Error("expected same bucket for same key")
}
// Different key - should create new bucket
bucket3 := rl.getBucket("different-key")
if bucket3 == bucket1 {
t.Error("expected different bucket for different key")
}
}
// ==================== getRateLimitKey Tests ====================
func TestGetRateLimitKey_Default(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
// No LimitBy... set
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/test", nil)
key := rl.getRateLimitKey(req)
if key != "default" {
t.Errorf("expected 'default', got '%s'", key)
}
}
func TestGetRateLimitKey_ByTenant(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
LimitByTenant: true,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/test", nil)
key := rl.getRateLimitKey(req)
if key != "tenant:0" {
t.Errorf("expected 'tenant:0', got '%s'", key)
}
}
func TestGetRateLimitKey_ByUser(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
LimitByUser: true,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/test", nil)
key := rl.getRateLimitKey(req)
if key != "user:unknown" {
t.Errorf("expected 'user:unknown', got '%s'", key)
}
}
func TestGetRateLimitKey_ByIP(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
LimitByIP: true,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/test", nil)
key := rl.getRateLimitKey(req)
if key == "" {
t.Error("expected non-empty key")
}
}
func TestGetRateLimitKey_ByEndpoint(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
LimitByEndpoint: true,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/api/v1/test", nil)
key := rl.getRateLimitKey(req)
if key != "/api/v1/test" {
t.Errorf("expected '/api/v1/test', got '%s'", key)
}
}
func TestGetRateLimitKey_MultipleDimensions(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 100,
Window: time.Minute,
LimitByTenant: true,
LimitByUser: true,
LimitByIP: true,
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
}
req := httptest.NewRequest("GET", "/test", nil)
key := rl.getRateLimitKey(req)
// Should contain multiple parts
if key == "default" {
t.Error("expected multi-part key")
}
if key == "" {
t.Error("expected non-empty key")
}
}
// ==================== handleRateLimitExceeded Tests ====================
func TestHandleRateLimitExceeded_FallbackCode(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 10,
Window: time.Second,
FallbackCode: http.StatusServiceUnavailable,
// DegradationEnabled = false
}
rl := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Error("next handler should not be called")
}),
}
bucket := NewTokenBucket(10, 0)
// Exhaust all tokens
for i := 0; i < 10; i++ {
bucket.Allow()
}
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
// Directly call handleRateLimitExceeded
rl.handleRateLimitExceeded(w, req, bucket)
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected status %d, got %d", http.StatusServiceUnavailable, w.Code)
}
}
func TestHandleRateLimitExceeded_WithoutDegradation(t *testing.T) {
config := &RateLimitConfig{
Enabled: true,
Requests: 1,
Window: time.Second,
FallbackCode: http.StatusTooManyRequests,
// DegradationEnabled = false by default
}
handler := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}),
}
// First request should succeed
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("first request: expected 200, got %d", w.Code)
}
// Second request should be rate limited
w = httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusTooManyRequests {
t.Errorf("second request: expected 429, got %d", w.Code)
}
}
func TestServeHTTP_Disabled(t *testing.T) {
config := &RateLimitConfig{
Enabled: false, // Disabled
}
nextCalled := false
handler := &RateLimitMiddleware{
config: config,
buckets: make(map[string]*TokenBucket),
next: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
w.WriteHeader(http.StatusOK)
}),
}
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !nextCalled {
t.Error("next handler should be called when rate limit is disabled")
}
}