Files
lijiaoqiao/gateway/internal/ratelimit/ratelimit.go
Your Name 0484c7be74 feat(gateway): 实现网关核心模块
实现内容:
- internal/adapter: Provider Adapter抽象层和OpenAI实现
- internal/router: 多Provider路由(支持latency/weighted/availability策略)
- internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions)
- internal/ratelimit: Token Bucket和Sliding Window限流器
- internal/alert: 告警系统(支持邮件/钉钉/飞书)
- internal/config: 配置管理
- pkg/error: 完整错误码体系
- pkg/model: API请求/响应模型

PRD对齐:
- P0-1: 统一API接入  (OpenAI兼容)
- P0-2: 基础路由与稳定性  (多Provider路由+Fallback)
- P0-4: 预算与限流  (Token Bucket限流)

注意:需要供应链模块支持后再完善成本归因和账单导出
2026-04-01 10:04:52 +08:00

337 lines
7.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package ratelimit
import (
"context"
"fmt"
"sync"
"time"
"lijiaoqiao/gateway/pkg/error"
)
// Algorithm 限流算法
type Algorithm string
const (
TokenBucket Algorithm = "token_bucket"
SlidingWindow Algorithm = "sliding_window"
FixedWindow Algorithm = "fixed_window"
)
// Limiter 限流器接口
type Limiter interface {
// Allow 检查是否允许请求
Allow(ctx context.Context, key string) (bool, error)
// AllowToken 检查是否允许消耗token
AllowToken(ctx context.Context, key string, tokens int) (bool, error)
// GetLimit 获取当前限制
GetLimit(key string) *Limit
}
// Limit 限制配置
type Limit struct {
RPM int // 请求数/分钟
TPM int // Token数/分钟
Burst int // 突发容量
Remaining int // 剩余请求数
ResetAt time.Time // 重置时间
}
// TokenBucketLimiter Token桶限流器
type TokenBucketLimiter struct {
mu sync.RWMutex
buckets map[string]*tokenBucket
defaultRPM int
defaultTPM int
burstMultiplier float64
cleanInterval time.Duration
}
type tokenBucket struct {
tokens float64
maxTokens float64
tokensPerSec float64
lastRefill time.Time
mu sync.Mutex
}
// NewTokenBucketLimiter 创建Token桶限流器
func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter {
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: defaultRPM,
defaultTPM: defaultTPM,
burstMultiplier: burstMultiplier,
cleanInterval: 5 * time.Minute,
}
// 启动清理goroutine
go limiter.cleanup()
return limiter
}
// Allow 检查是否允许请求
func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) {
return l.AllowToken(ctx, key, 1)
}
// AllowToken 检查是否允许消耗token
func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
l.mu.Lock()
bucket, exists := l.buckets[key]
if !exists {
bucket = l.newBucket(l.defaultRPM, l.defaultTPM)
l.buckets[key] = bucket
}
l.mu.Unlock()
bucket.mu.Lock()
defer bucket.mu.Unlock()
// 补充token
l.refill(bucket)
// 检查是否有足够的token
if bucket.tokens >= float64(tokens) {
bucket.tokens -= float64(tokens)
return true, nil
}
return false, nil
}
// GetLimit 获取当前限制
func (l *TokenBucketLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
bucket, exists := l.buckets[key]
l.mu.RUnlock()
if !exists {
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(float64(l.defaultRPM) * l.burstMultiplier),
}
}
bucket.mu.Lock()
defer bucket.mu.Unlock()
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(bucket.maxTokens),
Remaining: int(bucket.tokens),
ResetAt: bucket.lastRefill.Add(time.Minute),
}
}
func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket {
burst := int(float64(rpm) * l.burstMultiplier)
return &tokenBucket{
tokens: float64(burst),
maxTokens: float64(burst),
tokensPerSec: float64(rpm) / 60.0,
lastRefill: time.Now(),
}
}
func (l *TokenBucketLimiter) refill(bucket *tokenBucket) {
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
// 添加新token
bucket.tokens += elapsed * bucket.tokensPerSec
if bucket.tokens > bucket.maxTokens {
bucket.tokens = bucket.maxTokens
}
bucket.lastRefill = now
}
func (l *TokenBucketLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, bucket := range l.buckets {
bucket.mu.Lock()
// 如果bucket完全空了且超过10分钟没使用删除它
if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute {
delete(l.buckets, key)
}
bucket.mu.Unlock()
}
l.mu.Unlock()
}
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.RWMutex
windows map[string]*slidingWindow
windowSize time.Duration
maxRequests int
cleanInterval time.Duration
}
type slidingWindow struct {
requests []time.Time
mu sync.Mutex
}
func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter {
limiter := &SlidingWindowLimiter{
windows: make(map[string]*slidingWindow),
windowSize: windowSize,
maxRequests: maxRequests,
cleanInterval: 1 * time.Minute,
}
go limiter.cleanup()
return limiter
}
func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) {
l.mu.Lock()
window, exists := l.windows[key]
if !exists {
window = &slidingWindow{requests: make([]time.Time, 0)}
l.windows[key] = window
}
l.mu.Unlock()
window.mu.Lock()
defer window.mu.Unlock()
now := time.Now()
cutoff := now.Add(-l.windowSize)
// 清理过期的请求
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
window.requests = validRequests
// 检查是否超过限制
if len(window.requests) >= l.maxRequests {
return false, nil
}
window.requests = append(window.requests, now)
return true, nil
}
func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
// 对于滑动窗口tokens只是计数这里简化为1个请求
return l.Allow(ctx, key)
}
func (l *SlidingWindowLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
window, exists := l.windows[key]
l.mu.RUnlock()
remaining := l.maxRequests
if exists {
window.mu.Lock()
cutoff := time.Now().Add(-l.windowSize)
count := 0
for _, t := range window.requests {
if t.After(cutoff) {
count++
}
}
remaining = l.maxRequests - count
if remaining < 0 {
remaining = 0
}
window.mu.Unlock()
}
return &Limit{
RPM: l.maxRequests,
ResetAt: time.Now().Add(l.windowSize),
Remaining: remaining,
}
}
func (l *SlidingWindowLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, window := range l.windows {
window.mu.Lock()
cutoff := now.Add(-l.windowSize * 2)
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
delete(l.windows, key)
} else {
window.requests = validRequests
}
window.mu.Unlock()
}
l.mu.Unlock()
}
}
// Middleware 限流中间件
type Middleware struct {
limiter Limiter
}
func NewMiddleware(limiter Limiter) *Middleware {
return &Middleware{limiter: limiter}
}
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 使用API Key作为限流key
key := r.Header.Get("Authorization")
if key == "" {
key = r.RemoteAddr
}
allowed, err := m.limiter.Allow(r.Context(), key)
if err != nil {
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
return
}
if !allowed {
limit := m.limiter.GetLimit(key)
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
return
}
next.ServeHTTP(w, r)
}
}
import "net/http"
func writeError(w http.ResponseWriter, err *error.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(info.HTTPStatus)
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code)))
}