fix: 修复6个代码质量问题
P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key
This commit is contained in:
@@ -3,10 +3,12 @@ package ratelimit
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
gwerror "lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// Algorithm 限流算法
|
||||
@@ -278,7 +280,7 @@ func (l *SlidingWindowLimiter) cleanup() {
|
||||
validRequests = append(validRequests, t)
|
||||
}
|
||||
}
|
||||
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||
if len(validRequests) == 0 && len(window.requests) > 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||
delete(l.windows, key)
|
||||
} else {
|
||||
window.requests = validRequests
|
||||
@@ -301,14 +303,14 @@ func NewMiddleware(limiter Limiter) *Middleware {
|
||||
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 使用API Key作为限流key
|
||||
key := r.Header.Get("Authorization")
|
||||
key := extractRateLimitKey(r)
|
||||
if key == "" {
|
||||
key = r.RemoteAddr
|
||||
}
|
||||
|
||||
allowed, err := m.limiter.Allow(r.Context(), key)
|
||||
if err != nil {
|
||||
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||
writeError(w, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -318,7 +320,7 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
||||
|
||||
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||
writeError(w, gwerror.NewGatewayError(gwerror.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -326,9 +328,27 @@ func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
import "net/http"
|
||||
// extractRateLimitKey 从请求中提取限流key
|
||||
func extractRateLimitKey(r *http.Request) string {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, err *error.GatewayError) {
|
||||
// 如果是Bearer token,提取token部分
|
||||
if strings.HasPrefix(authHeader, "Bearer ") {
|
||||
token := strings.TrimPrefix(authHeader, "Bearer ")
|
||||
token = strings.TrimSpace(token)
|
||||
if token != "" {
|
||||
return token
|
||||
}
|
||||
}
|
||||
|
||||
// 否则返回原始header(不应该发生)
|
||||
return authHeader
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, err *gwerror.GatewayError) {
|
||||
info := err.GetErrorInfo()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(info.HTTPStatus)
|
||||
|
||||
Reference in New Issue
Block a user