Files
lijiaoqiao/supply-api/internal/middleware/idempotency.go
Your Name f34333dc09 fix: 修复代码审查中发现的P0/P1/P2问题
修复内容:
1. P0-01/P0-02: IAM Handler硬编码userID=1问题
   - getUserIDFromContext现在从认证中间件的context获取真实userID
   - 添加middleware.GetOperatorID公开函数
   - CheckScope方法添加未认证检查

2. P1-01: 审计服务幂等竞态条件
   - 重构锁保护范围,整个检查和插入过程在锁保护下
   - 使用defer确保锁正确释放

3. P1-02: 幂等中间件响应码硬编码
   - 添加statusCapturingResponseWriter包装器
   - 捕获实际的状态码和响应体用于幂等记录

4. P2-01: 事件ID时间戳冲突
   - generateEventID改用UUID替代时间戳

5. P2-02: ListScopes硬编码
   - 使用model.PredefinedScopes替代硬编码列表

所有supply-api测试通过
2026-04-03 12:25:22 +08:00

294 lines
8.6 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
"lijiaoqiao/supply-api/internal/repository"
)
// IdempotencyConfig 幂等中间件配置
type IdempotencyConfig struct {
TTL time.Duration // 幂等有效期默认24h
ProcessingTTL time.Duration // 处理中状态有效期默认30s
Enabled bool // 是否启用幂等
}
// IdempotencyMiddleware 幂等中间件
type IdempotencyMiddleware struct {
idempotencyRepo *repository.IdempotencyRepository
config IdempotencyConfig
}
// NewIdempotencyMiddleware 创建幂等中间件
func NewIdempotencyMiddleware(repo *repository.IdempotencyRepository, config IdempotencyConfig) *IdempotencyMiddleware {
if config.TTL == 0 {
config.TTL = 24 * time.Hour
}
if config.ProcessingTTL == 0 {
config.ProcessingTTL = 30 * time.Second
}
return &IdempotencyMiddleware{
idempotencyRepo: repo,
config: config,
}
}
// IdempotencyKey 幂等键信息
type IdempotencyKey struct {
TenantID int64
OperatorID int64
APIPath string
Key string
}
// ExtractIdempotencyKey 从请求中提取幂等信息
func ExtractIdempotencyKey(r *http.Request, tenantID, operatorID int64) (*IdempotencyKey, error) {
requestID := r.Header.Get("X-Request-Id")
if requestID == "" {
return nil, fmt.Errorf("missing X-Request-Id header")
}
idempotencyKey := r.Header.Get("Idempotency-Key")
if idempotencyKey == "" {
return nil, fmt.Errorf("missing Idempotency-Key header")
}
if len(idempotencyKey) < 16 || len(idempotencyKey) > 128 {
return nil, fmt.Errorf("Idempotency-Key length must be 16-128")
}
// 从路径提取API路径去除前缀
apiPath := r.URL.Path
if strings.HasPrefix(apiPath, "/api/v1") {
apiPath = strings.TrimPrefix(apiPath, "/api/v1")
}
return &IdempotencyKey{
TenantID: tenantID,
OperatorID: operatorID,
APIPath: apiPath,
Key: idempotencyKey,
}, nil
}
// ComputePayloadHash 计算请求体的SHA256哈希
func ComputePayloadHash(body []byte) string {
hash := sha256.Sum256(body)
return hex.EncodeToString(hash[:])
}
// IdempotentHandler 幂等处理器函数
type IdempotentHandler func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error
// Wrap 包装HTTP处理器以实现幂等
func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !m.config.Enabled {
handler(r.Context(), w, r, nil)
return
}
ctx := r.Context()
// 从context获取租户和操作者ID由鉴权中间件设置
tenantID := getTenantID(ctx)
operatorID := getOperatorID(ctx)
// 提取幂等信息
idempKey, err := ExtractIdempotencyKey(r, tenantID, operatorID)
if err != nil {
writeIdempotencyError(w, http.StatusBadRequest, "IDEMPOTENCY_KEY_INVALID", err.Error())
return
}
// 读取请求体
body, err := io.ReadAll(r.Body)
if err != nil {
writeIdempotencyError(w, http.StatusBadRequest, "BODY_READ_ERROR", err.Error())
return
}
// 重新填充body以供后续处理
r.Body = io.NopCloser(bytes.NewBuffer(body))
// 计算payload hash
payloadHash := ComputePayloadHash(body)
// 查询已存在的幂等记录
existingRecord, err := m.idempotencyRepo.GetByKey(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key)
if err != nil {
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_CHECK_FAILED", err.Error())
return
}
if existingRecord != nil {
// 存在记录,处理不同情况
switch existingRecord.Status {
case repository.IdempotencyStatusSucceeded:
// 同参重放:返回原结果
if existingRecord.PayloadHash == payloadHash {
writeIdempotentReplay(w, existingRecord.ResponseCode, existingRecord.ResponseBody)
return
}
// 异参重放返回409冲突
writeIdempotencyError(w, http.StatusConflict, "IDEMPOTENCY_PAYLOAD_MISMATCH",
fmt.Sprintf("same idempotency key but different payload, original request_id: %s", existingRecord.RequestID))
return
case repository.IdempotencyStatusProcessing:
// 处理中:检查是否超时
if time.Since(existingRecord.UpdatedAt) < m.config.ProcessingTTL {
retryAfter := m.config.ProcessingTTL - time.Since(existingRecord.UpdatedAt)
writeIdempotencyProcessing(w, int(retryAfter.Milliseconds()), existingRecord.RequestID)
return
}
// 超时:允许重试(记录会自然过期)
case repository.IdempotencyStatusFailed:
// 失败状态也允许重试
}
}
// 使用AcquireLock获取锁
requestID := r.Header.Get("X-Request-Id")
lockedRecord, err := m.idempotencyRepo.AcquireLock(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key, m.config.TTL)
if err != nil {
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_LOCK_FAILED", err.Error())
return
}
// 更新记录中的request_id和payload_hash
if lockedRecord.ID != 0 && (lockedRecord.RequestID == "" || lockedRecord.PayloadHash == "") {
lockedRecord.RequestID = requestID
lockedRecord.PayloadHash = payloadHash
}
// 创建包装器以捕获实际的状态码和响应体
wrappedWriter := &statusCapturingResponseWriter{ResponseWriter: w}
// 执行实际业务处理,使用包装器捕获响应
err = handler(ctx, wrappedWriter, r, lockedRecord)
// 根据处理结果更新幂等记录
if err != nil {
// 业务处理失败
errMsg, _ := json.Marshal(map[string]string{"error": err.Error()})
_ = m.idempotencyRepo.UpdateFailed(ctx, lockedRecord.ID, http.StatusInternalServerError, errMsg)
return
}
// 业务处理成功使用捕获的实际状态码和body更新幂等记录
successBody := wrappedWriter.body
if len(successBody) == 0 {
successBody, _ = json.Marshal(map[string]interface{}{"status": "ok"})
}
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, wrappedWriter.statusCode, successBody)
}
}
// writeIdempotencyError 写入幂等错误
func writeIdempotencyError(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
resp := map[string]interface{}{
"request_id": "",
"error": map[string]string{
"code": code,
"message": message,
},
}
json.NewEncoder(w).Encode(resp)
}
// writeIdempotencyProcessing 写入处理中状态
func writeIdempotencyProcessing(w http.ResponseWriter, retryAfterMs int, requestID string) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After-Ms", fmt.Sprintf("%d", retryAfterMs))
w.Header().Set("X-Request-Id", requestID)
w.WriteHeader(http.StatusAccepted)
resp := map[string]interface{}{
"request_id": requestID,
"error": map[string]string{
"code": "IDEMPOTENCY_IN_PROGRESS",
"message": "request is being processed, please retry later",
},
}
json.NewEncoder(w).Encode(resp)
}
// writeIdempotentReplay 写入幂等重放响应
func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessage) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("X-Idempotent-Replay", "true")
w.WriteHeader(status)
if body != nil {
w.Write(body)
}
}
// statusCapturingResponseWriter 包装http.ResponseWriter以捕获状态码
type statusCapturingResponseWriter struct {
http.ResponseWriter
statusCode int
body []byte
}
func (w *statusCapturingResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
w.ResponseWriter.WriteHeader(statusCode)
}
func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) {
w.body = append(w.body, b...)
return w.ResponseWriter.Write(b)
}
// context keys
type contextKey string
const (
tenantIDKey contextKey = "tenant_id"
operatorIDKey contextKey = "operator_id"
)
// WithTenantID 在context中设置租户ID
func WithTenantID(ctx context.Context, tenantID int64) context.Context {
return context.WithValue(ctx, tenantIDKey, tenantID)
}
// WithOperatorID 在context中设置操作者ID
func WithOperatorID(ctx context.Context, operatorID int64) context.Context {
return context.WithValue(ctx, operatorIDKey, operatorID)
}
func getTenantID(ctx context.Context) int64 {
if v := ctx.Value(tenantIDKey); v != nil {
if id, ok := v.(int64); ok {
return id
}
}
return 0
}
func getOperatorID(ctx context.Context) int64 {
if v := ctx.Value(operatorIDKey); v != nil {
if id, ok := v.(int64); ok {
return id
}
}
return 0
}
// GetOperatorID 公开函数从context获取操作者ID
func GetOperatorID(ctx context.Context) int64 {
return getOperatorID(ctx)
}