fix(supply-api): restore package lifecycle ownership semantics

This commit is contained in:
Your Name
2026-04-20 11:36:07 +08:00
parent 00ff6363bd
commit 1c088e2dd4
6 changed files with 234 additions and 20 deletions

View File

@@ -178,7 +178,11 @@ func (s *DBPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*do
}
func (s *DBPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
return s.repo.Update(ctx, pkg, pkg.Version)
expectedVersion := 0
if pkg.Version > 0 {
expectedVersion = pkg.Version - 1
}
return s.repo.Update(ctx, pkg, expectedVersion)
}
func (s *DBPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {

View File

@@ -0,0 +1,114 @@
//go:build integration
// +build integration
package adapter
import (
"context"
"testing"
"time"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/repository"
)
func createAdapterIntegrationPackage(t *testing.T, repo *repository.PackageRepository, supplierID, accountID int64, status domain.PackageStatus) *domain.Package {
t.Helper()
pkg := &domain.Package{
SupplierID: supplierID,
AccountID: accountID,
Platform: "openai",
Model: "gpt-4.1-mini",
TotalQuota: 10000,
AvailableQuota: 10000,
SoldQuota: 0,
ReservedQuota: 0,
PricePer1MInput: 0.25,
PricePer1MOutput: 0.75,
MinPurchase: 100,
StartAt: time.Date(2026, 4, 20, 0, 0, 0, 0, time.UTC),
EndAt: time.Date(2026, 5, 20, 0, 0, 0, 0, time.UTC),
ValidDays: 30,
Status: status,
MaxConcurrent: 5,
RateLimitRPM: 60,
Version: 1,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if err := repo.Create(context.Background(), pkg, "req-pkg-lifecycle-int", "trace-pkg-lifecycle-int"); err != nil {
t.Fatalf("创建测试套餐失败: %v", err)
}
return pkg
}
func TestDBPackageStore_Lifecycle_Integration(t *testing.T) {
if testing.Short() {
t.Skip("跳过集成测试short mode")
}
pool := getIntegrationDB(t)
if pool == nil {
return
}
repo := repository.NewPackageRepository(pool)
store := NewDBPackageStore(repo)
service := domain.NewPackageService(store, NewInMemoryAccountStoreAdapter(), audit.NewMemoryAuditStore())
supplierID := time.Now().UnixNano()
accountID := supplierID + 1000
pkg := createAdapterIntegrationPackage(t, repo, supplierID, accountID, domain.PackageStatusDraft)
fetched, err := store.GetByID(context.Background(), supplierID, pkg.ID)
if err != nil {
t.Fatalf("读取套餐失败: %v", err)
}
if fetched.SupplierID != supplierID {
t.Fatalf("expected fetched supplier id %d, got %d", supplierID, fetched.SupplierID)
}
if fetched.AccountID != accountID {
t.Fatalf("expected fetched account id %d, got %d", accountID, fetched.AccountID)
}
published, err := service.Publish(context.Background(), supplierID, pkg.ID)
if err != nil {
t.Fatalf("发布套餐失败: %v", err)
}
if published.Status != domain.PackageStatusActive {
t.Fatalf("expected published status %q, got %q", domain.PackageStatusActive, published.Status)
}
paused, err := service.Pause(context.Background(), supplierID, pkg.ID)
if err != nil {
t.Fatalf("暂停套餐失败: %v", err)
}
if paused.Status != domain.PackageStatusPaused {
t.Fatalf("expected paused status %q, got %q", domain.PackageStatusPaused, paused.Status)
}
unlisted, err := service.Unlist(context.Background(), supplierID, pkg.ID)
if err != nil {
t.Fatalf("下架套餐失败: %v", err)
}
if unlisted.Status != domain.PackageStatusExpired {
t.Fatalf("expected unlisted status %q, got %q", domain.PackageStatusExpired, unlisted.Status)
}
after, err := repo.GetByID(context.Background(), supplierID, pkg.ID)
if err != nil {
t.Fatalf("生命周期完成后读取套餐失败: %v", err)
}
if after.Status != domain.PackageStatusExpired {
t.Fatalf("expected persisted status %q, got %q", domain.PackageStatusExpired, after.Status)
}
if after.SupplierID != supplierID {
t.Fatalf("expected persisted supplier id %d, got %d", supplierID, after.SupplierID)
}
if after.AccountID != accountID {
t.Fatalf("expected persisted account id %d, got %d", accountID, after.AccountID)
}
}

View File

@@ -503,7 +503,7 @@ func (a *SupplyAPI) handleCreatePackageDraft(w http.ResponseWriter, r *http.Requ
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"supply_account_id": pkg.SupplierID,
"supply_account_id": pkg.AccountID,
"model": pkg.Model,
"status": pkg.Status,
"total_quota": pkg.TotalQuota,
@@ -665,7 +665,7 @@ func (a *SupplyAPI) handleClonePackage(w http.ResponseWriter, r *http.Request, p
"request_id": getRequestID(r),
"data": map[string]any{
"package_id": pkg.ID,
"supply_account_id": pkg.SupplierID,
"supply_account_id": pkg.AccountID,
"model": pkg.Model,
"status": pkg.Status,
"created_at": pkg.CreatedAt,

View File

@@ -703,6 +703,41 @@ func TestSupplyAPI_CreatePackageDraft_Success(t *testing.T) {
}
}
func TestSupplyAPI_CreatePackageDraft_ResponseUsesAccountID(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.pkg.AccountID = 200
req := httptest.NewRequest("POST", "/api/v1/supply/packages/draft", strings.NewReader(`{
"supply_account_id": 200,
"model": "gpt-4",
"total_quota": 10000,
"price_per_1m_input": 0.5,
"price_per_1m_output": 1.5,
"valid_days": 30,
"max_concurrent": 10,
"rate_limit_rpm": 100
}`))
w := httptest.NewRecorder()
api.handleCreatePackageDraft(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String())
}
var resp struct {
Data struct {
SupplyAccountID int64 `json:"supply_account_id"`
} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID {
t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID)
}
}
func TestSupplyAPI_CreatePackageDraft_MethodNotAllowed(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
@@ -876,6 +911,32 @@ func TestSupplyAPI_ClonePackage_Success(t *testing.T) {
}
}
func TestSupplyAPI_ClonePackage_ResponseUsesAccountID(t *testing.T) {
api, _, packageSvc, _, _, _ := newTestAPI()
packageSvc.pkg.AccountID = 200
req := httptest.NewRequest("POST", "/api/v1/supply/packages/1/clone", nil)
w := httptest.NewRecorder()
api.handlePackageActions(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("expected status 201, got %d body=%s", w.Code, w.Body.String())
}
var resp struct {
Data struct {
SupplyAccountID int64 `json:"supply_account_id"`
} `json:"data"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if resp.Data.SupplyAccountID != packageSvc.pkg.AccountID {
t.Fatalf("expected supply_account_id %d, got %d", packageSvc.pkg.AccountID, resp.Data.SupplyAccountID)
}
}
func TestSupplyAPI_BatchUpdatePrice_Success(t *testing.T) {
api, _, _, _, _, _ := newTestAPI()
api.packageService.(*mockPackageService).batchResp = &domain.BatchUpdatePriceResponse{

View File

@@ -50,7 +50,7 @@ func (r *PackageRepository) Create(ctx context.Context, pkg *domain.Package, req
}
err := r.pool.QueryRow(ctx, query,
pkg.SupplierID, pkg.AccountID, pkg.Platform, pkg.Model,
pkg.AccountID, pkg.SupplierID, pkg.Platform, pkg.Model,
pkg.TotalQuota, pkg.AvailableQuota, pkg.SoldQuota, pkg.ReservedQuota,
pkg.PricePer1MInput, pkg.PricePer1MOutput, pkg.MinPurchase,
startAt, endAt, pkg.ValidDays,
@@ -85,7 +85,7 @@ func (r *PackageRepository) GetByID(ctx context.Context, supplierID, id int64) (
pkg := &domain.Package{}
var startAt, endAt *time.Time
err := r.pool.QueryRow(ctx, query, id, supplierID).Scan(
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
&pkg.PricePer1MInput, &pkg.PricePer1MOutput, &pkg.MinPurchase,
&startAt, &endAt, &pkg.ValidDays,
@@ -169,7 +169,7 @@ func (r *PackageRepository) GetForUpdate(ctx context.Context, tx pgxpool.Tx, sup
pkg := &domain.Package{}
err := tx.QueryRow(ctx, query, id, supplierID).Scan(
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota, &pkg.ReservedQuota,
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
&pkg.Status, &pkg.Version,
@@ -210,7 +210,7 @@ func (r *PackageRepository) List(ctx context.Context, supplierID int64) ([]*doma
for rows.Next() {
pkg := &domain.Package{}
err := rows.Scan(
&pkg.ID, &pkg.SupplierID, &pkg.AccountID, &pkg.Platform, &pkg.Model,
&pkg.ID, &pkg.AccountID, &pkg.SupplierID, &pkg.Platform, &pkg.Model,
&pkg.TotalQuota, &pkg.AvailableQuota, &pkg.SoldQuota,
&pkg.PricePer1MInput, &pkg.PricePer1MOutput,
&pkg.Status, &pkg.MaxConcurrent, &pkg.RateLimitRPM,

View File

@@ -75,6 +75,17 @@ func TestPackageRepository_Create_Integration(t *testing.T) {
if pkg.ID == 0 {
t.Fatal("expected created package id")
}
fetched, err := repo.GetByID(context.Background(), pkg.SupplierID, pkg.ID)
if err != nil {
t.Fatalf("get package after create failed: %v", err)
}
if fetched.SupplierID != pkg.SupplierID {
t.Fatalf("expected supplier id %d, got %d", pkg.SupplierID, fetched.SupplierID)
}
if fetched.AccountID != pkg.AccountID {
t.Fatalf("expected account id %d, got %d", pkg.AccountID, fetched.AccountID)
}
}
func TestPackageRepository_GetByID_Integration(t *testing.T) {
@@ -116,23 +127,47 @@ func TestPackageRepository_List_Integration(t *testing.T) {
return
}
rows, err := pool.Query(context.Background(), `
SELECT id, user_id, available_quota
FROM supply_packages
LIMIT 10
`)
if err != nil {
t.Fatalf("列出套餐失败: %v", err)
repo := NewPackageRepository(pool)
pkg := &domain.Package{
SupplierID: 3001,
AccountID: 4001,
Platform: "anthropic",
Model: "claude-3-7-sonnet",
TotalQuota: 5000,
AvailableQuota: 5000,
PricePer1MInput: 0.4,
PricePer1MOutput: 1.2,
Status: domain.PackageStatusDraft,
ValidDays: 15,
}
defer rows.Close()
for rows.Next() {
var id, userID int64
var availableQuota float64
if scanErr := rows.Scan(&id, &userID, &availableQuota); scanErr != nil {
t.Fatalf("扫描套餐失败: %v", scanErr)
if err := repo.Create(context.Background(), pkg, "req-pkg-list-int", "trace-pkg-list-int"); err != nil {
t.Fatalf("create package for list failed: %v", err)
}
packages, err := repo.List(context.Background(), pkg.SupplierID)
if err != nil {
t.Fatalf("repo.List failed: %v", err)
}
if len(packages) == 0 {
t.Fatal("expected packages for supplier")
}
found := false
for _, listed := range packages {
if listed.ID == pkg.ID {
found = true
if listed.SupplierID != pkg.SupplierID {
t.Fatalf("expected listed supplier id %d, got %d", pkg.SupplierID, listed.SupplierID)
}
if listed.AccountID != pkg.AccountID {
t.Fatalf("expected listed account id %d, got %d", pkg.AccountID, listed.AccountID)
}
}
}
if !found {
t.Fatalf("expected package %d in supplier list", pkg.ID)
}
}
func TestPackageRepository_UpdateQuota_Integration(t *testing.T) {