package ratelimit import ( "context" "fmt" "sync" "time" "lijiaoqiao/gateway/pkg/error" ) // Algorithm 限流算法 type Algorithm string const ( TokenBucket Algorithm = "token_bucket" SlidingWindow Algorithm = "sliding_window" FixedWindow Algorithm = "fixed_window" ) // Limiter 限流器接口 type Limiter interface { // Allow 检查是否允许请求 Allow(ctx context.Context, key string) (bool, error) // AllowToken 检查是否允许消耗token AllowToken(ctx context.Context, key string, tokens int) (bool, error) // GetLimit 获取当前限制 GetLimit(key string) *Limit } // Limit 限制配置 type Limit struct { RPM int // 请求数/分钟 TPM int // Token数/分钟 Burst int // 突发容量 Remaining int // 剩余请求数 ResetAt time.Time // 重置时间 } // TokenBucketLimiter Token桶限流器 type TokenBucketLimiter struct { mu sync.RWMutex buckets map[string]*tokenBucket defaultRPM int defaultTPM int burstMultiplier float64 cleanInterval time.Duration } type tokenBucket struct { tokens float64 maxTokens float64 tokensPerSec float64 lastRefill time.Time mu sync.Mutex } // NewTokenBucketLimiter 创建Token桶限流器 func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter { limiter := &TokenBucketLimiter{ buckets: make(map[string]*tokenBucket), defaultRPM: defaultRPM, defaultTPM: defaultTPM, burstMultiplier: burstMultiplier, cleanInterval: 5 * time.Minute, } // 启动清理goroutine go limiter.cleanup() return limiter } // Allow 检查是否允许请求 func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) { return l.AllowToken(ctx, key, 1) } // AllowToken 检查是否允许消耗token func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) { l.mu.Lock() bucket, exists := l.buckets[key] if !exists { bucket = l.newBucket(l.defaultRPM, l.defaultTPM) l.buckets[key] = bucket } l.mu.Unlock() bucket.mu.Lock() defer bucket.mu.Unlock() // 补充token l.refill(bucket) // 检查是否有足够的token if bucket.tokens >= float64(tokens) { bucket.tokens -= float64(tokens) return true, nil } return false, nil } // GetLimit 获取当前限制 func (l *TokenBucketLimiter) GetLimit(key string) *Limit { l.mu.RLock() bucket, exists := l.buckets[key] l.mu.RUnlock() if !exists { return &Limit{ RPM: l.defaultRPM, TPM: l.defaultTPM, Burst: int(float64(l.defaultRPM) * l.burstMultiplier), } } bucket.mu.Lock() defer bucket.mu.Unlock() return &Limit{ RPM: l.defaultRPM, TPM: l.defaultTPM, Burst: int(bucket.maxTokens), Remaining: int(bucket.tokens), ResetAt: bucket.lastRefill.Add(time.Minute), } } func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket { burst := int(float64(rpm) * l.burstMultiplier) return &tokenBucket{ tokens: float64(burst), maxTokens: float64(burst), tokensPerSec: float64(rpm) / 60.0, lastRefill: time.Now(), } } func (l *TokenBucketLimiter) refill(bucket *tokenBucket) { now := time.Now() elapsed := now.Sub(bucket.lastRefill).Seconds() // 添加新token bucket.tokens += elapsed * bucket.tokensPerSec if bucket.tokens > bucket.maxTokens { bucket.tokens = bucket.maxTokens } bucket.lastRefill = now } func (l *TokenBucketLimiter) cleanup() { ticker := time.NewTicker(l.cleanInterval) defer ticker.Stop() for range ticker.C { l.mu.Lock() now := time.Now() for key, bucket := range l.buckets { bucket.mu.Lock() // 如果bucket完全空了且超过10分钟没使用,删除它 if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute { delete(l.buckets, key) } bucket.mu.Unlock() } l.mu.Unlock() } } // SlidingWindowLimiter 滑动窗口限流器 type SlidingWindowLimiter struct { mu sync.RWMutex windows map[string]*slidingWindow windowSize time.Duration maxRequests int cleanInterval time.Duration } type slidingWindow struct { requests []time.Time mu sync.Mutex } func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter { limiter := &SlidingWindowLimiter{ windows: make(map[string]*slidingWindow), windowSize: windowSize, maxRequests: maxRequests, cleanInterval: 1 * time.Minute, } go limiter.cleanup() return limiter } func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) { l.mu.Lock() window, exists := l.windows[key] if !exists { window = &slidingWindow{requests: make([]time.Time, 0)} l.windows[key] = window } l.mu.Unlock() window.mu.Lock() defer window.mu.Unlock() now := time.Now() cutoff := now.Add(-l.windowSize) // 清理过期的请求 validRequests := make([]time.Time, 0) for _, t := range window.requests { if t.After(cutoff) { validRequests = append(validRequests, t) } } window.requests = validRequests // 检查是否超过限制 if len(window.requests) >= l.maxRequests { return false, nil } window.requests = append(window.requests, now) return true, nil } func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) { // 对于滑动窗口,tokens只是计数,这里简化为1个请求 return l.Allow(ctx, key) } func (l *SlidingWindowLimiter) GetLimit(key string) *Limit { l.mu.RLock() window, exists := l.windows[key] l.mu.RUnlock() remaining := l.maxRequests if exists { window.mu.Lock() cutoff := time.Now().Add(-l.windowSize) count := 0 for _, t := range window.requests { if t.After(cutoff) { count++ } } remaining = l.maxRequests - count if remaining < 0 { remaining = 0 } window.mu.Unlock() } return &Limit{ RPM: l.maxRequests, ResetAt: time.Now().Add(l.windowSize), Remaining: remaining, } } func (l *SlidingWindowLimiter) cleanup() { ticker := time.NewTicker(l.cleanInterval) defer ticker.Stop() for range ticker.C { l.mu.Lock() now := time.Now() for key, window := range l.windows { window.mu.Lock() cutoff := now.Add(-l.windowSize * 2) validRequests := make([]time.Time, 0) for _, t := range window.requests { if t.After(cutoff) { validRequests = append(validRequests, t) } } if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 { delete(l.windows, key) } else { window.requests = validRequests } window.mu.Unlock() } l.mu.Unlock() } } // Middleware 限流中间件 type Middleware struct { limiter Limiter } func NewMiddleware(limiter Limiter) *Middleware { return &Middleware{limiter: limiter} } func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // 使用API Key作为限流key key := r.Header.Get("Authorization") if key == "" { key = r.RemoteAddr } allowed, err := m.limiter.Allow(r.Context(), key) if err != nil { writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error")) return } if !allowed { limit := m.limiter.GetLimit(key) w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM)) w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining)) w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix())) writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded")) return } next.ServeHTTP(w, r) } } import "net/http" func writeError(w http.ResponseWriter, err *error.GatewayError) { info := err.GetErrorInfo() w.Header().Set("Content-Type", "application/json") w.WriteHeader(info.HTTPStatus) w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code))) }