Files
lijiaoqiao/supply-api/internal/cache/redis.go
Your Name ed0961d486 fix(supply-api): 修复编译错误和测试问题
- 添加 ErrNotFound 和 ErrConcurrencyConflict 错误定义
- 修复 pgx.NullTime 替换为 *time.Time
- 修复 db.go 事务类型 (pgx.Tx vs pgxpool.Tx)
- 移除未使用的导入和变量
- 修复 NewSupplyAPI 调用参数
- 修复中间件链路 handler 类型问题
- 修复适配器类型引用 (storage.InMemoryAccountStore 等)
- 所有测试通过

Test: go test ./...
2026-04-01 13:03:44 +08:00

232 lines
6.3 KiB
Go

package cache
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
"lijiaoqiao/supply-api/internal/config"
)
// RedisCache Redis缓存客户端
type RedisCache struct {
client *redis.Client
}
// NewRedisCache 创建Redis缓存客户端
func NewRedisCache(cfg config.RedisConfig) (*RedisCache, error) {
client := redis.NewClient(&redis.Options{
Addr: cfg.Addr(),
Password: cfg.Password,
DB: cfg.DB,
PoolSize: cfg.PoolSize,
})
// 验证连接
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to redis: %w", err)
}
return &RedisCache{client: client}, nil
}
// Close 关闭连接
func (r *RedisCache) Close() error {
return r.client.Close()
}
// HealthCheck 健康检查
func (r *RedisCache) HealthCheck(ctx context.Context) error {
return r.client.Ping(ctx).Err()
}
// ==================== Token状态缓存 ====================
// TokenStatus Token状态
type TokenStatus struct {
TokenID string `json:"token_id"`
SubjectID string `json:"subject_id"`
Role string `json:"role"`
Status string `json:"status"` // active, revoked, expired
ExpiresAt int64 `json:"expires_at"`
RevokedAt int64 `json:"revoked_at,omitempty"`
RevokedReason string `json:"revoked_reason,omitempty"`
}
// GetTokenStatus 获取Token状态
func (r *RedisCache) GetTokenStatus(ctx context.Context, tokenID string) (*TokenStatus, error) {
key := fmt.Sprintf("token:status:%s", tokenID)
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get token status: %w", err)
}
var status TokenStatus
if err := json.Unmarshal(data, &status); err != nil {
return nil, fmt.Errorf("failed to unmarshal token status: %w", err)
}
return &status, nil
}
// SetTokenStatus 设置Token状态
func (r *RedisCache) SetTokenStatus(ctx context.Context, status *TokenStatus, ttl time.Duration) error {
key := fmt.Sprintf("token:status:%s", status.TokenID)
data, err := json.Marshal(status)
if err != nil {
return fmt.Errorf("failed to marshal token status: %w", err)
}
return r.client.Set(ctx, key, data, ttl).Err()
}
// InvalidateToken 使Token失效
func (r *RedisCache) InvalidateToken(ctx context.Context, tokenID string) error {
key := fmt.Sprintf("token:status:%s", tokenID)
return r.client.Del(ctx, key).Err()
}
// ==================== 限流 ====================
// RateLimitKey 限流键
type RateLimitKey struct {
TenantID int64
Route string
LimitType string // rpm, rpd, concurrent
}
// GetRateLimit 获取限流计数
func (r *RedisCache) GetRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
count, err := r.client.Get(ctx, redisKey).Int64()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("failed to get rate limit: %w", err)
}
return count, nil
}
// IncrRateLimit 增加限流计数
func (r *RedisCache) IncrRateLimit(ctx context.Context, key *RateLimitKey, window time.Duration) (int64, error) {
redisKey := fmt.Sprintf("ratelimit:%d:%s:%s", key.TenantID, key.Route, key.LimitType)
pipe := r.client.Pipeline()
incrCmd := pipe.Incr(ctx, redisKey)
pipe.Expire(ctx, redisKey, window)
_, err := pipe.Exec(ctx)
if err != nil {
return 0, fmt.Errorf("failed to increment rate limit: %w", err)
}
return incrCmd.Val(), nil
}
// CheckRateLimit 检查限流
func (r *RedisCache) CheckRateLimit(ctx context.Context, key *RateLimitKey, limit int64, window time.Duration) (bool, int64, error) {
count, err := r.IncrRateLimit(ctx, key, window)
if err != nil {
return false, 0, err
}
return count <= limit, count, nil
}
// ==================== 分布式锁 ====================
// AcquireLock 获取分布式锁
func (r *RedisCache) AcquireLock(ctx context.Context, lockKey string, ttl time.Duration) (bool, error) {
redisKey := fmt.Sprintf("lock:%s", lockKey)
ok, err := r.client.SetNX(ctx, redisKey, "1", ttl).Result()
if err != nil {
return false, fmt.Errorf("failed to acquire lock: %w", err)
}
return ok, nil
}
// ReleaseLock 释放分布式锁
func (r *RedisCache) ReleaseLock(ctx context.Context, lockKey string) error {
redisKey := fmt.Sprintf("lock:%s", lockKey)
return r.client.Del(ctx, redisKey).Err()
}
// ==================== 幂等缓存 ====================
// IdempotencyCache 幂等缓存(短期)
func (r *RedisCache) GetIdempotency(ctx context.Context, key string) (string, error) {
redisKey := fmt.Sprintf("idempotency:%s", key)
val, err := r.client.Get(ctx, redisKey).Result()
if err == redis.Nil {
return "", nil
}
if err != nil {
return "", fmt.Errorf("failed to get idempotency: %w", err)
}
return val, nil
}
func (r *RedisCache) SetIdempotency(ctx context.Context, key, value string, ttl time.Duration) error {
redisKey := fmt.Sprintf("idempotency:%s", key)
return r.client.Set(ctx, redisKey, value, ttl).Err()
}
// ==================== Session缓存 ====================
// SessionData Session数据
type SessionData struct {
UserID int64 `json:"user_id"`
TenantID int64 `json:"tenant_id"`
Role string `json:"role"`
CreatedAt int64 `json:"created_at"`
}
// GetSession 获取Session
func (r *RedisCache) GetSession(ctx context.Context, sessionID string) (*SessionData, error) {
key := fmt.Sprintf("session:%s", sessionID)
data, err := r.client.Get(ctx, key).Bytes()
if err == redis.Nil {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get session: %w", err)
}
var session SessionData
if err := json.Unmarshal(data, &session); err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
return &session, nil
}
// SetSession 设置Session
func (r *RedisCache) SetSession(ctx context.Context, sessionID string, session *SessionData, ttl time.Duration) error {
key := fmt.Sprintf("session:%s", sessionID)
data, err := json.Marshal(session)
if err != nil {
return fmt.Errorf("failed to marshal session: %w", err)
}
return r.client.Set(ctx, key, data, ttl).Err()
}
// DeleteSession 删除Session
func (r *RedisCache) DeleteSession(ctx context.Context, sessionID string) error {
key := fmt.Sprintf("session:%s", sessionID)
return r.client.Del(ctx, key).Err()
}