P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key
334 lines
8.7 KiB
Go
334 lines
8.7 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestTokenBucketLimiter(t *testing.T) {
|
|
t.Run("allows requests within limit", func(t *testing.T) {
|
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5) // 60 RPM
|
|
ctx := context.Background()
|
|
|
|
// Should allow multiple requests
|
|
for i := 0; i < 5; i++ {
|
|
allowed, err := limiter.Allow(ctx, "test-key")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !allowed {
|
|
t.Errorf("request %d should be allowed", i+1)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("blocks requests over limit", func(t *testing.T) {
|
|
// Use very low limits for testing
|
|
limiter := &TokenBucketLimiter{
|
|
buckets: make(map[string]*tokenBucket),
|
|
defaultRPM: 2,
|
|
defaultTPM: 100,
|
|
burstMultiplier: 1.0,
|
|
cleanInterval: 10 * time.Minute,
|
|
}
|
|
// Pre-fill the bucket to capacity
|
|
key := "test-key"
|
|
bucket := limiter.newBucket(2, 100)
|
|
limiter.buckets[key] = bucket
|
|
|
|
ctx := context.Background()
|
|
|
|
// First two should be allowed
|
|
allowed, _ := limiter.Allow(ctx, key)
|
|
if !allowed {
|
|
t.Error("first request should be allowed")
|
|
}
|
|
|
|
allowed, _ = limiter.Allow(ctx, key)
|
|
if !allowed {
|
|
t.Error("second request should be allowed")
|
|
}
|
|
|
|
// Third should be blocked
|
|
allowed, _ = limiter.Allow(ctx, key)
|
|
if allowed {
|
|
t.Error("third request should be blocked")
|
|
}
|
|
})
|
|
|
|
t.Run("refills tokens over time", func(t *testing.T) {
|
|
limiter := &TokenBucketLimiter{
|
|
buckets: make(map[string]*tokenBucket),
|
|
defaultRPM: 60,
|
|
defaultTPM: 60000,
|
|
burstMultiplier: 1.0,
|
|
cleanInterval: 10 * time.Minute,
|
|
}
|
|
key := "test-key"
|
|
|
|
// Consume all tokens
|
|
for i := 0; i < 60; i++ {
|
|
limiter.Allow(context.Background(), key)
|
|
}
|
|
|
|
// Should be blocked now
|
|
allowed, _ := limiter.Allow(context.Background(), key)
|
|
if allowed {
|
|
t.Error("should be blocked after consuming all tokens")
|
|
}
|
|
|
|
// Manually backdate the refill time to simulate time passing
|
|
limiter.buckets[key].lastRefill = time.Now().Add(-2 * time.Minute)
|
|
|
|
// Should allow again after time-based refill
|
|
allowed, _ = limiter.Allow(context.Background(), key)
|
|
if !allowed {
|
|
t.Error("should allow after token refill")
|
|
}
|
|
})
|
|
|
|
t.Run("separate buckets for different keys", func(t *testing.T) {
|
|
limiter := NewTokenBucketLimiter(2, 100, 1.0)
|
|
ctx := context.Background()
|
|
|
|
// Exhaust key1
|
|
limiter.Allow(ctx, "key1")
|
|
limiter.Allow(ctx, "key1")
|
|
|
|
// key1 should be blocked
|
|
allowed, _ := limiter.Allow(ctx, "key1")
|
|
if allowed {
|
|
t.Error("key1 should be rate limited")
|
|
}
|
|
|
|
// key2 should still work
|
|
allowed, _ = limiter.Allow(ctx, "key2")
|
|
if !allowed {
|
|
t.Error("key2 should be allowed")
|
|
}
|
|
})
|
|
|
|
t.Run("get limit returns correct values", func(t *testing.T) {
|
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
|
limiter.Allow(context.Background(), "test-key")
|
|
|
|
limit := limiter.GetLimit("test-key")
|
|
if limit.RPM != 60 {
|
|
t.Errorf("expected RPM 60, got %d", limit.RPM)
|
|
}
|
|
if limit.TPM != 60000 {
|
|
t.Errorf("expected TPM 60000, got %d", limit.TPM)
|
|
}
|
|
if limit.Burst != 90 { // 60 * 1.5
|
|
t.Errorf("expected Burst 90, got %d", limit.Burst)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestSlidingWindowLimiter(t *testing.T) {
|
|
t.Run("allows requests within window", func(t *testing.T) {
|
|
limiter := NewSlidingWindowLimiter(time.Minute, 5)
|
|
ctx := context.Background()
|
|
|
|
for i := 0; i < 5; i++ {
|
|
allowed, err := limiter.Allow(ctx, "test-key")
|
|
if err != nil {
|
|
t.Fatalf("unexpected error: %v", err)
|
|
}
|
|
if !allowed {
|
|
t.Errorf("request %d should be allowed", i+1)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("blocks requests over window limit", func(t *testing.T) {
|
|
limiter := NewSlidingWindowLimiter(time.Minute, 2)
|
|
ctx := context.Background()
|
|
|
|
limiter.Allow(ctx, "test-key")
|
|
limiter.Allow(ctx, "test-key")
|
|
|
|
allowed, _ := limiter.Allow(ctx, "test-key")
|
|
if allowed {
|
|
t.Error("third request should be blocked")
|
|
}
|
|
})
|
|
|
|
t.Run("sliding window respects time", func(t *testing.T) {
|
|
limiter := &SlidingWindowLimiter{
|
|
windows: make(map[string]*slidingWindow),
|
|
windowSize: time.Minute,
|
|
maxRequests: 2,
|
|
cleanInterval: 10 * time.Minute,
|
|
}
|
|
ctx := context.Background()
|
|
key := "test-key"
|
|
|
|
// Make requests
|
|
limiter.Allow(ctx, key)
|
|
limiter.Allow(ctx, key)
|
|
|
|
// Should be blocked
|
|
allowed, _ := limiter.Allow(ctx, key)
|
|
if allowed {
|
|
t.Error("should be blocked after reaching limit")
|
|
}
|
|
|
|
// Simulate time passing - move window forward
|
|
limiter.windows[key].requests[0] = time.Now().Add(-2 * time.Minute)
|
|
limiter.windows[key].requests[1] = time.Now().Add(-2 * time.Minute)
|
|
|
|
// Should allow now
|
|
allowed, _ = limiter.Allow(ctx, key)
|
|
if !allowed {
|
|
t.Error("should allow after old requests expire from window")
|
|
}
|
|
})
|
|
|
|
t.Run("separate windows for different keys", func(t *testing.T) {
|
|
limiter := NewSlidingWindowLimiter(time.Minute, 1)
|
|
ctx := context.Background()
|
|
|
|
limiter.Allow(ctx, "key1")
|
|
|
|
allowed, _ := limiter.Allow(ctx, "key1")
|
|
if allowed {
|
|
t.Error("key1 should be rate limited")
|
|
}
|
|
|
|
allowed, _ = limiter.Allow(ctx, "key2")
|
|
if !allowed {
|
|
t.Error("key2 should be allowed")
|
|
}
|
|
})
|
|
|
|
t.Run("get limit returns correct remaining", func(t *testing.T) {
|
|
limiter := NewSlidingWindowLimiter(time.Minute, 10)
|
|
ctx := context.Background()
|
|
|
|
limiter.Allow(ctx, "test-key")
|
|
limiter.Allow(ctx, "test-key")
|
|
limiter.Allow(ctx, "test-key")
|
|
|
|
limit := limiter.GetLimit("test-key")
|
|
if limit.RPM != 10 {
|
|
t.Errorf("expected RPM 10, got %d", limit.RPM)
|
|
}
|
|
if limit.Remaining != 7 {
|
|
t.Errorf("expected Remaining 7, got %d", limit.Remaining)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestMiddleware(t *testing.T) {
|
|
t.Run("allows request when under limit", func(t *testing.T) {
|
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
|
middleware := NewMiddleware(limiter)
|
|
|
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer test-token")
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("sets rate limit headers when blocked", func(t *testing.T) {
|
|
// Use very low limit so request is blocked
|
|
limiter := &TokenBucketLimiter{
|
|
buckets: make(map[string]*tokenBucket),
|
|
defaultRPM: 1,
|
|
defaultTPM: 100,
|
|
burstMultiplier: 1.0,
|
|
cleanInterval: 10 * time.Minute,
|
|
}
|
|
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
|
key := "test-token"
|
|
bucket := limiter.newBucket(1, 100)
|
|
bucket.tokens = 0
|
|
limiter.buckets[key] = bucket
|
|
|
|
middleware := NewMiddleware(limiter)
|
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+key)
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
// Headers should be set when rate limited
|
|
if rr.Header().Get("X-RateLimit-Limit") == "" {
|
|
t.Error("expected X-RateLimit-Limit header to be set")
|
|
}
|
|
if rr.Header().Get("X-RateLimit-Remaining") == "" {
|
|
t.Error("expected X-RateLimit-Remaining header to be set")
|
|
}
|
|
if rr.Header().Get("X-RateLimit-Reset") == "" {
|
|
t.Error("expected X-RateLimit-Reset header to be set")
|
|
}
|
|
})
|
|
|
|
t.Run("blocks request when over limit", func(t *testing.T) {
|
|
// Use very low limit
|
|
limiter := &TokenBucketLimiter{
|
|
buckets: make(map[string]*tokenBucket),
|
|
defaultRPM: 1,
|
|
defaultTPM: 100,
|
|
burstMultiplier: 1.0,
|
|
cleanInterval: 10 * time.Minute,
|
|
}
|
|
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
|
key := "test-token"
|
|
bucket := limiter.newBucket(1, 100)
|
|
bucket.tokens = 0 // Exhaust
|
|
limiter.buckets[key] = bucket
|
|
|
|
middleware := NewMiddleware(limiter)
|
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
t.Error("handler should not be called")
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+key)
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("expected status 429, got %d", rr.Code)
|
|
}
|
|
})
|
|
|
|
t.Run("uses remote addr when no auth header", func(t *testing.T) {
|
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
|
middleware := NewMiddleware(limiter)
|
|
|
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest("GET", "/test", nil)
|
|
// No Authorization header
|
|
rr := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("expected status 200, got %d", rr.Code)
|
|
}
|
|
})
|
|
}
|