128 lines
2.9 KiB
Go
128 lines
2.9 KiB
Go
package middleware
|
|
|
|
import (
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/user-management-system/internal/config"
|
|
)
|
|
|
|
// RateLimitMiddleware 限流中间件
|
|
type RateLimitMiddleware struct {
|
|
cfg config.RateLimitConfig
|
|
limiters map[string]*SlidingWindowLimiter
|
|
mu sync.RWMutex
|
|
cleanupInt time.Duration
|
|
}
|
|
|
|
// SlidingWindowLimiter 滑动窗口限流器
|
|
type SlidingWindowLimiter struct {
|
|
mu sync.Mutex
|
|
window time.Duration
|
|
capacity int64
|
|
requests []int64
|
|
}
|
|
|
|
// NewSlidingWindowLimiter 创建滑动窗口限流器
|
|
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
|
|
return &SlidingWindowLimiter{
|
|
window: window,
|
|
capacity: capacity,
|
|
requests: make([]int64, 0),
|
|
}
|
|
}
|
|
|
|
// Allow 检查是否允许请求
|
|
func (l *SlidingWindowLimiter) Allow() bool {
|
|
l.mu.Lock()
|
|
defer l.mu.Unlock()
|
|
|
|
now := time.Now().UnixMilli()
|
|
cutoff := now - l.window.Milliseconds()
|
|
|
|
// 清理过期请求
|
|
var validRequests []int64
|
|
for _, t := range l.requests {
|
|
if t > cutoff {
|
|
validRequests = append(validRequests, t)
|
|
}
|
|
}
|
|
l.requests = validRequests
|
|
|
|
// 检查容量
|
|
if int64(len(l.requests)) >= l.capacity {
|
|
return false
|
|
}
|
|
|
|
l.requests = append(l.requests, now)
|
|
return true
|
|
}
|
|
|
|
// NewRateLimitMiddleware 创建限流中间件
|
|
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
|
return &RateLimitMiddleware{
|
|
cfg: cfg,
|
|
limiters: make(map[string]*SlidingWindowLimiter),
|
|
cleanupInt: 5 * time.Minute,
|
|
}
|
|
}
|
|
|
|
// Register 返回注册接口的限流中间件
|
|
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
|
|
return m.limitForKey("register", 60, 10)
|
|
}
|
|
|
|
// Login 返回登录接口的限流中间件
|
|
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
|
|
return m.limitForKey("login", 60, 5)
|
|
}
|
|
|
|
// API 返回 API 接口的限流中间件
|
|
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
|
|
return m.limitForKey("api", 60, 100)
|
|
}
|
|
|
|
// Refresh 返回刷新令牌的限流中间件
|
|
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
|
return m.limitForKey("refresh", 60, 10)
|
|
}
|
|
|
|
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
|
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
|
|
|
return func(c *gin.Context) {
|
|
if !limiter.Allow() {
|
|
c.JSON(429, gin.H{
|
|
"code": 429,
|
|
"message": "请求过于频繁,请稍后再试",
|
|
})
|
|
c.Abort()
|
|
return
|
|
}
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
|
m.mu.RLock()
|
|
limiter, exists := m.limiters[key]
|
|
m.mu.RUnlock()
|
|
|
|
if exists {
|
|
return limiter
|
|
}
|
|
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
|
|
// 双重检查
|
|
if limiter, exists = m.limiters[key]; exists {
|
|
return limiter
|
|
}
|
|
|
|
limiter = NewSlidingWindowLimiter(window, capacity)
|
|
m.limiters[key] = limiter
|
|
return limiter
|
|
}
|