Files
user-system/internal/api/middleware/ratelimit.go

128 lines
2.9 KiB
Go

package middleware
import (
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/config"
)
// RateLimitMiddleware 限流中间件
type RateLimitMiddleware struct {
cfg config.RateLimitConfig
limiters map[string]*SlidingWindowLimiter
mu sync.RWMutex
cleanupInt time.Duration
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.Mutex
window time.Duration
capacity int64
requests []int64
}
// NewSlidingWindowLimiter 创建滑动窗口限流器
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
window: window,
capacity: capacity,
requests: make([]int64, 0),
}
}
// Allow 检查是否允许请求
func (l *SlidingWindowLimiter) Allow() bool {
l.mu.Lock()
defer l.mu.Unlock()
now := time.Now().UnixMilli()
cutoff := now - l.window.Milliseconds()
// 清理过期请求
var validRequests []int64
for _, t := range l.requests {
if t > cutoff {
validRequests = append(validRequests, t)
}
}
l.requests = validRequests
// 检查容量
if int64(len(l.requests)) >= l.capacity {
return false
}
l.requests = append(l.requests, now)
return true
}
// NewRateLimitMiddleware 创建限流中间件
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
return &RateLimitMiddleware{
cfg: cfg,
limiters: make(map[string]*SlidingWindowLimiter),
cleanupInt: 5 * time.Minute,
}
}
// Register 返回注册接口的限流中间件
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
return m.limitForKey("register", 60, 10)
}
// Login 返回登录接口的限流中间件
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
return m.limitForKey("login", 60, 5)
}
// API 返回 API 接口的限流中间件
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
return m.limitForKey("api", 60, 100)
}
// Refresh 返回刷新令牌的限流中间件
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
return m.limitForKey("refresh", 60, 10)
}
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
return func(c *gin.Context) {
if !limiter.Allow() {
c.JSON(429, gin.H{
"code": 429,
"message": "请求过于频繁,请稍后再试",
})
c.Abort()
return
}
c.Next()
}
}
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
m.mu.RLock()
limiter, exists := m.limiters[key]
m.mu.RUnlock()
if exists {
return limiter
}
m.mu.Lock()
defer m.mu.Unlock()
// 双重检查
if limiter, exists = m.limiters[key]; exists {
return limiter
}
limiter = NewSlidingWindowLimiter(window, capacity)
m.limiters[key] = limiter
return limiter
}