From 1c088e2dd4d0916b65f5c193e8410596c77d89a6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 20 Apr 2026 11:36:07 +0800 Subject: [PATCH] fix(supply-api): restore package lifecycle ownership semantics --- supply-api/internal/adapter/adapter.go | 6 +- .../adapter/package_integration_test.go | 114 ++++++++++++++++++ supply-api/internal/httpapi/supply_api.go | 4 +- .../internal/httpapi/supply_api_test.go | 61 ++++++++++ supply-api/internal/repository/package.go | 8 +- .../repository/package_integration_test.go | 61 ++++++++-- 6 files changed, 234 insertions(+), 20 deletions(-) create mode 100644 supply-api/internal/adapter/package_integration_test.go diff --git a/supply-api/internal/adapter/adapter.go b/supply-api/internal/adapter/adapter.go index ecffc6c7..e1f9b7c4 100644 --- a/supply-api/internal/adapter/adapter.go +++ b/supply-api/internal/adapter/adapter.go @@ -178,7 +178,11 @@ func (s *DBPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*do } func (s *DBPackageStore) Update(ctx context.Context, pkg *domain.Package) error { - return s.repo.Update(ctx, pkg, pkg.Version) + expectedVersion := 0 + if pkg.Version > 0 { + expectedVersion = pkg.Version - 1 + } + return s.repo.Update(ctx, pkg, expectedVersion) } func (s *DBPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) { diff --git a/supply-api/internal/adapter/package_integration_test.go b/supply-api/internal/adapter/package_integration_test.go new file mode 100644 index 00000000..86d31096 --- /dev/null +++ b/supply-api/internal/adapter/package_integration_test.go @@ -0,0 +1,114 @@ +//go:build integration +// +build integration + +package adapter + +import ( + "context" + "testing" + "time" + + "lijiaoqiao/supply-api/internal/audit" + "lijiaoqiao/supply-api/internal/domain" + "lijiaoqiao/supply-api/internal/repository" +) + +func createAdapterIntegrationPackage(t *testing.T, repo *repository.PackageRepository, supplierID, accountID int64, status domain.PackageStatus) *domain.Package { + t.Helper() + + pkg := &domain.Package{ + SupplierID: supplierID, + AccountID: accountID, + Platform: "openai", + Model: "gpt-4.1-mini", + TotalQuota: 10000, + AvailableQuota: 10000, + SoldQuota: 0, + ReservedQuota: 0, + PricePer1MInput: 0.25, + PricePer1MOutput: 0.75, + MinPurchase: 100, + StartAt: time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC), + EndAt: time.Date(2026, 5, 20, 0, 0, 0, 0, time.UTC), + ValidDays: 30, + Status: status, + MaxConcurrent: 5, + RateLimitRPM: 60, + Version: 1, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := repo.Create(context.Background(), pkg, "req-pkg-lifecycle-int", "trace-pkg-lifecycle-int"); err != nil { + t.Fatalf("创建测试套餐失败: %v", err) + } + + return pkg +} + +func TestDBPackageStore_Lifecycle_Integration(t *testing.T) { + if testing.Short() { + t.Skip("跳过集成测试(short mode)") + } + + pool := getIntegrationDB(t) + if pool == nil { + return + } + + repo := repository.NewPackageRepository(pool) + store := NewDBPackageStore(repo) + service := domain.NewPackageService(store, NewInMemoryAccountStoreAdapter(), audit.NewMemoryAuditStore()) + supplierID := time.Now().UnixNano() + accountID := supplierID + 1000 + pkg := createAdapterIntegrationPackage(t, repo, supplierID, accountID, domain.PackageStatusDraft) + + fetched, err := store.GetByID(context.Background(), supplierID, pkg.ID) + if err != nil { + t.Fatalf("读取套餐失败: %v", err) + } + if fetched.SupplierID != supplierID { + t.Fatalf("expected fetched supplier id %d, got %d", supplierID, fetched.SupplierID) + } + if fetched.AccountID != accountID { + t.Fatalf("expected fetched account id %d, got %d", accountID, fetched.AccountID) + } + + published, err := service.Publish(context.Background(), supplierID, pkg.ID) + if err != nil { + t.Fatalf("发布套餐失败: %v", err) + } + if published.Status != domain.PackageStatusActive { + t.Fatalf("expected published status %q, got %q", domain.PackageStatusActive, published.Status) + } + + paused, err := service.Pause(context.Background(), supplierID, pkg.ID) + if err != nil { + t.Fatalf("暂停套餐失败: %v", err) + } + if paused.Status != domain.PackageStatusPaused { + t.Fatalf("expected paused status %q, got %q", domain.PackageStatusPaused, paused.Status) + } + + unlisted, err := service.Unlist(context.Background(), supplierID, pkg.ID) + if err != nil { + t.Fatalf("下架套餐失败: %v", err) + } + if unlisted.Status != domain.PackageStatusExpired { + t.Fatalf("expected unlisted status %q, got %q", domain.PackageStatusExpired, unlisted.Status) + } + + after, err := repo.GetByID(context.Background(), supplierID, pkg.ID) + if err != nil { + t.Fatalf("生命周期完成后读取套餐失败: %v", err) + } + if after.Status != domain.PackageStatusExpired { + t.Fatalf("expected persisted status %q, got %q", domain.PackageStatusExpired, after.Status) + } + if after.SupplierID != supplierID { + t.Fatalf("expected persisted supplier id %d, got %d", supplierID, after.SupplierID) + } + if after.AccountID != accountID { + t.Fatalf("expected persisted account id %d, got %d", accountID, after.AccountID) + } +} diff --git a/supply-api/internal/httpapi/supply_api.go b/supply-api/internal/httpapi/supply_api.go index ce504343..a1b8334a 100644 --- a/supply-api/internal/httpapi/supply_api.go +++ b/supply-api/internal/httpapi/supply_api.go @@ -503,7 +503,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ "request_id": getRequestID(r), "data": map[string]any{ "package_id": pkg.ID, - "supply_account_id": pkg.SupplierID, + "supply_account_id": pkg.AccountID, "model": pkg.Model, "status": pkg.Status, "total_quota": pkg.TotalQuota, @@ -665,7 +665,7 @@ func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, p "request_id": getRequestID(r), "data": map[string]any{ "package_id": pkg.ID, - "supply_account_id": pkg.SupplierID, + "supply_account_id": pkg.AccountID, "model": pkg.Model, "status": pkg.Status, "created_at": pkg.CreatedAt, diff --git a/supply-api/internal/httpapi/supply_api_test.go b/supply-api/internal/httpapi/supply_api_test.go index b3adfac3..9af0eee7 100644 --- a/supply-api/internal/httpapi/supply_api_test.go +++ b/supply-api/internal/httpapi/supply_api_test.go @@ -703,6 +703,41 @@ func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) { } } +func TestSupplyAPI_CreatePackageDraft_ResponseUsesAccountID(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.pkg.AccountID = 200 + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(`{ + "supply_account_id": 200, + "model": "gpt-4", + "total_quota": 10000, + "price_per_1m_input": 0.5, + "price_per_1m_output": 1.5, + "valid_days": 30, + "max_concurrent": 10, + "rate_limit_rpm": 100 + }`)) + w := httptest.NewRecorder() + + api.handleCreatePackageDraft(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String()) + } + + var resp struct { + Data struct { + SupplyAccountID int64 `json:"supply_account_id"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID { + t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID) + } +} + func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) { api, _, _, _, _, _ := newTestAPI() @@ -876,6 +911,32 @@ func TestSupplyAPI_ClonePackage_Success(t *testing.T) { } } +func TestSupplyAPI_ClonePackage_ResponseUsesAccountID(t *testing.T) { + api, _, packageSvc, _, _, _ := newTestAPI() + packageSvc.pkg.AccountID = 200 + + req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil) + w := httptest.NewRecorder() + + api.handlePackageActions(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String()) + } + + var resp struct { + Data struct { + SupplyAccountID int64 `json:"supply_account_id"` + } `json:"data"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID { + t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID) + } +} + func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) { api, _, _, _, _, _ := newTestAPI() api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{ diff --git a/supply-api/internal/repository/package.go b/supply-api/internal/repository/package.go index 5cc6ab6a..a808066e 100644 --- a/supply-api/internal/repository/package.go +++ b/supply-api/internal/repository/package.go @@ -50,7 +50,7 @@ func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, req } err := r.pool.QueryRow(ctx, query, - pkg.SupplierID, pkg.AccountID, pkg.Platform, pkg.Model, + pkg.AccountID, pkg.SupplierID, pkg.Platform, pkg.Model, pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota, pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase, startAt, endAt, pkg.ValidDays, @@ -85,7 +85,7 @@ func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) ( pkg := &domain.Package{} var startAt, endAt *time.Time err := r.pool.QueryRow(ctx, query, id, supplierID).Scan( - &pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model, + &pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota, &pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase, &startAt, &endAt, &pkg.ValidDays, @@ -169,7 +169,7 @@ func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, sup pkg := &domain.Package{} err := tx.QueryRow(ctx, query, id, supplierID).Scan( - &pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model, + &pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota, &pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.Status, &pkg.Version, @@ -210,7 +210,7 @@ func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*doma for rows.Next() { pkg := &domain.Package{} err := rows.Scan( - &pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model, + &pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model, &pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM, diff --git a/supply-api/internal/repository/package_integration_test.go b/supply-api/internal/repository/package_integration_test.go index 45b3babd..12fd2bd0 100644 --- a/supply-api/internal/repository/package_integration_test.go +++ b/supply-api/internal/repository/package_integration_test.go @@ -75,6 +75,17 @@ func TestPackageRepository_Create_Integration(t *testing.T) { if pkg.ID == 0 { t.Fatal("expected created package id") } + + fetched, err := repo.GetByID(context.Background(), pkg.SupplierID, pkg.ID) + if err != nil { + t.Fatalf("get package after create failed: %v", err) + } + if fetched.SupplierID != pkg.SupplierID { + t.Fatalf("expected supplier id %d, got %d", pkg.SupplierID, fetched.SupplierID) + } + if fetched.AccountID != pkg.AccountID { + t.Fatalf("expected account id %d, got %d", pkg.AccountID, fetched.AccountID) + } } func TestPackageRepository_GetByID_Integration(t *testing.T) { @@ -116,23 +127,47 @@ func TestPackageRepository_List_Integration(t *testing.T) { return } - rows, err := pool.Query(context.Background(), ` - SELECT id, user_id, available_quota - FROM supply_packages - LIMIT 10 - `) - if err != nil { - t.Fatalf("列出套餐失败: %v", err) + repo := NewPackageRepository(pool) + pkg := &domain.Package{ + SupplierID: 3001, + AccountID: 4001, + Platform: "anthropic", + Model: "claude-3-7-sonnet", + TotalQuota: 5000, + AvailableQuota: 5000, + PricePer1MInput: 0.4, + PricePer1MOutput: 1.2, + Status: domain.PackageStatusDraft, + ValidDays: 15, } - defer rows.Close() - for rows.Next() { - var id, userID int64 - var availableQuota float64 - if scanErr := rows.Scan(&id, &userID, &availableQuota); scanErr != nil { - t.Fatalf("扫描套餐失败: %v", scanErr) + if err := repo.Create(context.Background(), pkg, "req-pkg-list-int", "trace-pkg-list-int"); err != nil { + t.Fatalf("create package for list failed: %v", err) + } + + packages, err := repo.List(context.Background(), pkg.SupplierID) + if err != nil { + t.Fatalf("repo.List failed: %v", err) + } + if len(packages) == 0 { + t.Fatal("expected packages for supplier") + } + + found := false + for _, listed := range packages { + if listed.ID == pkg.ID { + found = true + if listed.SupplierID != pkg.SupplierID { + t.Fatalf("expected listed supplier id %d, got %d", pkg.SupplierID, listed.SupplierID) + } + if listed.AccountID != pkg.AccountID { + t.Fatalf("expected listed account id %d, got %d", pkg.AccountID, listed.AccountID) + } } } + if !found { + t.Fatalf("expected package %d in supplier list", pkg.ID) + } } func TestPackageRepository_UpdateQuota_Integration(t *testing.T) {