fix: 修复6个代码质量问题
P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key
This commit is contained in:
@@ -33,7 +33,7 @@ type Principal struct {
|
|||||||
// BuildTokenAuthChain 构建认证中间件链
|
// BuildTokenAuthChain 构建认证中间件链
|
||||||
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
|
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
|
||||||
handler := tokenAuthMiddleware(cfg)(next)
|
handler := tokenAuthMiddleware(cfg)(next)
|
||||||
handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now)
|
handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now, cfg.TrustedProxies)
|
||||||
handler = requestIDMiddleware(handler, cfg.Now)
|
handler = requestIDMiddleware(handler, cfg.Now)
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
@@ -54,7 +54,7 @@ func requestIDMiddleware(next http.Handler, now func() time.Time) http.Handler {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// queryKeyRejectMiddleware 拒绝query key入站
|
// queryKeyRejectMiddleware 拒绝query key入站
|
||||||
func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time) http.Handler {
|
func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time, trustedProxies []string) http.Handler {
|
||||||
if next == nil {
|
if next == nil {
|
||||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||||
}
|
}
|
||||||
@@ -69,7 +69,7 @@ func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func(
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: CodeQueryKeyNotAllowed,
|
ResultCode: CodeQueryKeyNotAllowed,
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, trustedProxies),
|
||||||
CreatedAt: now(),
|
CreatedAt: now(),
|
||||||
})
|
})
|
||||||
writeError(w, http.StatusUnauthorized, requestID, CodeQueryKeyNotAllowed, "query key not allowed")
|
writeError(w, http.StatusUnauthorized, requestID, CodeQueryKeyNotAllowed, "query key not allowed")
|
||||||
@@ -105,7 +105,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: CodeAuthMissingBearer,
|
ResultCode: CodeAuthMissingBearer,
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
||||||
CreatedAt: cfg.Now(),
|
CreatedAt: cfg.Now(),
|
||||||
})
|
})
|
||||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthMissingBearer, "missing bearer token")
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthMissingBearer, "missing bearer token")
|
||||||
@@ -119,7 +119,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl
|
|||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: CodeAuthInvalidToken,
|
ResultCode: CodeAuthInvalidToken,
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
||||||
CreatedAt: cfg.Now(),
|
CreatedAt: cfg.Now(),
|
||||||
})
|
})
|
||||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthInvalidToken, "invalid bearer token")
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthInvalidToken, "invalid bearer token")
|
||||||
@@ -135,7 +135,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl
|
|||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: CodeAuthTokenInactive,
|
ResultCode: CodeAuthTokenInactive,
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
||||||
CreatedAt: cfg.Now(),
|
CreatedAt: cfg.Now(),
|
||||||
})
|
})
|
||||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthTokenInactive, "token is inactive")
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthTokenInactive, "token is inactive")
|
||||||
@@ -150,7 +150,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl
|
|||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: CodeAuthScopeDenied,
|
ResultCode: CodeAuthScopeDenied,
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
||||||
CreatedAt: cfg.Now(),
|
CreatedAt: cfg.Now(),
|
||||||
})
|
})
|
||||||
writeError(w, http.StatusForbidden, requestID, CodeAuthScopeDenied, "scope denied")
|
writeError(w, http.StatusForbidden, requestID, CodeAuthScopeDenied, "scope denied")
|
||||||
@@ -174,7 +174,7 @@ func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handl
|
|||||||
SubjectID: claims.SubjectID,
|
SubjectID: claims.SubjectID,
|
||||||
Route: r.URL.Path,
|
Route: r.URL.Path,
|
||||||
ResultCode: "OK",
|
ResultCode: "OK",
|
||||||
ClientIP: extractClientIP(r),
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
||||||
CreatedAt: cfg.Now(),
|
CreatedAt: cfg.Now(),
|
||||||
})
|
})
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
@@ -297,15 +297,31 @@ func writeError(w http.ResponseWriter, status int, requestID, code, message stri
|
|||||||
_ = json.NewEncoder(w).Encode(payload)
|
_ = json.NewEncoder(w).Encode(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractClientIP(r *http.Request) string {
|
func extractClientIP(r *http.Request, trustedProxies []string) string {
|
||||||
|
// 检查请求是否来自可信代理
|
||||||
|
isFromTrustedProxy := false
|
||||||
|
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||||
|
if err == nil {
|
||||||
|
for _, proxy := range trustedProxies {
|
||||||
|
if remoteHost == proxy {
|
||||||
|
isFromTrustedProxy = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 只有来自可信代理的请求才使用X-Forwarded-For
|
||||||
|
if isFromTrustedProxy {
|
||||||
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
||||||
if xForwardedFor != "" {
|
if xForwardedFor != "" {
|
||||||
parts := strings.Split(xForwardedFor, ",")
|
parts := strings.Split(xForwardedFor, ",")
|
||||||
return strings.TrimSpace(parts[0])
|
return strings.TrimSpace(parts[0])
|
||||||
}
|
}
|
||||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
}
|
||||||
|
|
||||||
|
// 否则使用RemoteAddr
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return host
|
return remoteHost
|
||||||
}
|
}
|
||||||
return r.RemoteAddr
|
return r.RemoteAddr
|
||||||
}
|
}
|
||||||
@@ -87,4 +87,7 @@ type AuthMiddlewareConfig struct {
|
|||||||
ProtectedPrefixes []string
|
ProtectedPrefixes []string
|
||||||
ExcludedPrefixes []string
|
ExcludedPrefixes []string
|
||||||
Now func() time.Time
|
Now func() time.Time
|
||||||
|
// TrustedProxies 可信的代理IP列表,用于IP伪造防护
|
||||||
|
// 只有来自这些IP的请求才会使用X-Forwarded-For头
|
||||||
|
TrustedProxies []string
|
||||||
}
|
}
|
||||||
@@ -3,10 +3,12 @@ package ratelimit
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"lijiaoqiao/gateway/pkg/error"
|
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Algorithm 限流算法
|
// Algorithm 限流算法
|
||||||
@@ -278,7 +280,7 @@ func (l *SlidingWindowLimiter) cleanup() {
|
|||||||
validRequests = append(validRequests, t)
|
validRequests = append(validRequests, t)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
if len(validRequests) == 0 && len(window.requests) > 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||||
delete(l.windows, key)
|
delete(l.windows, key)
|
||||||
} else {
|
} else {
|
||||||
window.requests = validRequests
|
window.requests = validRequests
|
||||||
@@ -301,14 +303,14 @@ func NewMiddleware(limiter Limiter) *Middleware {
|
|||||||
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
// 使用API Key作为限流key
|
// 使用API Key作为限流key
|
||||||
key := r.Header.Get("Authorization")
|
key := extractRateLimitKey(r)
|
||||||
if key == "" {
|
if key == "" {
|
||||||
key = r.RemoteAddr
|
key = r.RemoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
allowed, err := m.limiter.Allow(r.Context(), key)
|
allowed, err := m.limiter.Allow(r.Context(), key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
writeError(w, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -318,7 +320,7 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
||||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
||||||
|
|
||||||
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
writeError(w, gwerror.NewGatewayError(gwerror.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -326,9 +328,27 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
import "net/http"
|
// extractRateLimitKey 从请求中提取限流key
|
||||||
|
func extractRateLimitKey(r *http.Request) string {
|
||||||
|
authHeader := r.Header.Get("Authorization")
|
||||||
|
if authHeader == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
func writeError(w http.ResponseWriter, err *error.GatewayError) {
|
// 如果是Bearer token,提取token部分
|
||||||
|
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||||
|
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||||
|
token = strings.TrimSpace(token)
|
||||||
|
if token != "" {
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 否则返回原始header(不应该发生)
|
||||||
|
return authHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, err *gwerror.GatewayError) {
|
||||||
info := err.GetErrorInfo()
|
info := err.GetErrorInfo()
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
w.WriteHeader(info.HTTPStatus)
|
w.WriteHeader(info.HTTPStatus)
|
||||||
|
|||||||
333
gateway/internal/ratelimit/ratelimit_test.go
Normal file
333
gateway/internal/ratelimit/ratelimit_test.go
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenBucketLimiter(t *testing.T) {
|
||||||
|
t.Run("allows requests within limit", func(t *testing.T) {
|
||||||
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5) // 60 RPM
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Should allow multiple requests
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
allowed, err := limiter.Allow(ctx, "test-key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks requests over limit", func(t *testing.T) {
|
||||||
|
// Use very low limits for testing
|
||||||
|
limiter := &TokenBucketLimiter{
|
||||||
|
buckets: make(map[string]*tokenBucket),
|
||||||
|
defaultRPM: 2,
|
||||||
|
defaultTPM: 100,
|
||||||
|
burstMultiplier: 1.0,
|
||||||
|
cleanInterval: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
// Pre-fill the bucket to capacity
|
||||||
|
key := "test-key"
|
||||||
|
bucket := limiter.newBucket(2, 100)
|
||||||
|
limiter.buckets[key] = bucket
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// First two should be allowed
|
||||||
|
allowed, _ := limiter.Allow(ctx, key)
|
||||||
|
if !allowed {
|
||||||
|
t.Error("first request should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, _ = limiter.Allow(ctx, key)
|
||||||
|
if !allowed {
|
||||||
|
t.Error("second request should be allowed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third should be blocked
|
||||||
|
allowed, _ = limiter.Allow(ctx, key)
|
||||||
|
if allowed {
|
||||||
|
t.Error("third request should be blocked")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("refills tokens over time", func(t *testing.T) {
|
||||||
|
limiter := &TokenBucketLimiter{
|
||||||
|
buckets: make(map[string]*tokenBucket),
|
||||||
|
defaultRPM: 60,
|
||||||
|
defaultTPM: 60000,
|
||||||
|
burstMultiplier: 1.0,
|
||||||
|
cleanInterval: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
key := "test-key"
|
||||||
|
|
||||||
|
// Consume all tokens
|
||||||
|
for i := 0; i < 60; i++ {
|
||||||
|
limiter.Allow(context.Background(), key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be blocked now
|
||||||
|
allowed, _ := limiter.Allow(context.Background(), key)
|
||||||
|
if allowed {
|
||||||
|
t.Error("should be blocked after consuming all tokens")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manually backdate the refill time to simulate time passing
|
||||||
|
limiter.buckets[key].lastRefill = time.Now().Add(-2 * time.Minute)
|
||||||
|
|
||||||
|
// Should allow again after time-based refill
|
||||||
|
allowed, _ = limiter.Allow(context.Background(), key)
|
||||||
|
if !allowed {
|
||||||
|
t.Error("should allow after token refill")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("separate buckets for different keys", func(t *testing.T) {
|
||||||
|
limiter := NewTokenBucketLimiter(2, 100, 1.0)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Exhaust key1
|
||||||
|
limiter.Allow(ctx, "key1")
|
||||||
|
limiter.Allow(ctx, "key1")
|
||||||
|
|
||||||
|
// key1 should be blocked
|
||||||
|
allowed, _ := limiter.Allow(ctx, "key1")
|
||||||
|
if allowed {
|
||||||
|
t.Error("key1 should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// key2 should still work
|
||||||
|
allowed, _ = limiter.Allow(ctx, "key2")
|
||||||
|
if !allowed {
|
||||||
|
t.Error("key2 should be allowed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get limit returns correct values", func(t *testing.T) {
|
||||||
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||||
|
limiter.Allow(context.Background(), "test-key")
|
||||||
|
|
||||||
|
limit := limiter.GetLimit("test-key")
|
||||||
|
if limit.RPM != 60 {
|
||||||
|
t.Errorf("expected RPM 60, got %d", limit.RPM)
|
||||||
|
}
|
||||||
|
if limit.TPM != 60000 {
|
||||||
|
t.Errorf("expected TPM 60000, got %d", limit.TPM)
|
||||||
|
}
|
||||||
|
if limit.Burst != 90 { // 60 * 1.5
|
||||||
|
t.Errorf("expected Burst 90, got %d", limit.Burst)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSlidingWindowLimiter(t *testing.T) {
|
||||||
|
t.Run("allows requests within window", func(t *testing.T) {
|
||||||
|
limiter := NewSlidingWindowLimiter(time.Minute, 5)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
allowed, err := limiter.Allow(ctx, "test-key")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if !allowed {
|
||||||
|
t.Errorf("request %d should be allowed", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks requests over window limit", func(t *testing.T) {
|
||||||
|
limiter := NewSlidingWindowLimiter(time.Minute, 2)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
limiter.Allow(ctx, "test-key")
|
||||||
|
limiter.Allow(ctx, "test-key")
|
||||||
|
|
||||||
|
allowed, _ := limiter.Allow(ctx, "test-key")
|
||||||
|
if allowed {
|
||||||
|
t.Error("third request should be blocked")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sliding window respects time", func(t *testing.T) {
|
||||||
|
limiter := &SlidingWindowLimiter{
|
||||||
|
windows: make(map[string]*slidingWindow),
|
||||||
|
windowSize: time.Minute,
|
||||||
|
maxRequests: 2,
|
||||||
|
cleanInterval: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
key := "test-key"
|
||||||
|
|
||||||
|
// Make requests
|
||||||
|
limiter.Allow(ctx, key)
|
||||||
|
limiter.Allow(ctx, key)
|
||||||
|
|
||||||
|
// Should be blocked
|
||||||
|
allowed, _ := limiter.Allow(ctx, key)
|
||||||
|
if allowed {
|
||||||
|
t.Error("should be blocked after reaching limit")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Simulate time passing - move window forward
|
||||||
|
limiter.windows[key].requests[0] = time.Now().Add(-2 * time.Minute)
|
||||||
|
limiter.windows[key].requests[1] = time.Now().Add(-2 * time.Minute)
|
||||||
|
|
||||||
|
// Should allow now
|
||||||
|
allowed, _ = limiter.Allow(ctx, key)
|
||||||
|
if !allowed {
|
||||||
|
t.Error("should allow after old requests expire from window")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("separate windows for different keys", func(t *testing.T) {
|
||||||
|
limiter := NewSlidingWindowLimiter(time.Minute, 1)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
limiter.Allow(ctx, "key1")
|
||||||
|
|
||||||
|
allowed, _ := limiter.Allow(ctx, "key1")
|
||||||
|
if allowed {
|
||||||
|
t.Error("key1 should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, _ = limiter.Allow(ctx, "key2")
|
||||||
|
if !allowed {
|
||||||
|
t.Error("key2 should be allowed")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("get limit returns correct remaining", func(t *testing.T) {
|
||||||
|
limiter := NewSlidingWindowLimiter(time.Minute, 10)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
limiter.Allow(ctx, "test-key")
|
||||||
|
limiter.Allow(ctx, "test-key")
|
||||||
|
limiter.Allow(ctx, "test-key")
|
||||||
|
|
||||||
|
limit := limiter.GetLimit("test-key")
|
||||||
|
if limit.RPM != 10 {
|
||||||
|
t.Errorf("expected RPM 10, got %d", limit.RPM)
|
||||||
|
}
|
||||||
|
if limit.Remaining != 7 {
|
||||||
|
t.Errorf("expected Remaining 7, got %d", limit.Remaining)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddleware(t *testing.T) {
|
||||||
|
t.Run("allows request when under limit", func(t *testing.T) {
|
||||||
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||||
|
middleware := NewMiddleware(limiter)
|
||||||
|
|
||||||
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer test-token")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("sets rate limit headers when blocked", func(t *testing.T) {
|
||||||
|
// Use very low limit so request is blocked
|
||||||
|
limiter := &TokenBucketLimiter{
|
||||||
|
buckets: make(map[string]*tokenBucket),
|
||||||
|
defaultRPM: 1,
|
||||||
|
defaultTPM: 100,
|
||||||
|
burstMultiplier: 1.0,
|
||||||
|
cleanInterval: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
||||||
|
key := "test-token"
|
||||||
|
bucket := limiter.newBucket(1, 100)
|
||||||
|
bucket.tokens = 0
|
||||||
|
limiter.buckets[key] = bucket
|
||||||
|
|
||||||
|
middleware := NewMiddleware(limiter)
|
||||||
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
// Headers should be set when rate limited
|
||||||
|
if rr.Header().Get("X-RateLimit-Limit") == "" {
|
||||||
|
t.Error("expected X-RateLimit-Limit header to be set")
|
||||||
|
}
|
||||||
|
if rr.Header().Get("X-RateLimit-Remaining") == "" {
|
||||||
|
t.Error("expected X-RateLimit-Remaining header to be set")
|
||||||
|
}
|
||||||
|
if rr.Header().Get("X-RateLimit-Reset") == "" {
|
||||||
|
t.Error("expected X-RateLimit-Reset header to be set")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("blocks request when over limit", func(t *testing.T) {
|
||||||
|
// Use very low limit
|
||||||
|
limiter := &TokenBucketLimiter{
|
||||||
|
buckets: make(map[string]*tokenBucket),
|
||||||
|
defaultRPM: 1,
|
||||||
|
defaultTPM: 100,
|
||||||
|
burstMultiplier: 1.0,
|
||||||
|
cleanInterval: 10 * time.Minute,
|
||||||
|
}
|
||||||
|
// Exhaust the bucket - key is the extracted token, not the full Authorization header
|
||||||
|
key := "test-token"
|
||||||
|
bucket := limiter.newBucket(1, 100)
|
||||||
|
bucket.tokens = 0 // Exhaust
|
||||||
|
limiter.buckets[key] = bucket
|
||||||
|
|
||||||
|
middleware := NewMiddleware(limiter)
|
||||||
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t.Error("handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusTooManyRequests {
|
||||||
|
t.Errorf("expected status 429, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("uses remote addr when no auth header", func(t *testing.T) {
|
||||||
|
limiter := NewTokenBucketLimiter(60, 60000, 1.5)
|
||||||
|
middleware := NewMiddleware(limiter)
|
||||||
|
|
||||||
|
handler := middleware.Limit(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
// No Authorization header
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Errorf("expected status 200, got %d", rr.Code)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -3,6 +3,7 @@ package router
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"math"
|
"math"
|
||||||
|
"math/rand"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -10,6 +11,9 @@ import (
|
|||||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// 全局随机数生成器(线程安全)
|
||||||
|
var globalRand = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||||
|
|
||||||
// LoadBalancerStrategy 负载均衡策略
|
// LoadBalancerStrategy 负载均衡策略
|
||||||
type LoadBalancerStrategy string
|
type LoadBalancerStrategy string
|
||||||
|
|
||||||
@@ -142,7 +146,7 @@ func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, e
|
|||||||
totalWeight += r.health[name].Weight
|
totalWeight += r.health[name].Weight
|
||||||
}
|
}
|
||||||
|
|
||||||
randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight
|
randVal := globalRand.Float64() * totalWeight
|
||||||
var cumulative float64
|
var cumulative float64
|
||||||
|
|
||||||
for _, name := range candidates {
|
for _, name := range candidates {
|
||||||
@@ -215,11 +219,17 @@ func (r *Router) RecordResult(ctx context.Context, providerName string, success
|
|||||||
|
|
||||||
// 更新失败率
|
// 更新失败率
|
||||||
if success {
|
if success {
|
||||||
if health.FailureRate > 0 {
|
// 成功时快速恢复:使用0.5的下降因子加速恢复
|
||||||
health.FailureRate = health.FailureRate * 0.9 // 下降
|
health.FailureRate = health.FailureRate * 0.5
|
||||||
|
if health.FailureRate < 0.01 {
|
||||||
|
health.FailureRate = 0
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升
|
// 失败时逐步上升
|
||||||
|
health.FailureRate = health.FailureRate*0.9 + 0.1
|
||||||
|
if health.FailureRate > 1 {
|
||||||
|
health.FailureRate = 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否应该标记为不可用
|
// 检查是否应该标记为不可用
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package service
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -89,6 +90,8 @@ type DefaultIAMService struct {
|
|||||||
userRoleStore map[int64][]*UserRole
|
userRoleStore map[int64][]*UserRole
|
||||||
// 角色Scope存储: roleCode -> []scopeCode
|
// 角色Scope存储: roleCode -> []scopeCode
|
||||||
roleScopeStore map[string][]string
|
roleScopeStore map[string][]string
|
||||||
|
// 并发控制
|
||||||
|
mu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDefaultIAMService 创建默认IAM服务
|
// NewDefaultIAMService 创建默认IAM服务
|
||||||
@@ -102,6 +105,9 @@ func NewDefaultIAMService() *DefaultIAMService {
|
|||||||
|
|
||||||
// CreateRole 创建角色
|
// CreateRole 创建角色
|
||||||
func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
|
func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// 检查是否重复
|
// 检查是否重复
|
||||||
if _, exists := s.roleStore[req.Code]; exists {
|
if _, exists := s.roleStore[req.Code]; exists {
|
||||||
return nil, ErrDuplicateRoleCode
|
return nil, ErrDuplicateRoleCode
|
||||||
@@ -138,6 +144,9 @@ func (s *DefaultIAMService) CreateRole(ctx context.Context, req *CreateRoleReque
|
|||||||
|
|
||||||
// GetRole 获取角色
|
// GetRole 获取角色
|
||||||
func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) {
|
func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
role, exists := s.roleStore[roleCode]
|
role, exists := s.roleStore[roleCode]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, ErrRoleNotFound
|
return nil, ErrRoleNotFound
|
||||||
@@ -147,6 +156,9 @@ func (s *DefaultIAMService) GetRole(ctx context.Context, roleCode string) (*Role
|
|||||||
|
|
||||||
// UpdateRole 更新角色
|
// UpdateRole 更新角色
|
||||||
func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) {
|
func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
role, exists := s.roleStore[req.Code]
|
role, exists := s.roleStore[req.Code]
|
||||||
if !exists {
|
if !exists {
|
||||||
return nil, ErrRoleNotFound
|
return nil, ErrRoleNotFound
|
||||||
@@ -175,6 +187,9 @@ func (s *DefaultIAMService) UpdateRole(ctx context.Context, req *UpdateRoleReque
|
|||||||
|
|
||||||
// DeleteRole 删除角色(软删除)
|
// DeleteRole 删除角色(软删除)
|
||||||
func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) error {
|
func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
role, exists := s.roleStore[roleCode]
|
role, exists := s.roleStore[roleCode]
|
||||||
if !exists {
|
if !exists {
|
||||||
return ErrRoleNotFound
|
return ErrRoleNotFound
|
||||||
@@ -187,6 +202,9 @@ func (s *DefaultIAMService) DeleteRole(ctx context.Context, roleCode string) err
|
|||||||
|
|
||||||
// ListRoles 列出角色
|
// ListRoles 列出角色
|
||||||
func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) {
|
func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
var roles []*Role
|
var roles []*Role
|
||||||
for _, role := range s.roleStore {
|
for _, role := range s.roleStore {
|
||||||
if roleType == "" || role.Type == roleType {
|
if roleType == "" || role.Type == roleType {
|
||||||
@@ -198,6 +216,9 @@ func (s *DefaultIAMService) ListRoles(ctx context.Context, roleType string) ([]*
|
|||||||
|
|
||||||
// AssignRole 分配角色
|
// AssignRole 分配角色
|
||||||
func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error) {
|
func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error) {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
// 检查角色是否存在
|
// 检查角色是否存在
|
||||||
if _, exists := s.roleStore[req.RoleCode]; !exists {
|
if _, exists := s.roleStore[req.RoleCode]; !exists {
|
||||||
return nil, ErrRoleNotFound
|
return nil, ErrRoleNotFound
|
||||||
@@ -226,6 +247,9 @@ func (s *DefaultIAMService) AssignRole(ctx context.Context, req *AssignRoleReque
|
|||||||
|
|
||||||
// RevokeRole 撤销角色
|
// RevokeRole 撤销角色
|
||||||
func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
for _, ur := range s.userRoleStore[userID] {
|
for _, ur := range s.userRoleStore[userID] {
|
||||||
if ur.RoleCode == roleCode && ur.TenantID == tenantID {
|
if ur.RoleCode == roleCode && ur.TenantID == tenantID {
|
||||||
ur.IsActive = false
|
ur.IsActive = false
|
||||||
@@ -237,6 +261,9 @@ func (s *DefaultIAMService) RevokeRole(ctx context.Context, userID int64, roleCo
|
|||||||
|
|
||||||
// GetUserRoles 获取用户角色
|
// GetUserRoles 获取用户角色
|
||||||
func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) {
|
func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
var userRoles []*UserRole
|
var userRoles []*UserRole
|
||||||
for _, ur := range s.userRoleStore[userID] {
|
for _, ur := range s.userRoleStore[userID] {
|
||||||
if ur.IsActive {
|
if ur.IsActive {
|
||||||
@@ -248,7 +275,10 @@ func (s *DefaultIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*
|
|||||||
|
|
||||||
// CheckScope 检查用户是否有指定Scope
|
// CheckScope 检查用户是否有指定Scope
|
||||||
func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) {
|
func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) {
|
||||||
scopes, err := s.GetUserScopes(ctx, userID)
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
scopes, err := s.getUserScopesLocked(userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
@@ -263,6 +293,14 @@ func (s *DefaultIAMService) CheckScope(ctx context.Context, userID int64, requir
|
|||||||
|
|
||||||
// GetUserScopes 获取用户所有Scope
|
// GetUserScopes 获取用户所有Scope
|
||||||
func (s *DefaultIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
func (s *DefaultIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
defer s.mu.RUnlock()
|
||||||
|
|
||||||
|
return s.getUserScopesLocked(userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getUserScopesLocked 获取用户所有Scope(内部使用,需要持有锁)
|
||||||
|
func (s *DefaultIAMService) getUserScopesLocked(userID int64) ([]string, error) {
|
||||||
var allScopes []string
|
var allScopes []string
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user