fix: 修复6个代码质量问题
P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key
This commit is contained in:
@@ -3,10 +3,12 @@ package ratelimit
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// Algorithm 限流算法
|
||||
@@ -278,7 +280,7 @@ func (l *SlidingWindowLimiter) cleanup() {
|
||||
validRequests = append(validRequests, t)
|
||||
}
|
||||
}
|
||||
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||
if len(validRequests) == 0 && len(window.requests) > 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||
delete(l.windows, key)
|
||||
} else {
|
||||
window.requests = validRequests
|
||||
@@ -301,14 +303,14 @@ func NewMiddleware(limiter Limiter) *Middleware {
|
||||
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 使用API Key作为限流key
|
||||
key := r.Header.Get("Authorization")
|
||||
key := extractRateLimitKey(r)
|
||||
if key == "" {
|
||||
key = r.RemoteAddr
|
||||
}
|
||||
|
||||
allowed, err := m.limiter.Allow(r.Context(), key)
|
||||
if err != nil {
|
||||
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||
writeError(w, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -318,7 +320,7 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
||||
|
||||
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||
writeError(w, gwerror.NewGatewayError(gwerror.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -326,9 +328,27 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
import "net/http"
|
||||
// extractRateLimitKey 从请求中提取限流key
|
||||
func extractRateLimitKey(r *http.Request) string {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, err *error.GatewayError) {
|
||||
// 如果是Bearer token,提取token部分
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token = strings.TrimSpace(token)
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
// 否则返回原始header(不应该发生)
|
||||
return authHeader
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, err *gwerror.GatewayError) {
|
||||
info := err.GetErrorInfo()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(info.HTTPStatus)
|
||||
|
||||
333
gateway/internal/ratelimit/ratelimit_test.go
Normal file
333
gateway/internal/ratelimit/ratelimit_test.go
Normal file
@@ -0,0 +1,333 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestTokenBucketLimiter(t *testing.T) {
|
||||
t.Run("allows requests within limit", func(t *testing.T) {
|
||||
limiter := NewTokenBucketLimiter(60, 60000, 1.5) // 60 RPM
|
||||
ctx := context.Background()
|
||||
|
||||
// Should allow multiple requests
|
||||
for i := 0; i < 5; i++ {
|
||||
allowed, err := limiter.Allow(ctx, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !allowed {
|
||||
t.Errorf("request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blocks requests over limit", func(t *testing.T) {
|
||||
// Use very low limits for testing
|
||||
limiter := &TokenBucketLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
defaultRPM: 2,
|
||||
defaultTPM: 100,
|
||||
burstMultiplier: 1.0,
|
||||
cleanInterval: 10 * time.Minute,
|
||||
}
|
||||
// Pre-fill the bucket to capacity
|
||||
key := "test-key"
|
||||
bucket := limiter.newBucket(2, 100)
|
||||
limiter.buckets[key] = bucket
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// First two should be allowed
|
||||
allowed, _ := limiter.Allow(ctx, key)
|
||||
if !allowed {
|
||||
t.Error("first request should be allowed")
|
||||
}
|
||||
|
||||
allowed, _ = limiter.Allow(ctx, key)
|
||||
if !allowed {
|
||||
t.Error("second request should be allowed")
|
||||
}
|
||||
|
||||
// Third should be blocked
|
||||
allowed, _ = limiter.Allow(ctx, key)
|
||||
if allowed {
|
||||
t.Error("third request should be blocked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("refills tokens over time", func(t *testing.T) {
|
||||
limiter := &TokenBucketLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
defaultRPM: 60,
|
||||
defaultTPM: 60000,
|
||||
burstMultiplier: 1.0,
|
||||
cleanInterval: 10 * time.Minute,
|
||||
}
|
||||
key := "test-key"
|
||||
|
||||
// Consume all tokens
|
||||
for i := 0; i < 60; i++ {
|
||||
limiter.Allow(context.Background(), key)
|
||||
}
|
||||
|
||||
// Should be blocked now
|
||||
allowed, _ := limiter.Allow(context.Background(), key)
|
||||
if allowed {
|
||||
t.Error("should be blocked after consuming all tokens")
|
||||
}
|
||||
|
||||
// Manually backdate the refill time to simulate time passing
|
||||
limiter.buckets[key].lastRefill = time.Now().Add(-2 * time.Minute)
|
||||
|
||||
// Should allow again after time-based refill
|
||||
allowed, _ = limiter.Allow(context.Background(), key)
|
||||
if !allowed {
|
||||
t.Error("should allow after token refill")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("separate buckets for different keys", func(t *testing.T) {
|
||||
limiter := NewTokenBucketLimiter(2, 100, 1.0)
|
||||
ctx := context.Background()
|
||||
|
||||
// Exhaust key1
|
||||
limiter.Allow(ctx, "key1")
|
||||
limiter.Allow(ctx, "key1")
|
||||
|
||||
// key1 should be blocked
|
||||
allowed, _ := limiter.Allow(ctx, "key1")
|
||||
if allowed {
|
||||
t.Error("key1 should be rate limited")
|
||||
}
|
||||
|
||||
// key2 should still work
|
||||
allowed, _ = limiter.Allow(ctx, "key2")
|
||||
if !allowed {
|
||||
t.Error("key2 should be allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get limit returns correct values", func(t *testing.T) {
|
||||
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||
limiter.Allow(context.Background(), "test-key")
|
||||
|
||||
limit := limiter.GetLimit("test-key")
|
||||
if limit.RPM != 60 {
|
||||
t.Errorf("expected RPM 60, got %d", limit.RPM)
|
||||
}
|
||||
if limit.TPM != 60000 {
|
||||
t.Errorf("expected TPM 60000, got %d", limit.TPM)
|
||||
}
|
||||
if limit.Burst != 90 { // 60 * 1.5
|
||||
t.Errorf("expected Burst 90, got %d", limit.Burst)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSlidingWindowLimiter(t *testing.T) {
|
||||
t.Run("allows requests within window", func(t *testing.T) {
|
||||
limiter := NewSlidingWindowLimiter(time.Minute, 5)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
allowed, err := limiter.Allow(ctx, "test-key")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if !allowed {
|
||||
t.Errorf("request %d should be allowed", i+1)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blocks requests over window limit", func(t *testing.T) {
|
||||
limiter := NewSlidingWindowLimiter(time.Minute, 2)
|
||||
ctx := context.Background()
|
||||
|
||||
limiter.Allow(ctx, "test-key")
|
||||
limiter.Allow(ctx, "test-key")
|
||||
|
||||
allowed, _ := limiter.Allow(ctx, "test-key")
|
||||
if allowed {
|
||||
t.Error("third request should be blocked")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sliding window respects time", func(t *testing.T) {
|
||||
limiter := &SlidingWindowLimiter{
|
||||
windows: make(map[string]*slidingWindow),
|
||||
windowSize: time.Minute,
|
||||
maxRequests: 2,
|
||||
cleanInterval: 10 * time.Minute,
|
||||
}
|
||||
ctx := context.Background()
|
||||
key := "test-key"
|
||||
|
||||
// Make requests
|
||||
limiter.Allow(ctx, key)
|
||||
limiter.Allow(ctx, key)
|
||||
|
||||
// Should be blocked
|
||||
allowed, _ := limiter.Allow(ctx, key)
|
||||
if allowed {
|
||||
t.Error("should be blocked after reaching limit")
|
||||
}
|
||||
|
||||
// Simulate time passing - move window forward
|
||||
limiter.windows[key].requests[0] = time.Now().Add(-2 * time.Minute)
|
||||
limiter.windows[key].requests[1] = time.Now().Add(-2 * time.Minute)
|
||||
|
||||
// Should allow now
|
||||
allowed, _ = limiter.Allow(ctx, key)
|
||||
if !allowed {
|
||||
t.Error("should allow after old requests expire from window")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("separate windows for different keys", func(t *testing.T) {
|
||||
limiter := NewSlidingWindowLimiter(time.Minute, 1)
|
||||
ctx := context.Background()
|
||||
|
||||
limiter.Allow(ctx, "key1")
|
||||
|
||||
allowed, _ := limiter.Allow(ctx, "key1")
|
||||
if allowed {
|
||||
t.Error("key1 should be rate limited")
|
||||
}
|
||||
|
||||
allowed, _ = limiter.Allow(ctx, "key2")
|
||||
if !allowed {
|
||||
t.Error("key2 should be allowed")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("get limit returns correct remaining", func(t *testing.T) {
|
||||
limiter := NewSlidingWindowLimiter(time.Minute, 10)
|
||||
ctx := context.Background()
|
||||
|
||||
limiter.Allow(ctx, "test-key")
|
||||
limiter.Allow(ctx, "test-key")
|
||||
limiter.Allow(ctx, "test-key")
|
||||
|
||||
limit := limiter.GetLimit("test-key")
|
||||
if limit.RPM != 10 {
|
||||
t.Errorf("expected RPM 10, got %d", limit.RPM)
|
||||
}
|
||||
if limit.Remaining != 7 {
|
||||
t.Errorf("expected Remaining 7, got %d", limit.Remaining)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
t.Run("allows request when under limit", func(t *testing.T) {
|
||||
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||
middleware := NewMiddleware(limiter)
|
||||
|
||||
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sets rate limit headers when blocked", func(t *testing.T) {
|
||||
// Use very low limit so request is blocked
|
||||
limiter := &TokenBucketLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
defaultRPM: 1,
|
||||
defaultTPM: 100,
|
||||
burstMultiplier: 1.0,
|
||||
cleanInterval: 10 * time.Minute,
|
||||
}
|
||||
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
||||
key := "test-token"
|
||||
bucket := limiter.newBucket(1, 100)
|
||||
bucket.tokens = 0
|
||||
limiter.buckets[key] = bucket
|
||||
|
||||
middleware := NewMiddleware(limiter)
|
||||
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
// Headers should be set when rate limited
|
||||
if rr.Header().Get("X-RateLimit-Limit") == "" {
|
||||
t.Error("expected X-RateLimit-Limit header to be set")
|
||||
}
|
||||
if rr.Header().Get("X-RateLimit-Remaining") == "" {
|
||||
t.Error("expected X-RateLimit-Remaining header to be set")
|
||||
}
|
||||
if rr.Header().Get("X-RateLimit-Reset") == "" {
|
||||
t.Error("expected X-RateLimit-Reset header to be set")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("blocks request when over limit", func(t *testing.T) {
|
||||
// Use very low limit
|
||||
limiter := &TokenBucketLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
defaultRPM: 1,
|
||||
defaultTPM: 100,
|
||||
burstMultiplier: 1.0,
|
||||
cleanInterval: 10 * time.Minute,
|
||||
}
|
||||
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
||||
key := "test-token"
|
||||
bucket := limiter.newBucket(1, 100)
|
||||
bucket.tokens = 0 // Exhaust
|
||||
limiter.buckets[key] = bucket
|
||||
|
||||
middleware := NewMiddleware(limiter)
|
||||
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+key)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected status 429, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses remote addr when no auth header", func(t *testing.T) {
|
||||
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||
middleware := NewMiddleware(limiter)
|
||||
|
||||
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
// No Authorization header
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user