Files
lijiaoqiao/supply-api/internal/cache/redis.go
Your Name 0196ee5d47 feat(supply-api): 完成核心模块实现
新增/修改内容:
- config: 添加配置管理(config.example.yaml, config.go)
- cache: 添加Redis缓存层(redis.go)
- domain: 添加invariants不变量验证及测试
- middleware: 添加auth认证和idempotency幂等性中间件及测试
- repository: 添加完整数据访问层(account, package, settlement, idempotency, db)
- sql: 添加幂等性表DDL脚本

代码覆盖:
- auth middleware实现凭证边界验证
- idempotency middleware实现请求幂等性
- invariants实现业务不变量检查
- repository层实现完整的数据访问逻辑

关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
2026-04-01 08:53:28 +08:00

232 lines
6.3 KiB
Go

package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"lijiaoqiao/supply-api/internal/config"
)
// RedisCache Redis缓存客户端
type RedisCache struct {
client *redis.Client
}
// NewRedisCache 创建Redis缓存客户端
func NewRedisCache(cfg config.RedisConfig) (*RedisCache, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr(),
Password: cfg.Password,
DB: cfg.DB,
PoolSize: cfg.PoolSize,
})
// 验证连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return &RedisCache{client: client}, nil
}
// Close 关闭连接
func (r *RedisCache) Close() error {
return r.client.Close()
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
return r.client.Ping(ctx).Err()
}
// ==================== Token状态缓存 ====================
// TokenStatus Token状态
type TokenStatus struct {
TokenID string `json:"token_id"`
SubjectID string `json:"subject_id"`
Role string `json:"role"`
Status string `json:"status"` // active, revoked, expired
ExpiresAt int64 `json:"expires_at"`
RevokedAt int64 `json:"revoked_at,omitempty"`
RevokedReason string `json:"revoked_reason,omitempty"`
}
// GetTokenStatus 获取Token状态
func (r *RedisCache) GetTokenStatus(ctx context.Context, tokenID string) (*TokenStatus, error) {
key := fmt.Sprintf("token:status:%s", tokenID)
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get token status: %w", err)
}
var status TokenStatus
if err := json.Unmarshal(data, &status); err != nil {
return nil, fmt.Errorf("failed to unmarshal token status: %w", err)
}
return &status, nil
}
// SetTokenStatus 设置Token状态
func (r *RedisCache) SetTokenStatus(ctx context.Context, status *TokenStatus, ttl time.Duration) error {
key := fmt.Sprintf("token:status:%s", status.TokenID)
data, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to marshal token status: %w", err)
}
return r.client.Set(ctx, key, data, ttl).Err()
}
// InvalidateToken 使Token失效
func (r *RedisCache) InvalidateToken(ctx context.Context, tokenID string) error {
key := fmt.Sprintf("token:status:%s", tokenID)
return r.client.Del(ctx, key).Err()
}
// ==================== 限流 ====================
// RateLimitKey 限流键
type RateLimitKey struct {
TenantID int64
Route string
LimitType string // rpm, rpd, concurrent
}
// GetRateLimit 获取限流计数
func (r *RedisCache) GetRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
count, err := r.client.Get(ctx, redisKey).Int64()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("failed to get rate limit: %w", err)
}
return count, nil
}
// IncrRateLimit 增加限流计数
func (r *RedisCache) IncrRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
pipe := r.client.Pipeline()
incrCmd := pipe.Incr(ctx, redisKey)
pipe.Expire(ctx, redisKey, window)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, fmt.Errorf("failed to increment rate limit: %w", err)
}
return incrCmd.Val(), nil
}
// CheckRateLimit 检查限流
func (r *RedisCache) CheckRateLimit(ctx context.Context, key *RateLimitKey, limit int64, window time.Duration) (bool, int64, error) {
count, err := r.IncrRateLimit(ctx, key, window)
if err != nil {
return false, 0, err
}
return count <= limit, count, nil
}
// ==================== 分布式锁 ====================
// AcquireLock 获取分布式锁
func (r *RedisCache) AcquireLock(ctx context.Context, lockKey string, ttl time.Duration) (bool, error) {
redisKey := fmt.Sprintf("lock:%s", lockKey)
ok, err := r.client.SetNX(ctx, redisKey, "1", ttl).Result()
if err != nil {
return false, fmt.Errorf("failed to acquire lock: %w", err)
}
return ok, nil
}
// ReleaseLock 释放分布式锁
func (r *RedisCache) ReleaseLock(ctx context.Context, lockKey string) error {
redisKey := fmt.Sprintf("lock:%s", lockKey)
return r.client.Del(ctx, redisKey).Err()
}
// ==================== 幂等缓存 ====================
// IdempotencyCache 幂等缓存(短期)
func (r *RedisCache) GetIdempotency(ctx context.Context, key string) (string, error) {
redisKey := fmt.Sprintf("idempotency:%s", key)
val, err := r.client.Get(ctx, redisKey).Result()
if err == redis.Nil {
return "", nil
}
if err != nil {
return "", fmt.Errorf("failed to get idempotency: %w", err)
}
return val, nil
}
func (r *RedisCache) SetIdempotency(ctx context.Context, key, value string, ttl time.Duration) error {
redisKey := fmt.Sprintf("idempotency:%s", key)
return r.client.Set(ctx, redisKey, value, ttl).Err()
}
// ==================== Session缓存 ====================
// SessionData Session数据
type SessionData struct {
UserID int64 `json:"user_id"`
TenantID int64 `json:"tenant_id"`
Role string `json:"role"`
CreatedAt int64 `json:"created_at"`
}
// GetSession 获取Session
func (r *RedisCache) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
key := fmt.Sprintf("session:%s", sessionID)
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
var session SessionData
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
return &session, nil
}
// SetSession 设置Session
func (r *RedisCache) SetSession(ctx context.Context, sessionID string, session *SessionData, ttl time.Duration) error {
key := fmt.Sprintf("session:%s", sessionID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
return r.client.Set(ctx, key, data, ttl).Err()
}
// DeleteSession 删除Session
func (r *RedisCache) DeleteSession(ctx context.Context, sessionID string) error {
key := fmt.Sprintf("session:%s", sessionID)
return r.client.Del(ctx, key).Err()
}