修复内容: 1. P0-01/P0-02: IAM Handler硬编码userID=1问题 - getUserIDFromContext现在从认证中间件的context获取真实userID - 添加middleware.GetOperatorID公开函数 - CheckScope方法添加未认证检查 2. P1-01: 审计服务幂等竞态条件 - 重构锁保护范围,整个检查和插入过程在锁保护下 - 使用defer确保锁正确释放 3. P1-02: 幂等中间件响应码硬编码 - 添加statusCapturingResponseWriter包装器 - 捕获实际的状态码和响应体用于幂等记录 4. P2-01: 事件ID时间戳冲突 - generateEventID改用UUID替代时间戳 5. P2-02: ListScopes硬编码 - 使用model.PredefinedScopes替代硬编码列表 所有supply-api测试通过
294 lines
8.6 KiB
Go
294 lines
8.6 KiB
Go
package middleware
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"lijiaoqiao/supply-api/internal/repository"
|
||
)
|
||
|
||
// IdempotencyConfig 幂等中间件配置
|
||
type IdempotencyConfig struct {
|
||
TTL time.Duration // 幂等有效期,默认24h
|
||
ProcessingTTL time.Duration // 处理中状态有效期,默认30s
|
||
Enabled bool // 是否启用幂等
|
||
}
|
||
|
||
// IdempotencyMiddleware 幂等中间件
|
||
type IdempotencyMiddleware struct {
|
||
idempotencyRepo *repository.IdempotencyRepository
|
||
config IdempotencyConfig
|
||
}
|
||
|
||
// NewIdempotencyMiddleware 创建幂等中间件
|
||
func NewIdempotencyMiddleware(repo *repository.IdempotencyRepository, config IdempotencyConfig) *IdempotencyMiddleware {
|
||
if config.TTL == 0 {
|
||
config.TTL = 24 * time.Hour
|
||
}
|
||
if config.ProcessingTTL == 0 {
|
||
config.ProcessingTTL = 30 * time.Second
|
||
}
|
||
return &IdempotencyMiddleware{
|
||
idempotencyRepo: repo,
|
||
config: config,
|
||
}
|
||
}
|
||
|
||
// IdempotencyKey 幂等键信息
|
||
type IdempotencyKey struct {
|
||
TenantID int64
|
||
OperatorID int64
|
||
APIPath string
|
||
Key string
|
||
}
|
||
|
||
// ExtractIdempotencyKey 从请求中提取幂等信息
|
||
func ExtractIdempotencyKey(r *http.Request, tenantID, operatorID int64) (*IdempotencyKey, error) {
|
||
requestID := r.Header.Get("X-Request-Id")
|
||
if requestID == "" {
|
||
return nil, fmt.Errorf("missing X-Request-Id header")
|
||
}
|
||
|
||
idempotencyKey := r.Header.Get("Idempotency-Key")
|
||
if idempotencyKey == "" {
|
||
return nil, fmt.Errorf("missing Idempotency-Key header")
|
||
}
|
||
|
||
if len(idempotencyKey) < 16 || len(idempotencyKey) > 128 {
|
||
return nil, fmt.Errorf("Idempotency-Key length must be 16-128")
|
||
}
|
||
|
||
// 从路径提取API路径(去除前缀)
|
||
apiPath := r.URL.Path
|
||
if strings.HasPrefix(apiPath, "/api/v1") {
|
||
apiPath = strings.TrimPrefix(apiPath, "/api/v1")
|
||
}
|
||
|
||
return &IdempotencyKey{
|
||
TenantID: tenantID,
|
||
OperatorID: operatorID,
|
||
APIPath: apiPath,
|
||
Key: idempotencyKey,
|
||
}, nil
|
||
}
|
||
|
||
// ComputePayloadHash 计算请求体的SHA256哈希
|
||
func ComputePayloadHash(body []byte) string {
|
||
hash := sha256.Sum256(body)
|
||
return hex.EncodeToString(hash[:])
|
||
}
|
||
|
||
// IdempotentHandler 幂等处理器函数
|
||
type IdempotentHandler func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error
|
||
|
||
// Wrap 包装HTTP处理器以实现幂等
|
||
func (m *IdempotencyMiddleware) Wrap(handler IdempotentHandler) http.HandlerFunc {
|
||
return func(w http.ResponseWriter, r *http.Request) {
|
||
if !m.config.Enabled {
|
||
handler(r.Context(), w, r, nil)
|
||
return
|
||
}
|
||
|
||
ctx := r.Context()
|
||
|
||
// 从context获取租户和操作者ID(由鉴权中间件设置)
|
||
tenantID := getTenantID(ctx)
|
||
operatorID := getOperatorID(ctx)
|
||
|
||
// 提取幂等信息
|
||
idempKey, err := ExtractIdempotencyKey(r, tenantID, operatorID)
|
||
if err != nil {
|
||
writeIdempotencyError(w, http.StatusBadRequest, "IDEMPOTENCY_KEY_INVALID", err.Error())
|
||
return
|
||
}
|
||
|
||
// 读取请求体
|
||
body, err := io.ReadAll(r.Body)
|
||
if err != nil {
|
||
writeIdempotencyError(w, http.StatusBadRequest, "BODY_READ_ERROR", err.Error())
|
||
return
|
||
}
|
||
// 重新填充body以供后续处理
|
||
r.Body = io.NopCloser(bytes.NewBuffer(body))
|
||
|
||
// 计算payload hash
|
||
payloadHash := ComputePayloadHash(body)
|
||
|
||
// 查询已存在的幂等记录
|
||
existingRecord, err := m.idempotencyRepo.GetByKey(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key)
|
||
if err != nil {
|
||
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_CHECK_FAILED", err.Error())
|
||
return
|
||
}
|
||
|
||
if existingRecord != nil {
|
||
// 存在记录,处理不同情况
|
||
switch existingRecord.Status {
|
||
case repository.IdempotencyStatusSucceeded:
|
||
// 同参重放:返回原结果
|
||
if existingRecord.PayloadHash == payloadHash {
|
||
writeIdempotentReplay(w, existingRecord.ResponseCode, existingRecord.ResponseBody)
|
||
return
|
||
}
|
||
// 异参重放:返回409冲突
|
||
writeIdempotencyError(w, http.StatusConflict, "IDEMPOTENCY_PAYLOAD_MISMATCH",
|
||
fmt.Sprintf("same idempotency key but different payload, original request_id: %s", existingRecord.RequestID))
|
||
return
|
||
|
||
case repository.IdempotencyStatusProcessing:
|
||
// 处理中:检查是否超时
|
||
if time.Since(existingRecord.UpdatedAt) < m.config.ProcessingTTL {
|
||
retryAfter := m.config.ProcessingTTL - time.Since(existingRecord.UpdatedAt)
|
||
writeIdempotencyProcessing(w, int(retryAfter.Milliseconds()), existingRecord.RequestID)
|
||
return
|
||
}
|
||
// 超时:允许重试(记录会自然过期)
|
||
|
||
case repository.IdempotencyStatusFailed:
|
||
// 失败状态也允许重试
|
||
}
|
||
}
|
||
|
||
// 使用AcquireLock获取锁
|
||
requestID := r.Header.Get("X-Request-Id")
|
||
lockedRecord, err := m.idempotencyRepo.AcquireLock(ctx, idempKey.TenantID, idempKey.OperatorID, idempKey.APIPath, idempKey.Key, m.config.TTL)
|
||
if err != nil {
|
||
writeIdempotencyError(w, http.StatusInternalServerError, "IDEMPOTENCY_LOCK_FAILED", err.Error())
|
||
return
|
||
}
|
||
|
||
// 更新记录中的request_id和payload_hash
|
||
if lockedRecord.ID != 0 && (lockedRecord.RequestID == "" || lockedRecord.PayloadHash == "") {
|
||
lockedRecord.RequestID = requestID
|
||
lockedRecord.PayloadHash = payloadHash
|
||
}
|
||
|
||
// 创建包装器以捕获实际的状态码和响应体
|
||
wrappedWriter := &statusCapturingResponseWriter{ResponseWriter: w}
|
||
|
||
// 执行实际业务处理,使用包装器捕获响应
|
||
err = handler(ctx, wrappedWriter, r, lockedRecord)
|
||
|
||
// 根据处理结果更新幂等记录
|
||
if err != nil {
|
||
// 业务处理失败
|
||
errMsg, _ := json.Marshal(map[string]string{"error": err.Error()})
|
||
_ = m.idempotencyRepo.UpdateFailed(ctx, lockedRecord.ID, http.StatusInternalServerError, errMsg)
|
||
return
|
||
}
|
||
|
||
// 业务处理成功,使用捕获的实际状态码和body更新幂等记录
|
||
successBody := wrappedWriter.body
|
||
if len(successBody) == 0 {
|
||
successBody, _ = json.Marshal(map[string]interface{}{"status": "ok"})
|
||
}
|
||
_ = m.idempotencyRepo.UpdateSuccess(ctx, lockedRecord.ID, wrappedWriter.statusCode, successBody)
|
||
}
|
||
}
|
||
|
||
// writeIdempotencyError 写入幂等错误
|
||
func writeIdempotencyError(w http.ResponseWriter, status int, code, message string) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.WriteHeader(status)
|
||
resp := map[string]interface{}{
|
||
"request_id": "",
|
||
"error": map[string]string{
|
||
"code": code,
|
||
"message": message,
|
||
},
|
||
}
|
||
json.NewEncoder(w).Encode(resp)
|
||
}
|
||
|
||
// writeIdempotencyProcessing 写入处理中状态
|
||
func writeIdempotencyProcessing(w http.ResponseWriter, retryAfterMs int, requestID string) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("Retry-After-Ms", fmt.Sprintf("%d", retryAfterMs))
|
||
w.Header().Set("X-Request-Id", requestID)
|
||
w.WriteHeader(http.StatusAccepted)
|
||
resp := map[string]interface{}{
|
||
"request_id": requestID,
|
||
"error": map[string]string{
|
||
"code": "IDEMPOTENCY_IN_PROGRESS",
|
||
"message": "request is being processed, please retry later",
|
||
},
|
||
}
|
||
json.NewEncoder(w).Encode(resp)
|
||
}
|
||
|
||
// writeIdempotentReplay 写入幂等重放响应
|
||
func writeIdempotentReplay(w http.ResponseWriter, status int, body json.RawMessage) {
|
||
w.Header().Set("Content-Type", "application/json")
|
||
w.Header().Set("X-Idempotent-Replay", "true")
|
||
w.WriteHeader(status)
|
||
if body != nil {
|
||
w.Write(body)
|
||
}
|
||
}
|
||
|
||
// statusCapturingResponseWriter 包装http.ResponseWriter以捕获状态码
|
||
type statusCapturingResponseWriter struct {
|
||
http.ResponseWriter
|
||
statusCode int
|
||
body []byte
|
||
}
|
||
|
||
func (w *statusCapturingResponseWriter) WriteHeader(statusCode int) {
|
||
w.statusCode = statusCode
|
||
w.ResponseWriter.WriteHeader(statusCode)
|
||
}
|
||
|
||
func (w *statusCapturingResponseWriter) Write(b []byte) (int, error) {
|
||
w.body = append(w.body, b...)
|
||
return w.ResponseWriter.Write(b)
|
||
}
|
||
|
||
// context keys
|
||
type contextKey string
|
||
|
||
const (
|
||
tenantIDKey contextKey = "tenant_id"
|
||
operatorIDKey contextKey = "operator_id"
|
||
)
|
||
|
||
// WithTenantID 在context中设置租户ID
|
||
func WithTenantID(ctx context.Context, tenantID int64) context.Context {
|
||
return context.WithValue(ctx, tenantIDKey, tenantID)
|
||
}
|
||
|
||
// WithOperatorID 在context中设置操作者ID
|
||
func WithOperatorID(ctx context.Context, operatorID int64) context.Context {
|
||
return context.WithValue(ctx, operatorIDKey, operatorID)
|
||
}
|
||
|
||
func getTenantID(ctx context.Context) int64 {
|
||
if v := ctx.Value(tenantIDKey); v != nil {
|
||
if id, ok := v.(int64); ok {
|
||
return id
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
func getOperatorID(ctx context.Context) int64 {
|
||
if v := ctx.Value(operatorIDKey); v != nil {
|
||
if id, ok := v.(int64); ok {
|
||
return id
|
||
}
|
||
}
|
||
return 0
|
||
}
|
||
|
||
// GetOperatorID 公开函数,从context获取操作者ID
|
||
func GetOperatorID(ctx context.Context) int64 {
|
||
return getOperatorID(ctx)
|
||
}
|