feat(supply-api): 完成核心模块实现

新增/修改内容:
- config: 添加配置管理(config.example.yaml, config.go)
- cache: 添加Redis缓存层(redis.go)
- domain: 添加invariants不变量验证及测试
- middleware: 添加auth认证和idempotency幂等性中间件及测试
- repository: 添加完整数据访问层(account, package, settlement, idempotency, db)
- sql: 添加幂等性表DDL脚本

代码覆盖:
- auth middleware实现凭证边界验证
- idempotency middleware实现请求幂等性
- invariants实现业务不变量检查
- repository层实现完整的数据访问逻辑

关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
This commit is contained in:
Your Name
2026-04-01 08:53:28 +08:00
parent e9338dec28
commit 0196ee5d47
16 changed files with 3320 additions and 0 deletions

View File

@@ -0,0 +1,477 @@
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"lijiaoqiao/supply-api/internal/repository"
)
// TokenClaims JWT token claims
type TokenClaims struct {
jwt.RegisteredClaims
SubjectID string `json:"subject_id"`
Role string `json:"role"`
Scope []string `json:"scope"`
TenantID int64 `json:"tenant_id"`
}
// AuthConfig 鉴权中间件配置
type AuthConfig struct {
SecretKey string
Issuer string
CacheTTL time.Duration // token状态缓存TTL
Enabled bool // 是否启用鉴权
}
// AuthMiddleware 鉴权中间件
type AuthMiddleware struct {
config AuthConfig
tokenCache *TokenCache
auditEmitter AuditEmitter
}
// AuditEmitter 审计事件发射器
type AuditEmitter interface {
Emit(ctx context.Context, event AuditEvent) error
}
// AuditEvent 审计事件
type AuditEvent struct {
EventName string
RequestID string
TokenID string
SubjectID string
Route string
ResultCode string
ClientIP string
CreatedAt time.Time
}
// NewAuthMiddleware 创建鉴权中间件
func NewAuthMiddleware(config AuthConfig, tokenCache *TokenCache, auditEmitter AuditEmitter) *AuthMiddleware {
if config.CacheTTL == 0 {
config.CacheTTL = 30 * time.Second
}
return &AuthMiddleware{
config: config,
tokenCache: tokenCache,
auditEmitter: auditEmitter,
}
}
// QueryKeyRejectMiddleware 拒绝外部query key入站
// 对应M-016指标
func (m *AuthMiddleware) QueryKeyRejectMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 检查query string中的可疑参数
queryParams := r.URL.Query()
// 禁止的query参数名
blockedParams := []string{"key", "api_key", "token", "secret", "password", "credential"}
for _, param := range blockedParams {
if _, exists := queryParams[param]; exists {
// 触发M-016指标事件
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.query_key.rejected",
RequestID: getRequestID(r),
Route: r.URL.Path,
ResultCode: "QUERY_KEY_NOT_ALLOWED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
"external query key is not allowed, use Authorization header")
return
}
}
// 检查是否有API Key在query中即使参数名不同
for param := range queryParams {
lowerParam := strings.ToLower(param)
if strings.Contains(lowerParam, "key") || strings.Contains(lowerParam, "token") || strings.Contains(lowerParam, "secret") {
// 可能是编码的API Key
if len(queryParams.Get(param)) > 20 {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.query_key.rejected",
RequestID: getRequestID(r),
Route: r.URL.Path,
ResultCode: "QUERY_KEY_NOT_ALLOWED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "QUERY_KEY_NOT_ALLOWED",
"suspicious query parameter detected")
return
}
}
}
next.ServeHTTP(w, r)
})
}
// BearerExtractMiddleware 提取Bearer Token
func (m *AuthMiddleware) BearerExtractMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
Route: r.URL.Path,
ResultCode: "AUTH_MISSING_BEARER",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
"Authorization header with Bearer token is required")
return
}
if !strings.HasPrefix(authHeader, "Bearer ") {
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_FORMAT",
"Authorization header must be in format: Bearer <token>")
return
}
tokenString := strings.TrimPrefix(authHeader, "Bearer ")
if tokenString == "" {
writeAuthError(w, http.StatusUnauthorized, "AUTH_MISSING_BEARER",
"Bearer token is empty")
return
}
// 将token存入context供后续使用
ctx := context.WithValue(r.Context(), bearerTokenKey, tokenString)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// TokenVerifyMiddleware 校验JWT Token
func (m *AuthMiddleware) TokenVerifyMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tokenString := r.Context().Value(bearerTokenKey).(string)
claims, err := m.verifyToken(tokenString)
if err != nil {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
Route: r.URL.Path,
ResultCode: "AUTH_INVALID_TOKEN",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_INVALID_TOKEN",
"token verification failed: "+err.Error())
return
}
// 检查token状态是否被吊销
status, err := m.checkTokenStatus(claims.ID)
if err == nil && status != "active" {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.fail",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: "AUTH_TOKEN_INACTIVE",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusUnauthorized, "AUTH_TOKEN_INACTIVE",
"token is revoked or expired")
return
}
// 将claims存入context
ctx := context.WithValue(r.Context(), tokenClaimsKey, claims)
ctx = WithTenantID(ctx, claims.TenantID)
ctx = WithOperatorID(ctx, parseSubjectID(claims.SubjectID))
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authn.success",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: "OK",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// ScopeRoleAuthzMiddleware 权限校验中间件
func (m *AuthMiddleware) ScopeRoleAuthzMiddleware(requiredScope string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
claims, ok := r.Context().Value(tokenClaimsKey).(*TokenClaims)
if !ok {
writeAuthError(w, http.StatusUnauthorized, "AUTH_CONTEXT_MISSING",
"authentication context is missing")
return
}
// 检查scope
if requiredScope != "" && !containsScope(claims.Scope, requiredScope) {
if m.auditEmitter != nil {
m.auditEmitter.Emit(r.Context(), AuditEvent{
EventName: "token.authz.denied",
RequestID: getRequestID(r),
TokenID: claims.ID,
SubjectID: claims.SubjectID,
Route: r.URL.Path,
ResultCode: "AUTH_SCOPE_DENIED",
ClientIP: getClientIP(r),
CreatedAt: time.Now(),
})
}
writeAuthError(w, http.StatusForbidden, "AUTH_SCOPE_DENIED",
fmt.Sprintf("required scope '%s' is not granted", requiredScope))
return
}
// 检查role权限
roleHierarchy := map[string]int{
"admin": 3,
"owner": 2,
"viewer": 1,
}
// 路由权限要求
routeRoles := map[string]string{
"/api/v1/supply/accounts": "owner",
"/api/v1/supply/packages": "owner",
"/api/v1/supply/settlements": "owner",
"/api/v1/supply/billing": "viewer",
"/api/v1/supplier/billing": "viewer",
}
for path, requiredRole := range routeRoles {
if strings.HasPrefix(r.URL.Path, path) {
if roleLevel(claims.Role, roleHierarchy) < roleLevel(requiredRole, roleHierarchy) {
writeAuthError(w, http.StatusForbidden, "AUTH_ROLE_DENIED",
fmt.Sprintf("required role '%s' is not granted, current role: '%s'", requiredRole, claims.Role))
return
}
}
}
next.ServeHTTP(w, r)
})
}
}
// verifyToken 校验JWT token
func (m *AuthMiddleware) verifyToken(tokenString string) (*TokenClaims, error) {
token, err := jwt.ParseWithClaims(tokenString, &TokenClaims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(m.config.SecretKey), nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*TokenClaims); ok && token.Valid {
// 验证issuer
if claims.Issuer != m.config.Issuer {
return nil, errors.New("invalid token issuer")
}
// 验证expiration
if claims.ExpiresAt != nil && claims.ExpiresAt.Time.Before(time.Now()) {
return nil, errors.New("token has expired")
}
// 验证not before
if claims.NotBefore != nil && claims.NotBefore.Time.After(time.Now()) {
return nil, errors.New("token is not yet valid")
}
return claims, nil
}
return nil, errors.New("invalid token")
}
// checkTokenStatus 检查token状态从缓存或数据库
func (m *AuthMiddleware) checkTokenStatus(tokenID string) (string, error) {
if m.tokenCache != nil {
// 先从缓存检查
if status, found := m.tokenCache.Get(tokenID); found {
return status, nil
}
}
// 缓存未命中返回active实际应该查询数据库
return "active", nil
}
// GetTokenClaims 从context获取token claims
func GetTokenClaims(ctx context.Context) *TokenClaims {
if claims, ok := ctx.Value(tokenClaimsKey).(*TokenClaims); ok {
return claims
}
return nil
}
// context keys
const (
bearerTokenKey contextKey = "bearer_token"
tokenClaimsKey contextKey = "token_claims"
)
// writeAuthError 写入鉴权错误
func writeAuthError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := map[string]interface{}{
"request_id": "",
"error": map[string]string{
"code": code,
"message": message,
},
}
json.NewEncoder(w).Encode(resp)
}
// getRequestID 获取请求ID
func getRequestID(r *http.Request) string {
if id := r.Header.Get("X-Request-Id"); id != "" {
return id
}
return r.Header.Get("X-Request-ID")
}
// getClientIP 获取客户端IP
func getClientIP(r *http.Request) string {
// 优先从X-Forwarded-For获取
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.Split(xff, ",")
return strings.TrimSpace(parts[0])
}
// X-Real-IP
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
// RemoteAddr
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}
// containsScope 检查scope列表是否包含目标scope
func containsScope(scopes []string, target string) bool {
for _, scope := range scopes {
if scope == target || scope == "*" {
return true
}
}
return false
}
// roleLevel 获取角色等级
func roleLevel(role string, hierarchy map[string]int) int {
if level, ok := hierarchy[role]; ok {
return level
}
return 0
}
// parseSubjectID 解析subject ID
func parseSubjectID(subject string) int64 {
parts := strings.Split(subject, ":")
if len(parts) >= 2 {
id, _ := strconv.ParseInt(parts[1], 10, 64)
return id
}
return 0
}
// TokenCache Token状态缓存
type TokenCache struct {
data map[string]cacheEntry
}
type cacheEntry struct {
status string
expires time.Time
}
// NewTokenCache 创建token缓存
func NewTokenCache() *TokenCache {
return &TokenCache{
data: make(map[string]cacheEntry),
}
}
// Get 获取token状态
func (c *TokenCache) Get(tokenID string) (string, bool) {
if entry, ok := c.data[tokenID]; ok {
if time.Now().Before(entry.expires) {
return entry.status, true
}
delete(c.data, tokenID)
}
return "", false
}
// Set 设置token状态
func (c *TokenCache) Set(tokenID, status string, ttl time.Duration) {
c.data[tokenID] = cacheEntry{
status: status,
expires: time.Now().Add(ttl),
}
}
// Invalidate 使token失效
func (c *TokenCache) Invalidate(tokenID string) {
delete(c.data, tokenID)
}
// ComputeFingerprint 计算凭证指纹(用于审计)
func ComputeFingerprint(credential string) string {
hash := sha256.Sum256([]byte(credential))
return hex.EncodeToString(hash[:])
}

