Files
lijiaoqiao/supply-api/internal/middleware/idempotency_test.go
Your Name 0196ee5d47 feat(supply-api): 完成核心模块实现
新增/修改内容:
- config: 添加配置管理(config.example.yaml, config.go)
- cache: 添加Redis缓存层(redis.go)
- domain: 添加invariants不变量验证及测试
- middleware: 添加auth认证和idempotency幂等性中间件及测试
- repository: 添加完整数据访问层(account, package, settlement, idempotency, db)
- sql: 添加幂等性表DDL脚本

代码覆盖:
- auth middleware实现凭证边界验证
- idempotency middleware实现请求幂等性
- invariants实现业务不变量检查
- repository层实现完整的数据访问逻辑

关联issue: Round-1 R1-ISSUE-006 凭证边界硬门禁
2026-04-01 08:53:28 +08:00

212 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"lijiaoqiao/supply-api/internal/repository"
)
// MockIdempotencyRepository 模拟幂等仓储
type MockIdempotencyRepository struct {
records map[string]*repository.IdempotencyRecord
}
func NewMockIdempotencyRepository() *MockIdempotencyRepository {
return &MockIdempotencyRepository{
records: make(map[string]*repository.IdempotencyRecord),
}
}
func (r *MockIdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
if record, ok := r.records[key]; ok {
if time.Now().Before(record.ExpiresAt) {
return record, nil
}
}
return nil, nil
}
func (r *MockIdempotencyRepository) Create(ctx context.Context, record *repository.IdempotencyRecord) error {
key := buildKey(record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey)
r.records[key] = record
return nil
}
func (r *MockIdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
record := &repository.IdempotencyRecord{
TenantID: tenantID,
OperatorID: operatorID,
APIPath: apiPath,
IdempotencyKey: idempotencyKey,
RequestID: "test-request-id",
PayloadHash: "",
Status: repository.IdempotencyStatusProcessing,
ExpiresAt: time.Now().Add(ttl),
}
r.records[key] = record
return record, nil
}
func buildKey(tenantID, operatorID int64, apiPath, idempotencyKey string) string {
return strings.Join([]string{
string(rune(tenantID)),
string(rune(operatorID)),
apiPath,
idempotencyKey,
}, ":")
}
func TestComputePayloadHash(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "empty body",
body: []byte{},
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
name: "simple JSON",
body: []byte(`{"key":"value"}`),
expected: computeExpectedHash(`{"key":"value"}`),
},
{
name: "JSON with spaces",
body: []byte(`{ "key": "value" }`),
expected: computeExpectedHash(`{ "key": "value" }`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ComputePayloadHash(tt.body)
if result != tt.expected {
t.Errorf("ComputePayloadHash() = %v, want %v", result, tt.expected)
}
})
}
}
func computeExpectedHash(s string) string {
hash := sha256.Sum256([]byte(s))
return hex.EncodeToString(hash[:])
}
func TestExtractIdempotencyKey(t *testing.T) {
tests := []struct {
name string
headers map[string]string
expectError bool
errorCode string
}{
{
name: "valid headers",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "idem-key-12345678",
},
expectError: false,
},
{
name: "missing X-Request-Id",
headers: map[string]string{
"Idempotency-Key": "idem-key-12345678",
},
expectError: true,
errorCode: "missing X-Request-Id header",
},
{
name: "missing Idempotency-Key",
headers: map[string]string{
"X-Request-Id": "req-123",
},
expectError: true,
errorCode: "missing Idempotency-Key header",
},
{
name: "Idempotency-Key too short",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "short",
},
expectError: true,
errorCode: "Idempotency-Key length must be 16-128",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
for k, v := range tt.headers {
req.Header.Set(k, v)
}
result, err := ExtractIdempotencyKey(req, 1, 1)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
}
if err != nil && !strings.Contains(err.Error(), tt.errorCode) {
t.Errorf("error = %v, want contains %v", err, tt.errorCode)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if result == nil {
t.Errorf("expected result but got nil")
}
}
})
}
}
func TestIdempotentHandler(t *testing.T) {
// 创建测试handler
testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error {
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "created"})
return nil
}
middleware := NewIdempotencyMiddleware(nil, IdempotencyConfig{
Enabled: false, // 禁用幂等只测试handler包装
})
handler := middleware.Wrap(testHandler)
t.Run("handler executes successfully", func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(`{"key":"value"}`))
req.Header.Set("X-Request-Id", "req-123")
req.Header.Set("Idempotency-Key", "idem-key-12345678")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
}
})
}