fix(supply-api): restore package lifecycle ownership semantics
This commit is contained in:
@@ -178,7 +178,11 @@ func (s *DBPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*do
|
||||
}
|
||||
|
||||
func (s *DBPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
|
||||
return s.repo.Update(ctx, pkg, pkg.Version)
|
||||
expectedVersion := 0
|
||||
if pkg.Version > 0 {
|
||||
expectedVersion = pkg.Version - 1
|
||||
}
|
||||
return s.repo.Update(ctx, pkg, expectedVersion)
|
||||
}
|
||||
|
||||
func (s *DBPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
|
||||
|
||||
114
supply-api/internal/adapter/package_integration_test.go
Normal file
114
supply-api/internal/adapter/package_integration_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit"
|
||||
"lijiaoqiao/supply-api/internal/domain"
|
||||
"lijiaoqiao/supply-api/internal/repository"
|
||||
)
|
||||
|
||||
func createAdapterIntegrationPackage(t *testing.T, repo *repository.PackageRepository, supplierID, accountID int64, status domain.PackageStatus) *domain.Package {
|
||||
t.Helper()
|
||||
|
||||
pkg := &domain.Package{
|
||||
SupplierID: supplierID,
|
||||
AccountID: accountID,
|
||||
Platform: "openai",
|
||||
Model: "gpt-4.1-mini",
|
||||
TotalQuota: 10000,
|
||||
AvailableQuota: 10000,
|
||||
SoldQuota: 0,
|
||||
ReservedQuota: 0,
|
||||
PricePer1MInput: 0.25,
|
||||
PricePer1MOutput: 0.75,
|
||||
MinPurchase: 100,
|
||||
StartAt: time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC),
|
||||
EndAt: time.Date(2026, 5, 20, 0, 0, 0, 0, time.UTC),
|
||||
ValidDays: 30,
|
||||
Status: status,
|
||||
MaxConcurrent: 5,
|
||||
RateLimitRPM: 60,
|
||||
Version: 1,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := repo.Create(context.Background(), pkg, "req-pkg-lifecycle-int", "trace-pkg-lifecycle-int"); err != nil {
|
||||
t.Fatalf("创建测试套餐失败: %v", err)
|
||||
}
|
||||
|
||||
return pkg
|
||||
}
|
||||
|
||||
func TestDBPackageStore_Lifecycle_Integration(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("跳过集成测试(short mode)")
|
||||
}
|
||||
|
||||
pool := getIntegrationDB(t)
|
||||
if pool == nil {
|
||||
return
|
||||
}
|
||||
|
||||
repo := repository.NewPackageRepository(pool)
|
||||
store := NewDBPackageStore(repo)
|
||||
service := domain.NewPackageService(store, NewInMemoryAccountStoreAdapter(), audit.NewMemoryAuditStore())
|
||||
supplierID := time.Now().UnixNano()
|
||||
accountID := supplierID + 1000
|
||||
pkg := createAdapterIntegrationPackage(t, repo, supplierID, accountID, domain.PackageStatusDraft)
|
||||
|
||||
fetched, err := store.GetByID(context.Background(), supplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("读取套餐失败: %v", err)
|
||||
}
|
||||
if fetched.SupplierID != supplierID {
|
||||
t.Fatalf("expected fetched supplier id %d, got %d", supplierID, fetched.SupplierID)
|
||||
}
|
||||
if fetched.AccountID != accountID {
|
||||
t.Fatalf("expected fetched account id %d, got %d", accountID, fetched.AccountID)
|
||||
}
|
||||
|
||||
published, err := service.Publish(context.Background(), supplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("发布套餐失败: %v", err)
|
||||
}
|
||||
if published.Status != domain.PackageStatusActive {
|
||||
t.Fatalf("expected published status %q, got %q", domain.PackageStatusActive, published.Status)
|
||||
}
|
||||
|
||||
paused, err := service.Pause(context.Background(), supplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("暂停套餐失败: %v", err)
|
||||
}
|
||||
if paused.Status != domain.PackageStatusPaused {
|
||||
t.Fatalf("expected paused status %q, got %q", domain.PackageStatusPaused, paused.Status)
|
||||
}
|
||||
|
||||
unlisted, err := service.Unlist(context.Background(), supplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("下架套餐失败: %v", err)
|
||||
}
|
||||
if unlisted.Status != domain.PackageStatusExpired {
|
||||
t.Fatalf("expected unlisted status %q, got %q", domain.PackageStatusExpired, unlisted.Status)
|
||||
}
|
||||
|
||||
after, err := repo.GetByID(context.Background(), supplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("生命周期完成后读取套餐失败: %v", err)
|
||||
}
|
||||
if after.Status != domain.PackageStatusExpired {
|
||||
t.Fatalf("expected persisted status %q, got %q", domain.PackageStatusExpired, after.Status)
|
||||
}
|
||||
if after.SupplierID != supplierID {
|
||||
t.Fatalf("expected persisted supplier id %d, got %d", supplierID, after.SupplierID)
|
||||
}
|
||||
if after.AccountID != accountID {
|
||||
t.Fatalf("expected persisted account id %d, got %d", accountID, after.AccountID)
|
||||
}
|
||||
}
|
||||
@@ -503,7 +503,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
|
||||
"request_id": getRequestID(r),
|
||||
"data": map[string]any{
|
||||
"package_id": pkg.ID,
|
||||
"supply_account_id": pkg.SupplierID,
|
||||
"supply_account_id": pkg.AccountID,
|
||||
"model": pkg.Model,
|
||||
"status": pkg.Status,
|
||||
"total_quota": pkg.TotalQuota,
|
||||
@@ -665,7 +665,7 @@ func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, p
|
||||
"request_id": getRequestID(r),
|
||||
"data": map[string]any{
|
||||
"package_id": pkg.ID,
|
||||
"supply_account_id": pkg.SupplierID,
|
||||
"supply_account_id": pkg.AccountID,
|
||||
"model": pkg.Model,
|
||||
"status": pkg.Status,
|
||||
"created_at": pkg.CreatedAt,
|
||||
|
||||
@@ -703,6 +703,41 @@ func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupplyAPI_CreatePackageDraft_ResponseUsesAccountID(t *testing.T) {
|
||||
api, _, packageSvc, _, _, _ := newTestAPI()
|
||||
packageSvc.pkg.AccountID = 200
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(`{
|
||||
"supply_account_id": 200,
|
||||
"model": "gpt-4",
|
||||
"total_quota": 10000,
|
||||
"price_per_1m_input": 0.5,
|
||||
"price_per_1m_output": 1.5,
|
||||
"valid_days": 30,
|
||||
"max_concurrent": 10,
|
||||
"rate_limit_rpm": 100
|
||||
}`))
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
api.handleCreatePackageDraft(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
SupplyAccountID int64 `json:"supply_account_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID {
|
||||
t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) {
|
||||
api, _, _, _, _, _ := newTestAPI()
|
||||
|
||||
@@ -876,6 +911,32 @@ func TestSupplyAPI_ClonePackage_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupplyAPI_ClonePackage_ResponseUsesAccountID(t *testing.T) {
|
||||
api, _, packageSvc, _, _, _ := newTestAPI()
|
||||
packageSvc.pkg.AccountID = 200
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
api.handlePackageActions(w, req)
|
||||
|
||||
if w.Code != http.StatusCreated {
|
||||
t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String())
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Data struct {
|
||||
SupplyAccountID int64 `json:"supply_account_id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID {
|
||||
t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) {
|
||||
api, _, _, _, _, _ := newTestAPI()
|
||||
api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{
|
||||
|
||||
@@ -50,7 +50,7 @@ func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, req
|
||||
}
|
||||
|
||||
err := r.pool.QueryRow(ctx, query,
|
||||
pkg.SupplierID, pkg.AccountID, pkg.Platform, pkg.Model,
|
||||
pkg.AccountID, pkg.SupplierID, pkg.Platform, pkg.Model,
|
||||
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
|
||||
pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase,
|
||||
startAt, endAt, pkg.ValidDays,
|
||||
@@ -85,7 +85,7 @@ func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (
|
||||
pkg := &domain.Package{}
|
||||
var startAt, endAt *time.Time
|
||||
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase,
|
||||
&startAt, &endAt, &pkg.ValidDays,
|
||||
@@ -169,7 +169,7 @@ func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, sup
|
||||
|
||||
pkg := &domain.Package{}
|
||||
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.Version,
|
||||
@@ -210,7 +210,7 @@ func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*doma
|
||||
for rows.Next() {
|
||||
pkg := &domain.Package{}
|
||||
err := rows.Scan(
|
||||
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
|
||||
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota,
|
||||
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
|
||||
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,
|
||||
|
||||
@@ -75,6 +75,17 @@ func TestPackageRepository_Create_Integration(t *testing.T) {
|
||||
if pkg.ID == 0 {
|
||||
t.Fatal("expected created package id")
|
||||
}
|
||||
|
||||
fetched, err := repo.GetByID(context.Background(), pkg.SupplierID, pkg.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("get package after create failed: %v", err)
|
||||
}
|
||||
if fetched.SupplierID != pkg.SupplierID {
|
||||
t.Fatalf("expected supplier id %d, got %d", pkg.SupplierID, fetched.SupplierID)
|
||||
}
|
||||
if fetched.AccountID != pkg.AccountID {
|
||||
t.Fatalf("expected account id %d, got %d", pkg.AccountID, fetched.AccountID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageRepository_GetByID_Integration(t *testing.T) {
|
||||
@@ -116,23 +127,47 @@ func TestPackageRepository_List_Integration(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
rows, err := pool.Query(context.Background(), `
|
||||
SELECT id, user_id, available_quota
|
||||
FROM supply_packages
|
||||
LIMIT 10
|
||||
`)
|
||||
if err != nil {
|
||||
t.Fatalf("列出套餐失败: %v", err)
|
||||
repo := NewPackageRepository(pool)
|
||||
pkg := &domain.Package{
|
||||
SupplierID: 3001,
|
||||
AccountID: 4001,
|
||||
Platform: "anthropic",
|
||||
Model: "claude-3-7-sonnet",
|
||||
TotalQuota: 5000,
|
||||
AvailableQuota: 5000,
|
||||
PricePer1MInput: 0.4,
|
||||
PricePer1MOutput: 1.2,
|
||||
Status: domain.PackageStatusDraft,
|
||||
ValidDays: 15,
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var id, userID int64
|
||||
var availableQuota float64
|
||||
if scanErr := rows.Scan(&id, &userID, &availableQuota); scanErr != nil {
|
||||
t.Fatalf("扫描套餐失败: %v", scanErr)
|
||||
if err := repo.Create(context.Background(), pkg, "req-pkg-list-int", "trace-pkg-list-int"); err != nil {
|
||||
t.Fatalf("create package for list failed: %v", err)
|
||||
}
|
||||
|
||||
packages, err := repo.List(context.Background(), pkg.SupplierID)
|
||||
if err != nil {
|
||||
t.Fatalf("repo.List failed: %v", err)
|
||||
}
|
||||
if len(packages) == 0 {
|
||||
t.Fatal("expected packages for supplier")
|
||||
}
|
||||
|
||||
found := false
|
||||
for _, listed := range packages {
|
||||
if listed.ID == pkg.ID {
|
||||
found = true
|
||||
if listed.SupplierID != pkg.SupplierID {
|
||||
t.Fatalf("expected listed supplier id %d, got %d", pkg.SupplierID, listed.SupplierID)
|
||||
}
|
||||
if listed.AccountID != pkg.AccountID {
|
||||
t.Fatalf("expected listed account id %d, got %d", pkg.AccountID, listed.AccountID)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Fatalf("expected package %d in supplier list", pkg.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPackageRepository_UpdateQuota_Integration(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user