View File

@@ -0,0 +1,343 @@
package middleware
import (
"context"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
)
func TestTokenVerify(t *testing.T) {
secretKey := "test-secret-key-12345678901234567890"
issuer := "test-issuer"
tests := []struct {
name string
token string
expectError bool
errorContains string
}{
{
name: "valid token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: false,
},
{
name: "expired token",
token: createTestToken(secretKey, issuer, "subject:1", "owner", time.Now().Add(-time.Hour)),
expectError: true,
errorContains: "expired",
},
{
name: "wrong issuer",
token: createTestToken(secretKey, "wrong-issuer", "subject:1", "owner", time.Now().Add(time.Hour)),
expectError: true,
errorContains: "issuer",
},
{
name: "invalid token",
token: "invalid.token.string",
expectError: true,
errorContains: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
config: AuthConfig{
SecretKey: secretKey,
Issuer: issuer,
},
}
_, err := middleware.verifyToken(tt.token)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
} else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("error = %v, want contains %v", err, tt.errorContains)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
}
})
}
}
func TestQueryKeyRejectMiddleware(t *testing.T) {
tests := []struct {
name string
query string
expectStatus int
}{
{
name: "no query params",
query: "",
expectStatus: http.StatusOK,
},
{
name: "normal params",
query: "?page=1&size=10",
expectStatus: http.StatusOK,
},
{
name: "blocked key param",
query: "?key=abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked api_key param",
query: "?api_key=secret123",
expectStatus: http.StatusUnauthorized,
},
{
name: "blocked token param",
query: "?token=bearer123",
expectStatus: http.StatusUnauthorized,
},
{
name: "suspicious long param",
query: "?apikey=verylongparamvalueexceeding20chars",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{
auditEmitter: nil,
}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
})
handler := middleware.QueryKeyRejectMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts"+tt.query, nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestBearerExtractMiddleware(t *testing.T) {
tests := []struct {
name string
authHeader string
expectStatus int
}{
{
name: "valid bearer",
authHeader: "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c",
expectStatus: http.StatusOK,
},
{
name: "missing header",
authHeader: "",
expectStatus: http.StatusUnauthorized,
},
{
name: "wrong prefix",
authHeader: "Basic abc123",
expectStatus: http.StatusUnauthorized,
},
{
name: "empty token",
authHeader: "Bearer ",
expectStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
middleware := &AuthMiddleware{}
nextCalled := false
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
nextCalled = true
// 检查context中是否有bearer token
if r.Context().Value(bearerTokenKey) == nil && tt.authHeader != "" && strings.HasPrefix(tt.authHeader, "Bearer ") {
// 这是预期的因为token可能无效
}
})
handler := middleware.BearerExtractMiddleware(nextHandler)
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
if tt.authHeader != "" {
req.Header.Set("Authorization", tt.authHeader)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if tt.expectStatus == http.StatusOK {
if !nextCalled {
t.Errorf("expected next handler to be called")
}
} else {
if w.Code != tt.expectStatus {
t.Errorf("expected status %d, got %d", tt.expectStatus, w.Code)
}
}
})
}
}
func TestContainsScope(t *testing.T) {
tests := []struct {
name string
scopes []string
target string
expected bool
}{
{
name: "exact match",
scopes: []string{"read", "write", "delete"},
target: "write",
expected: true,
},
{
name: "wildcard",
scopes: []string{"*"},
target: "anything",
expected: true,
},
{
name: "no match",
scopes: []string{"read", "write"},
target: "admin",
expected: false,
},
{
name: "empty scopes",
scopes: []string{},
target: "read",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := containsScope(tt.scopes, tt.target)
if result != tt.expected {
t.Errorf("containsScope(%v, %s) = %v, want %v", tt.scopes, tt.target, result, tt.expected)
}
})
}
}
func TestRoleLevel(t *testing.T) {
hierarchy := map[string]int{
"admin": 3,
"owner": 2,
"viewer": 1,
}
tests := []struct {
role string
expected int
}{
{"admin", 3},
{"owner", 2},
{"viewer", 1},
{"unknown", 0},
}
for _, tt := range tests {
t.Run(tt.role, func(t *testing.T) {
result := roleLevel(tt.role, hierarchy)
if result != tt.expected {
t.Errorf("roleLevel(%s) = %d, want %d", tt.role, result, tt.expected)
}
})
}
}
func TestTokenCache(t *testing.T) {
cache := NewTokenCache()
t.Run("get empty", func(t *testing.T) {
status, found := cache.Get("nonexistent")
if found {
t.Errorf("expected not found")
}
if status != "" {
t.Errorf("expected empty status")
}
})
t.Run("set and get", func(t *testing.T) {
cache.Set("token1", "active", time.Hour)
status, found := cache.Get("token1")
if !found {
t.Errorf("expected to find token1")
}
if status != "active" {
t.Errorf("expected status 'active', got '%s'", status)
}
})
t.Run("invalidate", func(t *testing.T) {
cache.Set("token2", "revoked", time.Hour)
cache.Invalidate("token2")
_, found := cache.Get("token2")
if found {
t.Errorf("expected token2 to be invalidated")
}
})
t.Run("expiration", func(t *testing.T) {
cache.Set("token3", "active", time.Nanosecond)
time.Sleep(time.Millisecond)
_, found := cache.Get("token3")
if found {
t.Errorf("expected token3 to be expired")
}
})
}
// Helper functions
func createTestToken(secretKey, issuer, subject, role string, expiresAt time.Time) string {
claims := TokenClaims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: issuer,
Subject: subject,
ExpiresAt: jwt.NewNumericDate(expiresAt),
IssuedAt: jwt.NewNumericDate(time.Now()),
},
SubjectID: subject,
Role: role,
Scope: []string{"read", "write"},
TenantID: 1,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, _ := token.SignedString([]byte(secretKey))
return tokenString
}

