Files
lijiaoqiao/gateway/internal/ratelimit/ratelimit.go
Your Name 6924b2bafc fix: 修复6个代码质量问题
P1-01: 提取重复的角色层级定义为包级常量
- 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量
- 消除重复定义

P1-02: 修复伪随机数用于加权选择
- 使用 math/rand 的线程安全随机数生成器替代时间戳
- 确保加权路由的均匀分布

P1-03: 修复 FailureRate 初始化计算错误
- 将成功时的恢复因子从 0.9 改为 0.5
- 加速失败后的恢复过程

P1-04: 为 DefaultIAMService 添加并发控制
- 添加 sync.RWMutex 保护 map 操作
- 确保所有服务方法的线程安全

P1-05: 修复 IP 伪造漏洞
- 添加 TrustedProxies 配置
- 只在来自可信代理时才使用 X-Forwarded-For

P1-06: 修复限流 key 提取逻辑错误
- 从 Authorization header 中提取 Bearer token
- 避免使用完整的 header 作为限流 key
2026-04-03 07:58:46 +08:00

357 lines
8.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ratelimit
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
gwerror "lijiaoqiao/gateway/pkg/error"
)
// Algorithm 限流算法
type Algorithm string
const (
TokenBucket Algorithm = "token_bucket"
SlidingWindow Algorithm = "sliding_window"
FixedWindow Algorithm = "fixed_window"
)
// Limiter 限流器接口
type Limiter interface {
// Allow 检查是否允许请求
Allow(ctx context.Context, key string) (bool, error)
// AllowToken 检查是否允许消耗token
AllowToken(ctx context.Context, key string, tokens int) (bool, error)
// GetLimit 获取当前限制
GetLimit(key string) *Limit
}
// Limit 限制配置
type Limit struct {
RPM int // 请求数/分钟
TPM int // Token数/分钟
Burst int // 突发容量
Remaining int // 剩余请求数
ResetAt time.Time // 重置时间
}
// TokenBucketLimiter Token桶限流器
type TokenBucketLimiter struct {
mu sync.RWMutex
buckets map[string]*tokenBucket
defaultRPM int
defaultTPM int
burstMultiplier float64
cleanInterval time.Duration
}
type tokenBucket struct {
tokens float64
maxTokens float64
tokensPerSec float64
lastRefill time.Time
mu sync.Mutex
}
// NewTokenBucketLimiter 创建Token桶限流器
func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter {
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: defaultRPM,
defaultTPM: defaultTPM,
burstMultiplier: burstMultiplier,
cleanInterval: 5 * time.Minute,
}
// 启动清理goroutine
go limiter.cleanup()
return limiter
}
// Allow 检查是否允许请求
func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) {
return l.AllowToken(ctx, key, 1)
}
// AllowToken 检查是否允许消耗token
func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
l.mu.Lock()
bucket, exists := l.buckets[key]
if !exists {
bucket = l.newBucket(l.defaultRPM, l.defaultTPM)
l.buckets[key] = bucket
}
l.mu.Unlock()
bucket.mu.Lock()
defer bucket.mu.Unlock()
// 补充token
l.refill(bucket)
// 检查是否有足够的token
if bucket.tokens >= float64(tokens) {
bucket.tokens -= float64(tokens)
return true, nil
}
return false, nil
}
// GetLimit 获取当前限制
func (l *TokenBucketLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
bucket, exists := l.buckets[key]
l.mu.RUnlock()
if !exists {
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(float64(l.defaultRPM) * l.burstMultiplier),
}
}
bucket.mu.Lock()
defer bucket.mu.Unlock()
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(bucket.maxTokens),
Remaining: int(bucket.tokens),
ResetAt: bucket.lastRefill.Add(time.Minute),
}
}
func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket {
burst := int(float64(rpm) * l.burstMultiplier)
return &tokenBucket{
tokens: float64(burst),
maxTokens: float64(burst),
tokensPerSec: float64(rpm) / 60.0,
lastRefill: time.Now(),
}
}
func (l *TokenBucketLimiter) refill(bucket *tokenBucket) {
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
// 添加新token
bucket.tokens += elapsed * bucket.tokensPerSec
if bucket.tokens > bucket.maxTokens {
bucket.tokens = bucket.maxTokens
}
bucket.lastRefill = now
}
func (l *TokenBucketLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, bucket := range l.buckets {
bucket.mu.Lock()
// 如果bucket完全空了且超过10分钟没使用删除它
if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute {
delete(l.buckets, key)
}
bucket.mu.Unlock()
}
l.mu.Unlock()
}
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.RWMutex
windows map[string]*slidingWindow
windowSize time.Duration
maxRequests int
cleanInterval time.Duration
}
type slidingWindow struct {
requests []time.Time
mu sync.Mutex
}
func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter {
limiter := &SlidingWindowLimiter{
windows: make(map[string]*slidingWindow),
windowSize: windowSize,
maxRequests: maxRequests,
cleanInterval: 1 * time.Minute,
}
go limiter.cleanup()
return limiter
}
func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) {
l.mu.Lock()
window, exists := l.windows[key]
if !exists {
window = &slidingWindow{requests: make([]time.Time, 0)}
l.windows[key] = window
}
l.mu.Unlock()
window.mu.Lock()
defer window.mu.Unlock()
now := time.Now()
cutoff := now.Add(-l.windowSize)
// 清理过期的请求
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
window.requests = validRequests
// 检查是否超过限制
if len(window.requests) >= l.maxRequests {
return false, nil
}
window.requests = append(window.requests, now)
return true, nil
}
func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
// 对于滑动窗口tokens只是计数这里简化为1个请求
return l.Allow(ctx, key)
}
func (l *SlidingWindowLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
window, exists := l.windows[key]
l.mu.RUnlock()
remaining := l.maxRequests
if exists {
window.mu.Lock()
cutoff := time.Now().Add(-l.windowSize)
count := 0
for _, t := range window.requests {
if t.After(cutoff) {
count++
}
}
remaining = l.maxRequests - count
if remaining < 0 {
remaining = 0
}
window.mu.Unlock()
}
return &Limit{
RPM: l.maxRequests,
ResetAt: time.Now().Add(l.windowSize),
Remaining: remaining,
}
}
func (l *SlidingWindowLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, window := range l.windows {
window.mu.Lock()
cutoff := now.Add(-l.windowSize * 2)
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
if len(validRequests) == 0 && len(window.requests) > 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
delete(l.windows, key)
} else {
window.requests = validRequests
}
window.mu.Unlock()
}
l.mu.Unlock()
}
}
// Middleware 限流中间件
type Middleware struct {
limiter Limiter
}
func NewMiddleware(limiter Limiter) *Middleware {
return &Middleware{limiter: limiter}
}
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 使用API Key作为限流key
key := extractRateLimitKey(r)
if key == "" {
key = r.RemoteAddr
}
allowed, err := m.limiter.Allow(r.Context(), key)
if err != nil {
writeError(w, gwerror.NewGatewayError(gwerror.COMMON_INTERNAL_ERROR, "rate limiter error"))
return
}
if !allowed {
limit := m.limiter.GetLimit(key)
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
writeError(w, gwerror.NewGatewayError(gwerror.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
return
}
next.ServeHTTP(w, r)
}
}
// extractRateLimitKey 从请求中提取限流key
func extractRateLimitKey(r *http.Request) string {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
return ""
}
// 如果是Bearer token提取token部分
if strings.HasPrefix(authHeader, "Bearer ") {
token := strings.TrimPrefix(authHeader, "Bearer ")
token = strings.TrimSpace(token)
if token != "" {
return token
}
}
// 否则返回原始header不应该发生
return authHeader
}
func writeError(w http.ResponseWriter, err *gwerror.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(info.HTTPStatus)
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code)))
}