Files
lijiaoqiao/supply-api/internal/middleware/idempotency_test.go
Your Name ed0961d486 fix(supply-api): 修复编译错误和测试问题
- 添加 ErrNotFound 和 ErrConcurrencyConflict 错误定义
- 修复 pgx.NullTime 替换为 *time.Time
- 修复 db.go 事务类型 (pgx.Tx vs pgxpool.Tx)
- 移除未使用的导入和变量
- 修复 NewSupplyAPI 调用参数
- 修复中间件链路 handler 类型问题
- 修复适配器类型引用 (storage.InMemoryAccountStore 等)
- 所有测试通过

Test: go test ./...
2026-04-01 13:03:44 +08:00

212 lines
5.5 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"lijiaoqiao/supply-api/internal/repository"
)
// MockIdempotencyRepository 模拟幂等仓储
type MockIdempotencyRepository struct {
records map[string]*repository.IdempotencyRecord
}
func NewMockIdempotencyRepository() *MockIdempotencyRepository {
return &MockIdempotencyRepository{
records: make(map[string]*repository.IdempotencyRecord),
}
}
func (r *MockIdempotencyRepository) GetByKey(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
if record, ok := r.records[key]; ok {
if time.Now().Before(record.ExpiresAt) {
return record, nil
}
}
return nil, nil
}
func (r *MockIdempotencyRepository) Create(ctx context.Context, record *repository.IdempotencyRecord) error {
key := buildKey(record.TenantID, record.OperatorID, record.APIPath, record.IdempotencyKey)
r.records[key] = record
return nil
}
func (r *MockIdempotencyRepository) UpdateSuccess(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) UpdateFailed(ctx context.Context, id int64, responseCode int, responseBody json.RawMessage) error {
return nil
}
func (r *MockIdempotencyRepository) AcquireLock(ctx context.Context, tenantID, operatorID int64, apiPath, idempotencyKey string, ttl time.Duration) (*repository.IdempotencyRecord, error) {
key := buildKey(tenantID, operatorID, apiPath, idempotencyKey)
record := &repository.IdempotencyRecord{
TenantID: tenantID,
OperatorID: operatorID,
APIPath: apiPath,
IdempotencyKey: idempotencyKey,
RequestID: "test-request-id",
PayloadHash: "",
Status: repository.IdempotencyStatusProcessing,
ExpiresAt: time.Now().Add(ttl),
}
r.records[key] = record
return record, nil
}
func buildKey(tenantID, operatorID int64, apiPath, idempotencyKey string) string {
return strings.Join([]string{
string(rune(tenantID)),
string(rune(operatorID)),
apiPath,
idempotencyKey,
}, ":")
}
func TestComputePayloadHash(t *testing.T) {
tests := []struct {
name string
body []byte
expected string
}{
{
name: "empty body",
body: []byte{},
expected: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
},
{
name: "simple JSON",
body: []byte(`{"key":"value"}`),
expected: computeExpectedHash(`{"key":"value"}`),
},
{
name: "JSON with spaces",
body: []byte(`{ "key": "value" }`),
expected: computeExpectedHash(`{ "key": "value" }`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ComputePayloadHash(tt.body)
if result != tt.expected {
t.Errorf("ComputePayloadHash() = %v, want %v", result, tt.expected)
}
})
}
}
func computeExpectedHash(s string) string {
hash := sha256.Sum256([]byte(s))
return hex.EncodeToString(hash[:])
}
func TestExtractIdempotencyKey(t *testing.T) {
tests := []struct {
name string
headers map[string]string
expectError bool
errorCode string
}{
{
name: "valid headers",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "idem-key-12345678",
},
expectError: false,
},
{
name: "missing X-Request-Id",
headers: map[string]string{
"Idempotency-Key": "idem-key-12345678",
},
expectError: true,
errorCode: "missing X-Request-Id header",
},
{
name: "missing Idempotency-Key",
headers: map[string]string{
"X-Request-Id": "req-123",
},
expectError: true,
errorCode: "missing Idempotency-Key header",
},
{
name: "Idempotency-Key too short",
headers: map[string]string{
"X-Request-Id": "req-123",
"Idempotency-Key": "short",
},
expectError: true,
errorCode: "Idempotency-Key length must be 16-128",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", nil)
for k, v := range tt.headers {
req.Header.Set(k, v)
}
result, err := ExtractIdempotencyKey(req, 1, 1)
if tt.expectError {
if err == nil {
t.Errorf("expected error but got nil")
}
if err != nil && !strings.Contains(err.Error(), tt.errorCode) {
t.Errorf("error = %v, want contains %v", err, tt.errorCode)
}
} else {
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if result == nil {
t.Errorf("expected result but got nil")
}
}
})
}
}
func TestIdempotentHandler(t *testing.T) {
// 创建测试handler
testHandler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, record *repository.IdempotencyRecord) error {
w.WriteHeader(http.StatusCreated)
json.NewEncoder(w).Encode(map[string]string{"status": "created"})
return nil
}
middleware := NewIdempotencyMiddleware(nil, IdempotencyConfig{
Enabled: false, // 禁用幂等只测试handler包装
})
handler := middleware.Wrap(testHandler)
t.Run("handler executes successfully", func(t *testing.T) {
req := httptest.NewRequest("POST", "/api/v1/supply/accounts", strings.NewReader(`{"key":"value"}`))
req.Header.Set("X-Request-Id", "req-123")
req.Header.Set("Idempotency-Key", "idem-key-12345678")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Errorf("expected status %d, got %d", http.StatusCreated, w.Code)
}
})
}