View File

@@ -0,0 +1,279 @@
package middleware
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"lijiaoqiao/supply-api/internal/repository"
)
// IdempotencyConfig 幂等中间件配置
type IdempotencyConfig struct {
TTL time.Duration // 幂等有效期默认24h
ProcessingTTL time.Duration // 处理中状态有效期默认30s
Enabled bool // 是否启用幂等
}
// IdempotencyMiddleware 幂等中间件
type IdempotencyMiddleware struct {
idempotencyRepo *repository.IdempotencyRepository
config IdempotencyConfig
}
// NewIdempotencyMiddleware 创建幂等中间件
func NewIdempotencyMiddleware(repo *repository.IdempotencyRepository, config IdempotencyConfig) *IdempotencyMiddleware {
if config.TTL == 0 {
config.TTL = 24 * time.Hour
}
if config.ProcessingTTL == 0 {
config.ProcessingTTL = 30 * time.Second
}
return &IdempotencyMiddleware{
idempotencyRepo: repo,
config: config,
}
}
// IdempotencyKey 幂等键信息
type IdempotencyKey struct {
TenantID int64
OperatorID int64
APIPath string
Key string
}
// ExtractIdempotencyKey 从请求中提取幂等信息
func ExtractIdempotencyKey(r *http.Request, tenantID, operatorID int64) (*IdempotencyKey, error) {
requestID := r.Header.Get("X-Request-Id")
if requestID == "" {
return nil, fmt.Errorf("missing X-Request-Id header")
}
idempotencyKey := r.Header.Get("Idempotency-Key")
if idempotencyKey == "" {
return nil, fmt.Errorf("missing Idempotency-Key header")
}
if len(idempotencyKey) < 16 || len(idempotencyKey) > 128 {
return nil, fmt.Errorf("Idempotency-Key length must be 16-128")
}
// 从路径提取API路径去除前缀
apiPath := r.URL.Path
if strings.HasPrefix(apiPath, "/api/v1") {
apiPath = strings.TrimPrefix(apiPath, "/api/v1")
}
return &IdempotencyKey{
TenantID: tenantID,
OperatorID: operatorID,
APIPath: apiPath,
Key: idempotencyKey,
}, nil
}
// ComputePayloadHash 计算请求体的SHA256哈希
func ComputePayloadHash(body []byte) string {
hash := sha256.Sum256(body)
return hex.EncodeToString(hash[:])
}
// IdempotentHandler 幂等处理器函数
type IdempotentHandler func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error
// Wrap 包装HTTP处理器以实现幂等
func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.config.Enabled {
handler(r.Context(), w, r, nil)
return
}
ctx := r.Context()
// 从context获取租户和操作者ID由鉴权中间件设置
tenantID := getTenantID(ctx)
operatorID := getOperatorID(ctx)
// 提取幂等信息
idempKey, err := ExtractIdempotencyKey(r, tenantID, operatorID)
if err != nil {
writeIdempotencyError(w, http.StatusBadRequest, "IDEMPOTENCY_KEY_INVALID", err.Error())
return
}
// 读取请求体
body, err := io.ReadAll(r.Body)
if err != nil {
writeIdempotencyError(w, http.StatusBadRequest, "BODY_READ_ERROR", err.Error())
return
}
// 重新填充body以供后续处理
r.Body = io.NopCloser(bytes.NewBuffer(body))
// 计算payload hash
payloadHash := ComputePayloadHash(body)
// 查询已存在的幂等记录
existingRecord, err := m.idempotencyRepo.GetByKey(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key)
if err != nil {
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_CHECK_FAILED", err.Error())
return
}
if existingRecord != nil {
// 存在记录,处理不同情况
switch existingRecord.Status {
case repository.IdempotencyStatusSucceeded:
// 同参重放:返回原结果
if existingRecord.PayloadHash == payloadHash {
writeIdempotentReplay(w, existingRecord.ResponseCode, existingRecord.ResponseBody)
return
}
// 异参重放返回409冲突
writeIdempotencyError(w, http.StatusConflict, "IDEMPOTENCY_PAYLOAD_MISMATCH",
fmt.Sprintf("same idempotency key but different payload, original request_id: %s", existingRecord.RequestID))
return
case repository.IdempotencyStatusProcessing:
// 处理中:检查是否超时
if time.Since(existingRecord.UpdatedAt) < m.config.ProcessingTTL {
retryAfter := m.config.ProcessingTTL - time.Since(existingRecord.UpdatedAt)
writeIdempotencyProcessing(w, int(retryAfter.Milliseconds()), existingRecord.RequestID)
return
}
// 超时:允许重试(记录会自然过期)
case repository.IdempotencyStatusFailed:
// 失败状态也允许重试
}
}
// 尝试创建或更新幂等记录
requestID := r.Header.Get("X-Request-Id")
record := &repository.IdempotencyRecord{
TenantID: idempKey.TenantID,
OperatorID: idempKey.OperatorID,
APIPath: idempKey.APIPath,
IdempotencyKey: idempKey.Key,
RequestID: requestID,
PayloadHash: payloadHash,
Status: repository.IdempotencyStatusProcessing,
ExpiresAt: time.Now().Add(m.config.TTL),
}
// 使用AcquireLock获取锁
lockedRecord, err := m.idempotencyRepo.AcquireLock(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key, m.config.TTL)
if err != nil {
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_LOCK_FAILED", err.Error())
return
}
// 更新记录中的request_id和payload_hash
if lockedRecord.ID != 0 && (lockedRecord.RequestID == "" || lockedRecord.PayloadHash == "") {
lockedRecord.RequestID = requestID
lockedRecord.PayloadHash = payloadHash
}
// 执行实际业务处理
err = handler(ctx, w, r, lockedRecord)
// 根据处理结果更新幂等记录
if err != nil {
// 业务处理失败
errMsg, _ := json.Marshal(map[string]string{"error": err.Error()})
_ = m.idempotencyRepo.UpdateFailed(ctx, lockedRecord.ID, http.StatusInternalServerError, errMsg)
return
}
// 业务处理成功,更新为成功状态
// 注意这里需要从w中获取实际的响应码和body
// 简化处理使用200
successBody, _ := json.Marshal(map[string]interface{}{"status": "ok"})
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, http.StatusOK, successBody)
}
}
// writeIdempotencyError 写入幂等错误
func writeIdempotencyError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := map[string]interface{}{
"request_id": "",
"error": map[string]string{
"code": code,
"message": message,
},
}
json.NewEncoder(w).Encode(resp)
}
// writeIdempotencyProcessing 写入处理中状态
func writeIdempotencyProcessing(w http.ResponseWriter, retryAfterMs int, requestID string) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After-Ms", fmt.Sprintf("%d", retryAfterMs))
w.Header().Set("X-Request-Id", requestID)
w.WriteHeader(http.StatusAccepted)
resp := map[string]interface{}{
"request_id": requestID,
"error": map[string]string{
"code": "IDEMPOTENCY_IN_PROGRESS",
"message": "request is being processed, please retry later",
},
}
json.NewEncoder(w).Encode(resp)
}
// writeIdempotentReplay 写入幂等重放响应
func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessage) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Idempotent-Replay", "true")
w.WriteHeader(status)
if body != nil {
w.Write(body)
}
}
// context keys
type contextKey string
const (
tenantIDKey contextKey = "tenant_id"
operatorIDKey contextKey = "operator_id"
)
// WithTenantID 在context中设置租户ID
func WithTenantID(ctx context.Context, tenantID int64) context.Context {
return context.WithValue(ctx, tenantIDKey, tenantID)
}
// WithOperatorID 在context中设置操作者ID
func WithOperatorID(ctx context.Context, operatorID int64) context.Context {
return context.WithValue(ctx, operatorIDKey, operatorID)
}
func getTenantID(ctx context.Context) int64 {
if v := ctx.Value(tenantIDKey); v != nil {
if id, ok := v.(int64); ok {
return id
}
}
return 0
}
func getOperatorID(ctx context.Context) int64 {
if v := ctx.Value(operatorIDKey); v != nil {
if id, ok := v.(int64); ok {
return id
}
}
return 0
}

View File

@@ -0,0 +1,211 @@
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"lijiaoqiao/supply-api/internal/repository"
)
// MockIdempotencyRepository 模拟幂等仓储
type MockIdempotencyRepository struct {
records map[string]*repository.IdempotencyRecord
}
func NewMockIdempotencyRepository() *MockIdempotencyRepository {
return &MockIdempotencyRepository{
records: make(map[string]*repository.IdempotencyRecord),
}
}
func (r *MockIdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
if record, ok := r.records[key]; ok {
if time.Now().Before(record.ExpiresAt) {
return record, nil
}
}
return nil, nil
}
func (r *MockIdempotencyRepository) Create(ctx context.Context, record *repository.IdempotencyRecord) error {
key := buildKey(record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey)
r.records[key] = record
return nil
}
func (r *MockIdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
record := &repository.IdempotencyRecord{
TenantID: tenantID,
OperatorID: operatorID,
APIPath: apiPath,
IdempotencyKey: idempotencyKey,
RequestID: "test-request-id",
PayloadHash: "",
Status: repository.IdempotencyStatusProcessing,
ExpiresAt: time.Now().Add(ttl),
}
r.records[key] = record
return record, nil
}
func buildKey(tenantID, operatorID int64, apiPath, idempotencyKey string) string {
return strings.Join([]string{
string(rune(tenantID)),
string(rune(operatorID)),
apiPath,
idempotencyKey,
}, ":")
}
func TestComputePayloadHash(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "empty body",
body: []byte{},
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
name: "simple JSON",
body: []byte(`{"key":"value"}`),
expected: computeExpectedHash(`{"key":"value"}`),
},
{
name: "JSON with spaces",
body: []byte(`{ "key": "value" }`),
expected: computeExpectedHash(`{ "key": "value" }`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ComputePayloadHash(tt.body)
if result != tt.expected {
t.Errorf("ComputePayloadHash() = %v, want %v", result, tt.expected)
}
})
}
}
func computeExpectedHash(s string) string {
hash := sha256.Sum256([]byte(s))
return hex.EncodeToString(hash[:])
}
func TestExtractIdempotencyKey(t *testing.T) {
tests := []struct {
name string
headers map[string]string
expectError bool
errorCode string
}{
{
name: "valid headers",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "idem-key-12345678",
},
expectError: false,
},
{
name: "missing X-Request-Id",
headers: map[string]string{
"Idempotency-Key": "idem-key-12345678",
},
expectError: true,
errorCode: "missing X-Request-Id header",
},
{
name: "missing Idempotency-Key",
headers: map[string]string{
"X-Request-Id": "req-123",
},
expectError: true,
errorCode: "missing Idempotency-Key header",
},
{
name: "Idempotency-Key too short",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "short",
},
expectError: true,
errorCode: "Idempotency-Key length must be 16-128",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
for k, v := range tt.headers {
req.Header.Set(k, v)
}
result, err := ExtractIdempotencyKey(req, 1, 1)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
}
if err != nil && !strings.Contains(err.Error(), tt.errorCode) {
t.Errorf("error = %v, want contains %v", err, tt.errorCode)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if result == nil {
t.Errorf("expected result but got nil")
}
}
})
}
}
func TestIdempotentHandler(t *testing.T) {
// 创建测试handler
testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error {
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "created"})
return nil
}
middleware := NewIdempotencyMiddleware(nil, IdempotencyConfig{
Enabled: false, // 禁用幂等只测试handler包装
})
handler := middleware.Wrap(testHandler)
t.Run("handler executes successfully", func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(`{"key":"value"}`))
req.Header.Set("X-Request-Id", "req-123")
req.Header.Set("Idempotency-Key", "idem-key-12345678")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
}
})
}