chore: initial public snapshot for github upload
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,853 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/accountgroup"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AccountRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *accountRepository
|
||||
}
|
||||
|
||||
type schedulerCacheRecorder struct {
|
||||
setAccounts []*service.Account
|
||||
accounts map[int64]*service.Account
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetSnapshot(ctx context.Context, bucket service.SchedulerBucket) ([]*service.Account, bool, error) {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetSnapshot(ctx context.Context, bucket service.SchedulerBucket, accounts []service.Account) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetAccount(ctx context.Context, accountID int64) (*service.Account, error) {
|
||||
if s.accounts == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.accounts[accountID], nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetAccount(ctx context.Context, account *service.Account) error {
|
||||
s.setAccounts = append(s.setAccounts, account)
|
||||
if s.accounts == nil {
|
||||
s.accounts = make(map[int64]*service.Account)
|
||||
}
|
||||
if account != nil {
|
||||
s.accounts[account.ID] = account
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) DeleteAccount(ctx context.Context, accountID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) UpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) TryLockBucket(ctx context.Context, bucket service.SchedulerBucket, ttl time.Duration) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) ListBuckets(ctx context.Context) ([]service.SchedulerBucket, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) GetOutboxWatermark(ctx context.Context) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *schedulerCacheRecorder) SetOutboxWatermark(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestAccountRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(AccountRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *AccountRepoSuite) TestCreate() {
|
||||
account := &service.Account{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
Concurrency: 3,
|
||||
Priority: 50,
|
||||
Schedulable: true,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, account)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(account.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "original"})
|
||||
|
||||
account.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "sync-update", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusDisabled, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdate_SyncSchedulerSnapshotOnCredentialsChange() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "sync-credentials-update",
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.1",
|
||||
},
|
||||
},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5": "gpt-5.2",
|
||||
},
|
||||
}
|
||||
err := s.repo.Update(s.ctx, account)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
mapping, ok := cacheRecorder.setAccounts[0].Credentials["model_mapping"].(map[string]any)
|
||||
s.Require().True(ok)
|
||||
s.Require().Equal("gpt-5.2", mapping["gpt-5"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "to-delete"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestDelete_WithGroupBindings() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-del"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-del"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
err := s.repo.Delete(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "Delete should cascade remove bindings")
|
||||
|
||||
count, err := s.client.AccountGroup.Query().Where(accountgroup.AccountIDEQ(account.ID)).Count(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count, "expected bindings to be removed")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *AccountRepoSuite) TestList() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc1"})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc2"})
|
||||
|
||||
accounts, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(accounts, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListWithFilters() {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func(client *dbent.Client)
|
||||
platform string
|
||||
accType string
|
||||
status string
|
||||
search string
|
||||
wantCount int
|
||||
validate func(accounts []service.Account)
|
||||
}{
|
||||
{
|
||||
name: "filter_by_platform",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI})
|
||||
},
|
||||
platform: service.PlatformOpenAI,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.PlatformOpenAI, accounts[0].Platform)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_type",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t1", Type: service.AccountTypeOAuth})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "t2", Type: service.AccountTypeAPIKey})
|
||||
},
|
||||
accType: service.AccountTypeAPIKey,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.AccountTypeAPIKey, accounts[0].Type)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_status",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "s2", Status: service.StatusDisabled})
|
||||
},
|
||||
status: service.StatusDisabled,
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Equal(service.StatusDisabled, accounts[0].Status)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter_by_search",
|
||||
setup: func(client *dbent.Client) {
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "alpha-account"})
|
||||
mustCreateAccount(s.T(), client, &service.Account{Name: "beta-account"})
|
||||
},
|
||||
search: "alpha",
|
||||
wantCount: 1,
|
||||
validate: func(accounts []service.Account) {
|
||||
s.Require().Contains(accounts[0].Name, "alpha")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
tx := testEntTx(s.T())
|
||||
client := tx.Client()
|
||||
repo := newAccountRepositoryWithSQL(client, tx, nil)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.setup(client)
|
||||
|
||||
accounts, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, tt.platform, tt.accType, tt.status, tt.search, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, tt.wantCount)
|
||||
if tt.validate != nil {
|
||||
tt.validate(accounts)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// --- ListByGroup / ListActive / ListByPlatform ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListByGroup() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-list"})
|
||||
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Status: service.StatusActive})
|
||||
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Status: service.StatusActive})
|
||||
mustBindAccountToGroup(s.T(), s.client, acc1.ID, group.ID, 2)
|
||||
mustBindAccountToGroup(s.T(), s.client, acc2.ID, group.ID, 1)
|
||||
|
||||
accounts, err := s.repo.ListByGroup(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListByGroup")
|
||||
s.Require().Len(accounts, 2)
|
||||
// Should be ordered by priority
|
||||
s.Require().Equal(acc2.ID, accounts[0].ID, "expected acc2 first (priority=1)")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListActive() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "active1", Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "inactive1", Status: service.StatusDisabled})
|
||||
|
||||
accounts, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal("active1", accounts[0].Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListByPlatform() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p1", Platform: service.PlatformAnthropic, Status: service.StatusActive})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "p2", Platform: service.PlatformOpenAI, Status: service.StatusActive})
|
||||
|
||||
accounts, err := s.repo.ListByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListByPlatform")
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// --- Preload and VirtualFields ---
|
||||
|
||||
func (s *AccountRepoSuite) TestPreload_And_VirtualFields() {
|
||||
proxy := mustCreateProxy(s.T(), s.client, &service.Proxy{Name: "p1"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc1",
|
||||
ProxyID: &proxy.ID,
|
||||
})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.Proxy, "expected Proxy preload")
|
||||
s.Require().Equal(proxy.ID, got.Proxy.ID)
|
||||
s.Require().Len(got.GroupIDs, 1, "expected GroupIDs to be populated")
|
||||
s.Require().Equal(group.ID, got.GroupIDs[0])
|
||||
s.Require().Len(got.Groups, 1, "expected Groups to be populated")
|
||||
s.Require().Equal(group.ID, got.Groups[0].ID)
|
||||
|
||||
accounts, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", "acc", 0)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().NotNil(accounts[0].Proxy, "expected Proxy preload in list")
|
||||
s.Require().Equal(proxy.ID, accounts[0].Proxy.ID)
|
||||
s.Require().Len(accounts[0].GroupIDs, 1, "expected GroupIDs in list")
|
||||
s.Require().Equal(group.ID, accounts[0].GroupIDs[0])
|
||||
}
|
||||
|
||||
// --- GroupBinding / AddToGroup / RemoveFromGroup / BindGroups / GetGroups ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGroupBinding_And_BindGroups() {
|
||||
g1 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g1"})
|
||||
g2 := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g2"})
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc"})
|
||||
|
||||
s.Require().NoError(s.repo.AddToGroup(s.ctx, account.ID, g1.ID, 10), "AddToGroup")
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups")
|
||||
s.Require().Len(groups, 1, "expected 1 group")
|
||||
s.Require().Equal(g1.ID, groups[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.RemoveFromGroup(s.ctx, account.ID, g1.ID), "RemoveFromGroup")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after remove")
|
||||
s.Require().Empty(groups, "expected 0 groups after remove")
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{g1.ID, g2.ID}), "BindGroups")
|
||||
groups, err = s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetGroups after bind")
|
||||
s.Require().Len(groups, 2, "expected 2 groups after bind")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBindGroups_EmptyList() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-empty"})
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-empty"})
|
||||
mustBindAccountToGroup(s.T(), s.client, account.ID, group.ID, 1)
|
||||
|
||||
s.Require().NoError(s.repo.BindGroups(s.ctx, account.ID, []int64{}), "BindGroups empty")
|
||||
|
||||
groups, err := s.repo.GetGroups(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups, "expected 0 groups after binding empty list")
|
||||
}
|
||||
|
||||
// --- Schedulable ---
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulable() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
sched, err := s.repo.ListSchedulable(s.ctx)
|
||||
s.Require().NoError(err, "ListSchedulable")
|
||||
ids := idsOfAccounts(sched)
|
||||
s.Require().Contains(ids, okAcc.ID)
|
||||
s.Require().NotContains(ids, overloaded.ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupID_TimeBoundaries_And_StatusUpdates() {
|
||||
now := time.Now()
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sched"})
|
||||
|
||||
okAcc := mustCreateAccount(s.T(), s.client, &service.Account{Name: "ok", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, okAcc.ID, group.ID, 1)
|
||||
|
||||
future := now.Add(10 * time.Minute)
|
||||
overloaded := mustCreateAccount(s.T(), s.client, &service.Account{Name: "over", Schedulable: true, OverloadUntil: &future})
|
||||
mustBindAccountToGroup(s.T(), s.client, overloaded.ID, group.ID, 1)
|
||||
|
||||
rateLimited := mustCreateAccount(s.T(), s.client, &service.Account{Name: "rl", Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, rateLimited.ID, group.ID, 1)
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, rateLimited.ID, now.Add(10*time.Minute)), "SetRateLimited")
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, overloaded.ID, "boom"), "SetError")
|
||||
|
||||
sched, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID")
|
||||
s.Require().Len(sched, 1, "expected only ok account schedulable")
|
||||
s.Require().Equal(okAcc.ID, sched[0].ID)
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, rateLimited.ID), "ClearRateLimit")
|
||||
sched2, err := s.repo.ListSchedulableByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ListSchedulableByGroupID after ClearRateLimit")
|
||||
s.Require().Len(sched2, 2, "expected 2 schedulable accounts after ClearRateLimit")
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByPlatform() {
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(service.PlatformAnthropic, accounts[0].Platform)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestListSchedulableByGroupIDAndPlatform() {
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{Name: "g-sp"})
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a1", Platform: service.PlatformAnthropic, Schedulable: true})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "a2", Platform: service.PlatformOpenAI, Schedulable: true})
|
||||
mustBindAccountToGroup(s.T(), s.client, a1.ID, group.ID, 1)
|
||||
mustBindAccountToGroup(s.T(), s.client, a2.ID, group.ID, 2)
|
||||
|
||||
accounts, err := s.repo.ListSchedulableByGroupIDAndPlatform(s.ctx, group.ID, service.PlatformAnthropic)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(a1.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetSchedulable() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-sched", Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.SetSchedulable(s.ctx, account.ID, false))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(got.Schedulable)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_SyncSchedulerSnapshotOnDisabled() {
|
||||
account1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-1", Status: service.StatusActive, Schedulable: true})
|
||||
account2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-2", Status: service.StatusActive, Schedulable: true})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
disabled := service.StatusDisabled
|
||||
rows, err := s.repo.BulkUpdate(s.ctx, []int64{account1.ID, account2.ID}, service.AccountBulkUpdate{
|
||||
Status: &disabled,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(2), rows)
|
||||
|
||||
s.Require().Len(cacheRecorder.setAccounts, 2)
|
||||
ids := map[int64]struct{}{}
|
||||
for _, acc := range cacheRecorder.setAccounts {
|
||||
ids[acc.ID] = struct{}{}
|
||||
}
|
||||
s.Require().Contains(ids, account1.ID)
|
||||
s.Require().Contains(ids, account2.ID)
|
||||
}
|
||||
|
||||
// --- SetOverloaded / SetRateLimited / ClearRateLimit ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetOverloaded() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-over"})
|
||||
until := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.OverloadUntil)
|
||||
s.Require().WithinDuration(until, *got.OverloadUntil, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestSetRateLimited() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-rl"})
|
||||
resetAt := time.Date(2025, 6, 15, 14, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, resetAt))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.RateLimitedAt)
|
||||
s.Require().NotNil(got.RateLimitResetAt)
|
||||
s.Require().WithinDuration(resetAt, *got.RateLimitResetAt, time.Second)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearRateLimit() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-clear"})
|
||||
until := time.Now().Add(1 * time.Hour)
|
||||
s.Require().NoError(s.repo.SetOverloaded(s.ctx, account.ID, until))
|
||||
s.Require().NoError(s.repo.SetRateLimited(s.ctx, account.ID, until))
|
||||
|
||||
s.Require().NoError(s.repo.ClearRateLimit(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.RateLimitedAt)
|
||||
s.Require().Nil(got.RateLimitResetAt)
|
||||
s.Require().Nil(got.OverloadUntil)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() {
|
||||
acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"})
|
||||
acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"})
|
||||
|
||||
until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second)
|
||||
reason := `{"rule":"429","matched_keyword":"too many requests"}`
|
||||
s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason))
|
||||
|
||||
gotByID, err := s.repo.GetByID(s.ctx, acc1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(gotByID.TempUnschedulableUntil)
|
||||
s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second)
|
||||
s.Require().Equal(reason, gotByID.TempUnschedulableReason)
|
||||
|
||||
gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(gotByIDs, 2)
|
||||
s.Require().Equal(acc2.ID, gotByIDs[0].ID)
|
||||
s.Require().Nil(gotByIDs[0].TempUnschedulableUntil)
|
||||
s.Require().Equal("", gotByIDs[0].TempUnschedulableReason)
|
||||
s.Require().Equal(acc1.ID, gotByIDs[1].ID)
|
||||
s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil)
|
||||
s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second)
|
||||
s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason)
|
||||
|
||||
s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID))
|
||||
cleared, err := s.repo.GetByID(s.ctx, acc1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(cleared.TempUnschedulableUntil)
|
||||
s.Require().Equal("", cleared.TempUnschedulableReason)
|
||||
}
|
||||
|
||||
// --- UpdateLastUsed ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateLastUsed() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-used"})
|
||||
s.Require().Nil(account.LastUsedAt)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateLastUsed(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.LastUsedAt)
|
||||
}
|
||||
|
||||
// --- SetError ---
|
||||
|
||||
func (s *AccountRepoSuite) TestSetError() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-err", Status: service.StatusActive})
|
||||
|
||||
s.Require().NoError(s.repo.SetError(s.ctx, account.ID, "something went wrong"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusError, got.Status)
|
||||
s.Require().Equal("something went wrong", got.ErrorMessage)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestClearError_SyncSchedulerSnapshotOnRecovery() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-clear-err",
|
||||
Status: service.StatusError,
|
||||
ErrorMessage: "temporary error",
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
s.Require().NoError(s.repo.ClearError(s.ctx, account.ID))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusActive, got.Status)
|
||||
s.Require().Empty(got.ErrorMessage)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
}
|
||||
|
||||
// --- UpdateSessionWindow ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateSessionWindow() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-win"})
|
||||
start := time.Date(2025, 6, 15, 10, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2025, 6, 15, 15, 0, 0, 0, time.UTC)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateSessionWindow(s.ctx, account.ID, &start, &end, "active"))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.SessionWindowStart)
|
||||
s.Require().NotNil(got.SessionWindowEnd)
|
||||
s.Require().Equal("active", got.SessionWindowStatus)
|
||||
}
|
||||
|
||||
// --- UpdateExtra ---
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_MergesFields() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra",
|
||||
Extra: map[string]any{"a": "1"},
|
||||
})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"b": "2"}), "UpdateExtra")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("1", got.Extra["a"])
|
||||
s.Require().Equal("2", got.Extra["b"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_EmptyUpdates() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-extra-empty"})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{}))
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_NilExtra() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-nil-extra", Extra: nil})
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{"key": "val"}))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("val", got.Extra["key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerNeutralSkipsOutboxAndSyncsFreshSnapshot() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-neutral",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Extra: map[string]any{"codex_usage_updated_at": "old"},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{
|
||||
accounts: map[int64]*service.Account{
|
||||
account.ID: {
|
||||
ID: account.ID,
|
||||
Platform: account.Platform,
|
||||
Status: service.StatusDisabled,
|
||||
Extra: map[string]any{
|
||||
"codex_usage_updated_at": "old",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
|
||||
updates := map[string]any{
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
"codex_5h_used_percent": 88.5,
|
||||
"session_window_utilization": 0.42,
|
||||
}
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, updates))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, account.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", got.Extra["codex_usage_updated_at"])
|
||||
s.Require().Equal(88.5, got.Extra["codex_5h_used_percent"])
|
||||
s.Require().Equal(0.42, got.Extra["session_window_utilization"])
|
||||
|
||||
var outboxCount int
|
||||
s.Require().NoError(scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &outboxCount))
|
||||
s.Require().Zero(outboxCount)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().NotNil(cacheRecorder.accounts[account.ID])
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.accounts[account.ID].Status)
|
||||
s.Require().Equal("2026-03-11T10:00:00Z", cacheRecorder.accounts[account.ID].Extra["codex_usage_updated_at"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_ExhaustedCodexSnapshotSyncsSchedulerCache() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-codex-exhausted",
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
cacheRecorder := &schedulerCacheRecorder{}
|
||||
s.repo.schedulerCache = cacheRecorder
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": "2026-03-12T13:00:00Z",
|
||||
"codex_7d_reset_after_seconds": 86400,
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(0, count)
|
||||
s.Require().Len(cacheRecorder.setAccounts, 1)
|
||||
s.Require().Equal(account.ID, cacheRecorder.setAccounts[0].ID)
|
||||
s.Require().Equal(service.StatusActive, cacheRecorder.setAccounts[0].Status)
|
||||
s.Require().Equal(100.0, cacheRecorder.setAccounts[0].Extra["codex_7d_used_percent"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestUpdateExtra_SchedulerRelevantStillEnqueuesOutbox() {
|
||||
account := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-extra-mixed",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Extra: map[string]any{},
|
||||
})
|
||||
_, err := s.repo.sql.ExecContext(s.ctx, "TRUNCATE scheduler_outbox")
|
||||
s.Require().NoError(err)
|
||||
|
||||
s.Require().NoError(s.repo.UpdateExtra(s.ctx, account.ID, map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"codex_usage_updated_at": "2026-03-11T10:00:00Z",
|
||||
}))
|
||||
|
||||
var count int
|
||||
err = scanSingleRow(s.ctx, s.repo.sql, "SELECT COUNT(*) FROM scheduler_outbox", nil, &count)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(1, count)
|
||||
}
|
||||
|
||||
// --- GetByCRSAccountID ---
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID() {
|
||||
crsID := "crs-12345"
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "acc-crs",
|
||||
Extra: map[string]any{"crs_account_id": crsID},
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, crsID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got)
|
||||
s.Require().Equal("acc-crs", got.Name)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_NotFound() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "non-existent")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestGetByCRSAccountID_EmptyString() {
|
||||
got, err := s.repo.GetByCRSAccountID(s.ctx, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got)
|
||||
}
|
||||
|
||||
// --- BulkUpdate ---
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk1", Priority: 1})
|
||||
a2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk2", Priority: 1})
|
||||
|
||||
newPriority := 99
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID, a2.ID}, service.AccountBulkUpdate{
|
||||
Priority: &newPriority,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.Require().GreaterOrEqual(affected, int64(1), "expected at least one affected row")
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, a2.ID)
|
||||
s.Require().Equal(99, got1.Priority)
|
||||
s.Require().Equal(99, got2.Priority)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeCredentials() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-cred",
|
||||
Credentials: map[string]any{"existing": "value"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Credentials: map[string]any{"new_key": "new_value"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("value", got.Credentials["existing"])
|
||||
s.Require().Equal("new_value", got.Credentials["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_MergeExtra() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bulk-extra",
|
||||
Extra: map[string]any{"existing": "val"},
|
||||
})
|
||||
|
||||
_, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{
|
||||
Extra: map[string]any{"new_key": "new_val"},
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got, _ := s.repo.GetByID(s.ctx, a1.ID)
|
||||
s.Require().Equal("val", got.Extra["existing"])
|
||||
s.Require().Equal("new_val", got.Extra["new_key"])
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyIDs() {
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func (s *AccountRepoSuite) TestBulkUpdate_EmptyUpdates() {
|
||||
a1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "bulk-empty"})
|
||||
|
||||
affected, err := s.repo.BulkUpdate(s.ctx, []int64{a1.ID}, service.AccountBulkUpdate{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected)
|
||||
}
|
||||
|
||||
func idsOfAccounts(accounts []service.Account) []int64 {
|
||||
out := make([]int64, 0, len(accounts))
|
||||
for i := range accounts {
|
||||
out = append(out, accounts[i].ID)
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// AESEncryptor implements SecretEncryptor using AES-256-GCM
|
||||
type AESEncryptor struct {
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewAESEncryptor creates a new AES encryptor
|
||||
func NewAESEncryptor(cfg *config.Config) (service.SecretEncryptor, error) {
|
||||
key, err := hex.DecodeString(cfg.Totp.EncryptionKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid totp encryption key: %w", err)
|
||||
}
|
||||
|
||||
if len(key) != 32 {
|
||||
return nil, fmt.Errorf("totp encryption key must be 32 bytes (64 hex chars), got %d bytes", len(key))
|
||||
}
|
||||
|
||||
return &AESEncryptor{key: key}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext using AES-256-GCM
|
||||
// Output format: base64(nonce + ciphertext + tag)
|
||||
func (e *AESEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
// Generate a random nonce
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return "", fmt.Errorf("generate nonce: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt the plaintext
|
||||
// Seal appends the ciphertext and tag to the nonce
|
||||
ciphertext := gcm.Seal(nonce, nonce, []byte(plaintext), nil)
|
||||
|
||||
// Encode as base64
|
||||
return base64.StdEncoding.EncodeToString(ciphertext), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext using AES-256-GCM
|
||||
func (e *AESEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
// Decode from base64
|
||||
data, err := base64.StdEncoding.DecodeString(ciphertext)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode base64: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(e.key)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create cipher: %w", err)
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create gcm: %w", err)
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(data) < nonceSize {
|
||||
return "", fmt.Errorf("ciphertext too short")
|
||||
}
|
||||
|
||||
// Extract nonce and ciphertext
|
||||
nonce, ciphertextData := data[:nonceSize], data[nonceSize:]
|
||||
|
||||
// Decrypt
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertextData, nil)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decrypt: %w", err)
|
||||
}
|
||||
|
||||
return string(plaintext), nil
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcementread"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementReadRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementReadRepository(client *dbent.Client) service.AnnouncementReadRepository {
|
||||
return &announcementReadRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
return client.AnnouncementRead.Create().
|
||||
SetAnnouncementID(announcementID).
|
||||
SetUserID(userID).
|
||||
SetReadAt(readAt).
|
||||
OnConflictColumns(announcementread.FieldAnnouncementID, announcementread.FieldUserID).
|
||||
DoNothing().
|
||||
Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(announcementIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.UserIDEQ(userID),
|
||||
announcementread.AnnouncementIDIn(announcementIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].AnnouncementID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error) {
|
||||
if len(userIDs) == 0 {
|
||||
return map[int64]time.Time{}, nil
|
||||
}
|
||||
|
||||
rows, err := r.client.AnnouncementRead.Query().
|
||||
Where(
|
||||
announcementread.AnnouncementIDEQ(announcementID),
|
||||
announcementread.UserIDIn(userIDs...),
|
||||
).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(map[int64]time.Time, len(rows))
|
||||
for i := range rows {
|
||||
out[rows[i].UserID] = rows[i].ReadAt
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *announcementReadRepository) CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error) {
|
||||
count, err := r.client.AnnouncementRead.Query().
|
||||
Where(announcementread.AnnouncementIDEQ(announcementID)).
|
||||
Count(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int64(count), nil
|
||||
}
|
||||
@@ -0,0 +1,197 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/announcement"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type announcementRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewAnnouncementRepository(client *dbent.Client) service.AnnouncementRepository {
|
||||
return &announcementRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Create(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.Create().
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
applyAnnouncementEntityToService(a, created)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) GetByID(ctx context.Context, id int64) (*service.Announcement, error) {
|
||||
m, err := r.client.Announcement.Query().
|
||||
Where(announcement.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
return announcementEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Update(ctx context.Context, a *service.Announcement) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.Announcement.UpdateOneID(a.ID).
|
||||
SetTitle(a.Title).
|
||||
SetContent(a.Content).
|
||||
SetStatus(a.Status).
|
||||
SetNotifyMode(a.NotifyMode).
|
||||
SetTargeting(a.Targeting)
|
||||
|
||||
if a.StartsAt != nil {
|
||||
builder.SetStartsAt(*a.StartsAt)
|
||||
} else {
|
||||
builder.ClearStartsAt()
|
||||
}
|
||||
if a.EndsAt != nil {
|
||||
builder.SetEndsAt(*a.EndsAt)
|
||||
} else {
|
||||
builder.ClearEndsAt()
|
||||
}
|
||||
if a.CreatedBy != nil {
|
||||
builder.SetCreatedBy(*a.CreatedBy)
|
||||
} else {
|
||||
builder.ClearCreatedBy()
|
||||
}
|
||||
if a.UpdatedBy != nil {
|
||||
builder.SetUpdatedBy(*a.UpdatedBy)
|
||||
} else {
|
||||
builder.ClearUpdatedBy()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrAnnouncementNotFound, nil)
|
||||
}
|
||||
|
||||
a.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.Announcement.Delete().Where(announcement.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *announcementRepository) List(
|
||||
ctx context.Context,
|
||||
params pagination.PaginationParams,
|
||||
filters service.AnnouncementListFilters,
|
||||
) ([]service.Announcement, *pagination.PaginationResult, error) {
|
||||
q := r.client.Announcement.Query()
|
||||
|
||||
if filters.Status != "" {
|
||||
q = q.Where(announcement.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.Search != "" {
|
||||
q = q.Where(
|
||||
announcement.Or(
|
||||
announcement.TitleContainsFold(filters.Search),
|
||||
announcement.ContentContainsFold(filters.Search),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
items, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(announcement.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
out := announcementEntitiesToService(items)
|
||||
return out, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *announcementRepository) ListActive(ctx context.Context, now time.Time) ([]service.Announcement, error) {
|
||||
q := r.client.Announcement.Query().
|
||||
Where(
|
||||
announcement.StatusEQ(service.AnnouncementStatusActive),
|
||||
announcement.Or(announcement.StartsAtIsNil(), announcement.StartsAtLTE(now)),
|
||||
announcement.Or(announcement.EndsAtIsNil(), announcement.EndsAtGT(now)),
|
||||
).
|
||||
Order(dbent.Desc(announcement.FieldID))
|
||||
|
||||
items, err := q.All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return announcementEntitiesToService(items), nil
|
||||
}
|
||||
|
||||
func applyAnnouncementEntityToService(dst *service.Announcement, src *dbent.Announcement) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
|
||||
func announcementEntityToService(m *dbent.Announcement) *service.Announcement {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Announcement{
|
||||
ID: m.ID,
|
||||
Title: m.Title,
|
||||
Content: m.Content,
|
||||
Status: m.Status,
|
||||
NotifyMode: m.NotifyMode,
|
||||
Targeting: m.Targeting,
|
||||
StartsAt: m.StartsAt,
|
||||
EndsAt: m.EndsAt,
|
||||
CreatedBy: m.CreatedBy,
|
||||
UpdatedBy: m.UpdatedBy,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func announcementEntitiesToService(models []*dbent.Announcement) []service.Announcement {
|
||||
out := make([]service.Announcement, 0, len(models))
|
||||
for i := range models {
|
||||
if s := announcementEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyRateLimitKeyPrefix = "apikey:ratelimit:"
|
||||
apiKeyRateLimitDuration = 24 * time.Hour
|
||||
apiKeyAuthCachePrefix = "apikey:auth:"
|
||||
authCacheInvalidateChannel = "auth:cache:invalidate"
|
||||
)
|
||||
|
||||
// apiKeyRateLimitKey generates the Redis key for API key creation rate limiting.
|
||||
func apiKeyRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func apiKeyAuthCacheKey(key string) string {
|
||||
return fmt.Sprintf("%s%s", apiKeyAuthCachePrefix, key)
|
||||
}
|
||||
|
||||
type apiKeyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewAPIKeyCache(rdb *redis.Client) service.APIKeyCache {
|
||||
return &apiKeyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, apiKeyRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := apiKeyRateLimitKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return c.rdb.Incr(ctx, apiKey).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return c.rdb.Expire(ctx, apiKey, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) GetAuthCache(ctx context.Context, key string) (*service.APIKeyAuthCacheEntry, error) {
|
||||
val, err := c.rdb.Get(ctx, apiKeyAuthCacheKey(key)).Bytes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var entry service.APIKeyAuthCacheEntry
|
||||
if err := json.Unmarshal(val, &entry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &entry, nil
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *service.APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
if entry == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, apiKeyAuthCacheKey(key), payload, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err()
|
||||
}
|
||||
|
||||
// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances
|
||||
func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err()
|
||||
}
|
||||
|
||||
// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages
|
||||
func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel)
|
||||
|
||||
// Verify subscription is working
|
||||
_, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
return fmt.Errorf("subscribe to auth cache invalidation: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := pubsub.Close(); err != nil {
|
||||
log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
ch := pubsub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg, ok := <-ch:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if msg != nil {
|
||||
handler(msg.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ApiKeyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestCreateAttemptCount() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_zero_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
|
||||
require.NoError(s.T(), err, "expected nil error for missing key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "increment_increases_count_and_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", apiKeyRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount")
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID), "IncrementCreateAttemptCount 2")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetCreateAttemptCount")
|
||||
require.Equal(s.T(), 2, count, "count mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, apiKeyRateLimitDuration)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "delete_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
userID := int64(1)
|
||||
|
||||
require.NoError(s.T(), cache.IncrementCreateAttemptCount(ctx, userID))
|
||||
require.NoError(s.T(), cache.DeleteCreateAttemptCount(ctx, userID), "DeleteCreateAttemptCount")
|
||||
|
||||
count, err := cache.GetCreateAttemptCount(ctx, userID)
|
||||
require.NoError(s.T(), err, "expected nil error after delete")
|
||||
require.Equal(s.T(), 0, count, "expected zero count after delete")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 每个 case 重新获取隔离资源
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ApiKeyCacheSuite) TestDailyUsage() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache)
|
||||
}{
|
||||
{
|
||||
name: "increment_increases_count",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage")
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey), "IncrementDailyUsage 2")
|
||||
|
||||
n, err := rdb.Get(ctx, dailyKey).Int()
|
||||
require.NoError(s.T(), err, "Get dailyKey")
|
||||
require.Equal(s.T(), 2, n, "expected daily usage=2")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_expiry_sets_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache *apiKeyCache) {
|
||||
dailyKey := "daily:sk-test-expiry"
|
||||
|
||||
require.NoError(s.T(), cache.IncrementDailyUsage(ctx, dailyKey))
|
||||
require.NoError(s.T(), cache.SetDailyUsageExpiry(ctx, dailyKey, 1*time.Hour), "SetDailyUsageExpiry")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, dailyKey).Result()
|
||||
require.NoError(s.T(), err, "TTL dailyKey")
|
||||
require.Greater(s.T(), ttl, time.Duration(0), "expected ttl > 0")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := &apiKeyCache{rdb: rdb}
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApiKeyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ApiKeyCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApiKeyRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "apikey:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "apikey:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "apikey:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "apikey:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := apiKeyRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,662 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/ent/schema/mixins"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type apiKeyRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewAPIKeyRepository(client *dbent.Client, sqlDB *sql.DB) service.APIKeyRepository {
|
||||
return newAPIKeyRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newAPIKeyRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *apiKeyRepository {
|
||||
return &apiKeyRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) activeQuery() *dbent.APIKeyQuery {
|
||||
// 默认过滤已软删除记录,避免删除后仍被查询到。
|
||||
return r.client.APIKey.Query().Where(apikey.DeletedAtIsNil())
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Create(ctx context.Context, key *service.APIKey) error {
|
||||
builder := r.client.APIKey.Create().
|
||||
SetUserID(key.UserID).
|
||||
SetKey(key.Key).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetNillableGroupID(key.GroupID).
|
||||
SetNillableLastUsedAt(key.LastUsedAt).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetNillableExpiresAt(key.ExpiresAt).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d)
|
||||
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
key.ID = created.ID
|
||||
key.LastUsedAt = created.LastUsedAt
|
||||
key.CreatedAt = created.CreatedAt
|
||||
key.UpdatedAt = created.UpdatedAt
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrAPIKeyExists)
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByID(ctx context.Context, id int64) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
// GetKeyAndOwnerID 根据 API Key ID 获取其 key 与所有者(用户)ID。
|
||||
// 相比 GetByID,此方法性能更优,因为:
|
||||
// - 使用 Select() 只查询必要字段,减少数据传输量
|
||||
// - 不加载完整的 API Key 实体及其关联数据(User、Group 等)
|
||||
// - 适用于删除等只需 key 与用户 ID 的场景
|
||||
func (r *apiKeyRepository) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Select(apikey.FieldKey, apikey.FieldUserID).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return "", 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return "", 0, err
|
||||
}
|
||||
return m.Key, m.UserID, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKey(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*service.APIKey, error) {
|
||||
m, err := r.activeQuery().
|
||||
Where(apikey.KeyEQ(key)).
|
||||
Select(
|
||||
apikey.FieldID,
|
||||
apikey.FieldUserID,
|
||||
apikey.FieldGroupID,
|
||||
apikey.FieldStatus,
|
||||
apikey.FieldIPWhitelist,
|
||||
apikey.FieldIPBlacklist,
|
||||
apikey.FieldQuota,
|
||||
apikey.FieldQuotaUsed,
|
||||
apikey.FieldExpiresAt,
|
||||
apikey.FieldRateLimit5h,
|
||||
apikey.FieldRateLimit1d,
|
||||
apikey.FieldRateLimit7d,
|
||||
).
|
||||
WithUser(func(q *dbent.UserQuery) {
|
||||
q.Select(
|
||||
user.FieldID,
|
||||
user.FieldStatus,
|
||||
user.FieldRole,
|
||||
user.FieldBalance,
|
||||
user.FieldConcurrency,
|
||||
)
|
||||
}).
|
||||
WithGroup(func(q *dbent.GroupQuery) {
|
||||
q.Select(
|
||||
group.FieldID,
|
||||
group.FieldName,
|
||||
group.FieldPlatform,
|
||||
group.FieldStatus,
|
||||
group.FieldSubscriptionType,
|
||||
group.FieldRateMultiplier,
|
||||
group.FieldDailyLimitUsd,
|
||||
group.FieldWeeklyLimitUsd,
|
||||
group.FieldMonthlyLimitUsd,
|
||||
group.FieldImagePrice1k,
|
||||
group.FieldImagePrice2k,
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldSoraImagePrice360,
|
||||
group.FieldSoraImagePrice540,
|
||||
group.FieldSoraVideoPricePerRequest,
|
||||
group.FieldSoraVideoPricePerRequestHd,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldFallbackGroupIDOnInvalidRequest,
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
group.FieldMcpXMLInject,
|
||||
group.FieldSupportedModelScopes,
|
||||
group.FieldAllowMessagesDispatch,
|
||||
group.FieldDefaultMappedModel,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return apiKeyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Update(ctx context.Context, key *service.APIKey) error {
|
||||
// 使用原子操作:将软删除检查与更新合并到同一语句,避免竞态条件。
|
||||
// 之前的实现先检查 Exist 再 UpdateOneID,若在两步之间发生软删除,
|
||||
// 则会更新已删除的记录。
|
||||
// 这里选择 Update().Where(),确保只有未软删除记录能被更新。
|
||||
// 同时显式设置 updated_at,避免二次查询带来的并发可见性问题。
|
||||
client := clientFromContext(ctx, r.client)
|
||||
now := time.Now()
|
||||
builder := client.APIKey.Update().
|
||||
Where(apikey.IDEQ(key.ID), apikey.DeletedAtIsNil()).
|
||||
SetName(key.Name).
|
||||
SetStatus(key.Status).
|
||||
SetQuota(key.Quota).
|
||||
SetQuotaUsed(key.QuotaUsed).
|
||||
SetRateLimit5h(key.RateLimit5h).
|
||||
SetRateLimit1d(key.RateLimit1d).
|
||||
SetRateLimit7d(key.RateLimit7d).
|
||||
SetUsage5h(key.Usage5h).
|
||||
SetUsage1d(key.Usage1d).
|
||||
SetUsage7d(key.Usage7d).
|
||||
SetUpdatedAt(now)
|
||||
if key.GroupID != nil {
|
||||
builder.SetGroupID(*key.GroupID)
|
||||
} else {
|
||||
builder.ClearGroupID()
|
||||
}
|
||||
|
||||
// Expiration time
|
||||
if key.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*key.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
// Rate limit window start times
|
||||
if key.Window5hStart != nil {
|
||||
builder.SetWindow5hStart(*key.Window5hStart)
|
||||
} else {
|
||||
builder.ClearWindow5hStart()
|
||||
}
|
||||
if key.Window1dStart != nil {
|
||||
builder.SetWindow1dStart(*key.Window1dStart)
|
||||
} else {
|
||||
builder.ClearWindow1dStart()
|
||||
}
|
||||
if key.Window7dStart != nil {
|
||||
builder.SetWindow7dStart(*key.Window7dStart)
|
||||
} else {
|
||||
builder.ClearWindow7dStart()
|
||||
}
|
||||
|
||||
// IP 限制字段
|
||||
if len(key.IPWhitelist) > 0 {
|
||||
builder.SetIPWhitelist(key.IPWhitelist)
|
||||
} else {
|
||||
builder.ClearIPWhitelist()
|
||||
}
|
||||
if len(key.IPBlacklist) > 0 {
|
||||
builder.SetIPBlacklist(key.IPBlacklist)
|
||||
} else {
|
||||
builder.ClearIPBlacklist()
|
||||
}
|
||||
|
||||
affected, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
// 更新影响行数为 0,说明记录不存在或已被软删除。
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
// 使用同一时间戳回填,避免并发删除导致二次查询失败。
|
||||
key.UpdatedAt = now
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) Delete(ctx context.Context, id int64) error {
|
||||
// 显式软删除:避免依赖 Hook 行为,确保 deleted_at 一定被设置。
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetDeletedAt(time.Now()).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
exists, err := r.client.APIKey.Query().
|
||||
Where(apikey.IDEQ(id)).
|
||||
Exist(mixins.SkipSoftDelete(ctx))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return nil
|
||||
}
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters service.APIKeyListFilters) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.UserIDEQ(userID))
|
||||
|
||||
// Apply filters
|
||||
if filters.Search != "" {
|
||||
q = q.Where(apikey.Or(
|
||||
apikey.NameContainsFold(filters.Search),
|
||||
apikey.KeyContainsFold(filters.Search),
|
||||
))
|
||||
}
|
||||
if filters.Status != "" {
|
||||
q = q.Where(apikey.StatusEQ(filters.Status))
|
||||
}
|
||||
if filters.GroupID != nil {
|
||||
if *filters.GroupID == 0 {
|
||||
q = q.Where(apikey.GroupIDIsNil())
|
||||
} else {
|
||||
q = q.Where(apikey.GroupIDEQ(*filters.GroupID))
|
||||
}
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
ids, err := r.client.APIKey.Query().
|
||||
Where(apikey.UserIDEQ(userID), apikey.IDIn(apiKeyIDs...), apikey.DeletedAtIsNil()).
|
||||
IDs(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.UserIDEQ(userID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
count, err := r.activeQuery().Where(apikey.KeyEQ(key)).Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) {
|
||||
q := r.activeQuery().Where(apikey.GroupIDEQ(groupID))
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
keys, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(apikey.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
|
||||
return outKeys, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SearchAPIKeys searches API keys by user ID and/or keyword (name)
|
||||
func (r *apiKeyRepository) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]service.APIKey, error) {
|
||||
q := r.activeQuery()
|
||||
if userID > 0 {
|
||||
q = q.Where(apikey.UserIDEQ(userID))
|
||||
}
|
||||
|
||||
if keyword != "" {
|
||||
q = q.Where(apikey.NameContainsFold(keyword))
|
||||
}
|
||||
|
||||
keys, err := q.Limit(limit).Order(dbent.Desc(apikey.FieldID)).All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
outKeys := make([]service.APIKey, 0, len(keys))
|
||||
for i := range keys {
|
||||
outKeys = append(outKeys, *apiKeyEntityToService(keys[i]))
|
||||
}
|
||||
return outKeys, nil
|
||||
}
|
||||
|
||||
// ClearGroupIDByGroupID 将指定分组的所有 API Key 的 group_id 设为 nil
|
||||
func (r *apiKeyRepository) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
n, err := r.client.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(groupID), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx)
|
||||
return int64(n), err
|
||||
}
|
||||
|
||||
// CountByGroupID 获取分组的 API Key 数量
|
||||
func (r *apiKeyRepository) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
count, err := r.activeQuery().Where(apikey.GroupIDEQ(groupID)).Count(ctx)
|
||||
return int64(count), err
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.UserIDEQ(userID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
keys, err := r.activeQuery().
|
||||
Where(apikey.GroupIDEQ(groupID)).
|
||||
Select(apikey.FieldKey).
|
||||
Strings(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsed 使用 Ent 原子递增 quota_used 字段并返回新值
|
||||
func (r *apiKeyRepository) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
updated, err := r.client.APIKey.UpdateOneID(id).
|
||||
Where(apikey.DeletedAtIsNil()).
|
||||
AddQuotaUsed(amount).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return 0, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return 0, err
|
||||
}
|
||||
return updated.QuotaUsed, nil
|
||||
}
|
||||
|
||||
// IncrementQuotaUsedAndGetState atomically increments quota_used, conditionally marks the key
|
||||
// as quota_exhausted, and returns the latest quota state in one round trip.
|
||||
func (r *apiKeyRepository) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*service.APIKeyQuotaUsageState, error) {
|
||||
query := `
|
||||
UPDATE api_keys
|
||||
SET
|
||||
quota_used = quota_used + $1,
|
||||
status = CASE
|
||||
WHEN quota > 0 AND quota_used + $1 >= quota THEN $2
|
||||
ELSE status
|
||||
END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3 AND deleted_at IS NULL
|
||||
RETURNING quota_used, quota, key, status
|
||||
`
|
||||
|
||||
state := &service.APIKeyQuotaUsageState{}
|
||||
if err := scanSingleRow(ctx, r.sql, query, []any{amount, service.StatusAPIKeyQuotaExhausted, id}, &state.QuotaUsed, &state.Quota, &state.Key, &state.Status); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (r *apiKeyRepository) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
affected, err := r.client.APIKey.Update().
|
||||
Where(apikey.IDEQ(id), apikey.DeletedAtIsNil()).
|
||||
SetLastUsedAt(usedAt).
|
||||
SetUpdatedAt(usedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrAPIKeyNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementRateLimitUsage atomically increments all rate limit usage counters and initializes
|
||||
// window start times via COALESCE if not already set.
|
||||
func (r *apiKeyRepository) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN $1 ELSE usage_5h + $1 END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN $1 ELSE usage_1d + $1 END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN $1 ELSE usage_7d + $1 END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NULL OR window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NULL OR window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NULL OR window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $2 AND deleted_at IS NULL`,
|
||||
cost, id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ResetRateLimitWindows resets expired rate limit windows atomically.
|
||||
func (r *apiKeyRepository) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
_, err := r.sql.ExecContext(ctx, `
|
||||
UPDATE api_keys SET
|
||||
usage_5h = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN 0 ELSE usage_5h END,
|
||||
window_5h_start = CASE WHEN window_5h_start IS NOT NULL AND window_5h_start + INTERVAL '5 hours' <= NOW() THEN NOW() ELSE window_5h_start END,
|
||||
usage_1d = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN 0 ELSE usage_1d END,
|
||||
window_1d_start = CASE WHEN window_1d_start IS NOT NULL AND window_1d_start + INTERVAL '24 hours' <= NOW() THEN date_trunc('day', NOW()) ELSE window_1d_start END,
|
||||
usage_7d = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN 0 ELSE usage_7d END,
|
||||
window_7d_start = CASE WHEN window_7d_start IS NOT NULL AND window_7d_start + INTERVAL '7 days' <= NOW() THEN date_trunc('day', NOW()) ELSE window_7d_start END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetRateLimitData returns the current rate limit usage and window start times for an API key.
|
||||
func (r *apiKeyRepository) GetRateLimitData(ctx context.Context, id int64) (result *service.APIKeyRateLimitData, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT usage_5h, usage_1d, usage_7d, window_5h_start, window_1d_start, window_7d_start
|
||||
FROM api_keys
|
||||
WHERE id = $1 AND deleted_at IS NULL`,
|
||||
id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
if !rows.Next() {
|
||||
return nil, service.ErrAPIKeyNotFound
|
||||
}
|
||||
data := &service.APIKeyRateLimitData{}
|
||||
if err := rows.Scan(&data.Usage5h, &data.Usage1d, &data.Usage7d, &data.Window5hStart, &data.Window1dStart, &data.Window7dStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, rows.Err()
|
||||
}
|
||||
|
||||
func apiKeyEntityToService(m *dbent.APIKey) *service.APIKey {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.APIKey{
|
||||
ID: m.ID,
|
||||
UserID: m.UserID,
|
||||
Key: m.Key,
|
||||
Name: m.Name,
|
||||
Status: m.Status,
|
||||
IPWhitelist: m.IPWhitelist,
|
||||
IPBlacklist: m.IPBlacklist,
|
||||
LastUsedAt: m.LastUsedAt,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
GroupID: m.GroupID,
|
||||
Quota: m.Quota,
|
||||
QuotaUsed: m.QuotaUsed,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
RateLimit5h: m.RateLimit5h,
|
||||
RateLimit1d: m.RateLimit1d,
|
||||
RateLimit7d: m.RateLimit7d,
|
||||
Usage5h: m.Usage5h,
|
||||
Usage1d: m.Usage1d,
|
||||
Usage7d: m.Usage7d,
|
||||
Window5hStart: m.Window5hStart,
|
||||
Window1dStart: m.Window1dStart,
|
||||
Window7dStart: m.Window7dStart,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func userEntityToService(u *dbent.User) *service.User {
|
||||
if u == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
SoraStorageQuotaBytes: u.SoraStorageQuotaBytes,
|
||||
SoraStorageUsedBytes: u.SoraStorageUsedBytes,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.Group{
|
||||
ID: g.ID,
|
||||
Name: g.Name,
|
||||
Description: derefString(g.Description),
|
||||
Platform: g.Platform,
|
||||
RateMultiplier: g.RateMultiplier,
|
||||
IsExclusive: g.IsExclusive,
|
||||
Status: g.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: g.SubscriptionType,
|
||||
DailyLimitUSD: g.DailyLimitUsd,
|
||||
WeeklyLimitUSD: g.WeeklyLimitUsd,
|
||||
MonthlyLimitUSD: g.MonthlyLimitUsd,
|
||||
ImagePrice1K: g.ImagePrice1k,
|
||||
ImagePrice2K: g.ImagePrice2k,
|
||||
ImagePrice4K: g.ImagePrice4k,
|
||||
SoraImagePrice360: g.SoraImagePrice360,
|
||||
SoraImagePrice540: g.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: g.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd,
|
||||
SoraStorageQuotaBytes: g.SoraStorageQuotaBytes,
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
MCPXMLInject: g.McpXMLInject,
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func derefString(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
@@ -0,0 +1,493 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type APIKeyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *apiKeyRepository
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = newAPIKeyRepositoryWithSQL(s.client, tx)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(APIKeyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("create@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-test",
|
||||
Name: "Test Key",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, key)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(key.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-create-test", got.Key)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey() {
|
||||
user := s.mustCreateUser("getbykey@test.com")
|
||||
group := s.mustCreateGroup("g-key")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-getbykey",
|
||||
Name: "My Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestGetByKey_NotFound() {
|
||||
_, err := s.repo.GetByKey(s.ctx, "non-existent-key")
|
||||
s.Require().Error(err, "expected error for non-existent key")
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update",
|
||||
Name: "Original",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("sk-update", got.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got.Status)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestUpdate_ClearGroupID() {
|
||||
user := s.mustCreateUser("cleargroup@test.com")
|
||||
group := s.mustCreateGroup("g-clear")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-clear-group",
|
||||
Name: "Group Key",
|
||||
GroupID: &group.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
key.GroupID = nil
|
||||
err := s.repo.Update(s.ctx, key)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Nil(got.GroupID, "expected GroupID to be cleared")
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com")
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-delete",
|
||||
Name: "Delete Me",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, key))
|
||||
|
||||
err := s.repo.Delete(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- ListByUserID / CountByUserID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listbyuser@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-list-1", "Key 1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-list-2", "Key 2", nil)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByUserID_Pagination() {
|
||||
user := s.mustCreateUser("paging@test.com")
|
||||
for i := 0; i < 5; i++ {
|
||||
s.mustCreateApiKey(user.ID, "sk-page-"+string(rune('a'+i)), "Key", nil)
|
||||
}
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 2}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(5), page.Total)
|
||||
s.Require().Equal(3, page.Pages)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByUserID() {
|
||||
user := s.mustCreateUser("count@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-count-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-count-2", "K2", nil)
|
||||
|
||||
count, err := s.repo.CountByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "CountByUserID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
// --- ListByGroupID / CountByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestListByGroupID() {
|
||||
user := s.mustCreateUser("listbygroup@test.com")
|
||||
group := s.mustCreateGroup("g-list")
|
||||
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-1", "K1", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-grp-3", "K3", nil) // no group
|
||||
|
||||
keys, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(keys, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
// User preloaded
|
||||
s.Require().NotNil(keys[0].User)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCountByGroupID() {
|
||||
user := s.mustCreateUser("countgroup@test.com")
|
||||
group := s.mustCreateGroup("g-count")
|
||||
s.mustCreateApiKey(user.ID, "sk-gc-1", "K1", &group.ID)
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), count)
|
||||
}
|
||||
|
||||
// --- ExistsByKey ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestExistsByKey() {
|
||||
user := s.mustCreateUser("exists@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-exists", "K", nil)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-exists")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByKey(s.ctx, "sk-not-exists")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- SearchAPIKeys ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys() {
|
||||
user := s.mustCreateUser("search@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-search-1", "Production Key", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-search-2", "Development Key", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "prod", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Contains(found[0].Name, "Production")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoKeyword() {
|
||||
user := s.mustCreateUser("searchnokw@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-1", "K1", nil)
|
||||
s.mustCreateApiKey(user.ID, "sk-nk-2", "K2", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 2)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestSearchAPIKeys_NoUserID() {
|
||||
user := s.mustCreateUser("searchnouid@test.com")
|
||||
s.mustCreateApiKey(user.ID, "sk-nu-1", "TestKey", nil)
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, 0, "testkey", 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(found, 1)
|
||||
}
|
||||
|
||||
// --- ClearGroupIDByGroupID ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestClearGroupIDByGroupID() {
|
||||
user := s.mustCreateUser("cleargrp@test.com")
|
||||
group := s.mustCreateGroup("g-clear-bulk")
|
||||
|
||||
k1 := s.mustCreateApiKey(user.ID, "sk-clr-1", "K1", &group.ID)
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-clr-2", "K2", &group.ID)
|
||||
s.mustCreateApiKey(user.ID, "sk-clr-3", "K3", nil) // no group
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
got1, _ := s.repo.GetByID(s.ctx, k1.ID)
|
||||
got2, _ := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().Nil(got1.GroupID)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined CRUD/Search/ClearGroupID (original test preserved as integration) ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestCRUD_Search_ClearGroupID() {
|
||||
user := s.mustCreateUser("k@example.com")
|
||||
group := s.mustCreateGroup("g-k")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-test-1", "My Key", &group.ID)
|
||||
key.GroupID = &group.ID
|
||||
|
||||
got, err := s.repo.GetByKey(s.ctx, key.Key)
|
||||
s.Require().NoError(err, "GetByKey")
|
||||
s.Require().Equal(key.ID, got.ID)
|
||||
s.Require().NotNil(got.User)
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().NotNil(got.Group)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
|
||||
key.Name = "Renamed"
|
||||
key.Status = service.StatusDisabled
|
||||
key.GroupID = nil
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update")
|
||||
|
||||
got2, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("sk-test-1", got2.Key, "Update should not change key")
|
||||
s.Require().Equal(user.ID, got2.UserID, "Update should not change user_id")
|
||||
s.Require().Equal("Renamed", got2.Name)
|
||||
s.Require().Equal(service.StatusDisabled, got2.Status)
|
||||
s.Require().Nil(got2.GroupID)
|
||||
|
||||
keys, page, err := s.repo.ListByUserID(s.ctx, user.ID, pagination.PaginationParams{Page: 1, PageSize: 10}, service.APIKeyListFilters{})
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(keys, 1)
|
||||
|
||||
exists, err := s.repo.ExistsByKey(s.ctx, "sk-test-1")
|
||||
s.Require().NoError(err, "ExistsByKey")
|
||||
s.Require().True(exists, "expected key to exist")
|
||||
|
||||
found, err := s.repo.SearchAPIKeys(s.ctx, user.ID, "renam", 10)
|
||||
s.Require().NoError(err, "SearchAPIKeys")
|
||||
s.Require().Len(found, 1)
|
||||
s.Require().Equal(key.ID, found[0].ID)
|
||||
|
||||
// ClearGroupIDByGroupID
|
||||
k2 := s.mustCreateApiKey(user.ID, "sk-test-2", "Group Key", &group.ID)
|
||||
k2.GroupID = &group.ID
|
||||
|
||||
countBefore, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(1), countBefore, "expected 1 key in group before clear")
|
||||
|
||||
affected, err := s.repo.ClearGroupIDByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "ClearGroupIDByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got3, err := s.repo.GetByID(s.ctx, k2.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Nil(got3.GroupID, "expected GroupID cleared")
|
||||
|
||||
countAfter, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID after clear")
|
||||
s.Require().Equal(int64(0), countAfter, "expected 0 keys in group after clear")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateUser(email string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) mustCreateApiKey(userID int64, key, name string, groupID *int64) *service.APIKey {
|
||||
s.T().Helper()
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: name,
|
||||
GroupID: groupID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, k), "create api key")
|
||||
return k
|
||||
}
|
||||
|
||||
// --- IncrementQuotaUsed ---
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_Basic() {
|
||||
user := s.mustCreateUser("incr-basic@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-basic", "Incr", nil)
|
||||
|
||||
newQuota, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed")
|
||||
s.Require().Equal(1.5, newQuota, "第一次递增后应为 1.5")
|
||||
|
||||
newQuota, err = s.repo.IncrementQuotaUsed(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsed second")
|
||||
s.Require().Equal(4.0, newQuota, "第二次递增后应为 4.0")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_NotFound() {
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, 999999, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "不存在的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsed_DeletedKey() {
|
||||
user := s.mustCreateUser("incr-deleted@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-incr-del", "Deleted", nil)
|
||||
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, key.ID), "Delete")
|
||||
|
||||
_, err := s.repo.IncrementQuotaUsed(s.ctx, key.ID, 1.0)
|
||||
s.Require().ErrorIs(err, service.ErrAPIKeyNotFound, "已删除的 key 应返回 ErrAPIKeyNotFound")
|
||||
}
|
||||
|
||||
func (s *APIKeyRepoSuite) TestIncrementQuotaUsedAndGetState() {
|
||||
user := s.mustCreateUser("quota-state@test.com")
|
||||
key := s.mustCreateApiKey(user.ID, "sk-quota-state", "QuotaState", nil)
|
||||
key.Quota = 3
|
||||
key.QuotaUsed = 1
|
||||
s.Require().NoError(s.repo.Update(s.ctx, key), "Update quota")
|
||||
|
||||
state, err := s.repo.IncrementQuotaUsedAndGetState(s.ctx, key.ID, 2.5)
|
||||
s.Require().NoError(err, "IncrementQuotaUsedAndGetState")
|
||||
s.Require().NotNil(state)
|
||||
s.Require().Equal(3.5, state.QuotaUsed)
|
||||
s.Require().Equal(3.0, state.Quota)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, state.Status)
|
||||
s.Require().Equal(key.Key, state.Key)
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, key.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(3.5, got.QuotaUsed)
|
||||
s.Require().Equal(service.StatusAPIKeyQuotaExhausted, got.Status)
|
||||
}
|
||||
|
||||
// TestIncrementQuotaUsed_Concurrent 使用真实数据库验证并发原子性。
|
||||
// 注意:此测试使用 testEntClient(非事务隔离),数据会真正写入数据库。
|
||||
func TestIncrementQuotaUsed_Concurrent(t *testing.T) {
|
||||
client := testEntClient(t)
|
||||
repo := NewAPIKeyRepository(client, integrationDB).(*apiKeyRepository)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建测试用户和 API Key
|
||||
u, err := client.User.Create().
|
||||
SetEmail("concurrent-incr-" + time.Now().Format(time.RFC3339Nano) + "@test.com").
|
||||
SetPasswordHash("hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(service.RoleUser).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
k := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: "sk-concurrent-" + time.Now().Format(time.RFC3339Nano),
|
||||
Name: "Concurrent",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, k), "create api key")
|
||||
t.Cleanup(func() {
|
||||
_ = client.APIKey.DeleteOneID(k.ID).Exec(ctx)
|
||||
_ = client.User.DeleteOneID(u.ID).Exec(ctx)
|
||||
})
|
||||
|
||||
// 10 个 goroutine 各递增 1.0,总计应为 10.0
|
||||
const goroutines = 10
|
||||
const increment = 1.0
|
||||
var wg sync.WaitGroup
|
||||
errs := make([]error, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
_, errs[idx] = repo.IncrementQuotaUsed(ctx, k.ID, increment)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
for i, e := range errs {
|
||||
require.NoError(t, e, "goroutine %d failed", i)
|
||||
}
|
||||
|
||||
// 验证最终结果
|
||||
got, err := repo.GetByID(ctx, k.ID)
|
||||
require.NoError(t, err, "GetByID")
|
||||
require.Equal(t, float64(goroutines)*increment, got.QuotaUsed,
|
||||
"并发递增后总和应为 %v,实际为 %v", float64(goroutines)*increment, got.QuotaUsed)
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/enttest"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newAPIKeyRepoSQLite(t *testing.T) (*apiKeyRepository, *dbent.Client) {
|
||||
t.Helper()
|
||||
|
||||
db, err := sql.Open("sqlite", "file:api_key_repo_last_used?mode=memory&cache=shared")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
|
||||
_, err = db.Exec("PRAGMA foreign_keys = ON")
|
||||
require.NoError(t, err)
|
||||
|
||||
drv := entsql.OpenDB(dialect.SQLite, db)
|
||||
client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv)))
|
||||
t.Cleanup(func() { _ = client.Close() })
|
||||
|
||||
return &apiKeyRepository{client: client}, client
|
||||
}
|
||||
|
||||
func mustCreateAPIKeyRepoUser(t *testing.T, ctx context.Context, client *dbent.Client, email string) *service.User {
|
||||
t.Helper()
|
||||
u, err := client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetRole(service.RoleUser).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateWithLastUsedAt(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "create-last-used@test.com")
|
||||
|
||||
lastUsed := time.Now().UTC().Add(-time.Hour).Truncate(time.Second)
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-create-last-used",
|
||||
Name: "CreateWithLastUsed",
|
||||
Status: service.StatusActive,
|
||||
LastUsedAt: &lastUsed,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NotNil(t, key.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *key.LastUsedAt, time.Second)
|
||||
|
||||
got, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.LastUsedAt)
|
||||
require.WithinDuration(t, lastUsed, *got.LastUsedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsed(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "update-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used",
|
||||
Name: "UpdateLastUsed",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
before, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, before.LastUsedAt)
|
||||
|
||||
target := time.Now().UTC().Add(2 * time.Minute).Truncate(time.Second)
|
||||
require.NoError(t, repo.UpdateLastUsed(ctx, key.ID, target))
|
||||
|
||||
after, err := repo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, after.LastUsedAt)
|
||||
require.WithinDuration(t, target, *after.LastUsedAt, time.Second)
|
||||
require.WithinDuration(t, target, after.UpdatedAt, time.Second)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDeletedKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "deleted-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-deleted",
|
||||
Name: "UpdateLastUsedDeleted",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
require.NoError(t, repo.Delete(ctx, key.ID))
|
||||
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_UpdateLastUsedDBError(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "db-error-last-used@test.com")
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-update-last-used-db-error",
|
||||
Name: "UpdateLastUsedDBError",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, key))
|
||||
|
||||
require.NoError(t, client.Close())
|
||||
err := repo.UpdateLastUsed(ctx, key.ID, time.Now().UTC())
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestAPIKeyRepository_CreateDuplicateKey(t *testing.T) {
|
||||
repo, client := newAPIKeyRepoSQLite(t)
|
||||
ctx := context.Background()
|
||||
user := mustCreateAPIKeyRepoUser(t, ctx, client, "duplicate-key@test.com")
|
||||
|
||||
first := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "first",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
second := &service.APIKey{
|
||||
UserID: user.ID,
|
||||
Key: "sk-duplicate",
|
||||
Name: "second",
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
require.NoError(t, repo.Create(ctx, first))
|
||||
err := repo.Create(ctx, second)
|
||||
require.ErrorIs(t, err, service.ErrAPIKeyExists)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// PgDumper implements service.DBDumper using pg_dump/psql
|
||||
type PgDumper struct {
|
||||
cfg *config.DatabaseConfig
|
||||
}
|
||||
|
||||
// NewPgDumper creates a new PgDumper
|
||||
func NewPgDumper(cfg *config.Config) service.DBDumper {
|
||||
return &PgDumper{cfg: &cfg.Database}
|
||||
}
|
||||
|
||||
// Dump executes pg_dump and returns a streaming reader of the output
|
||||
func (d *PgDumper) Dump(ctx context.Context) (io.ReadCloser, error) {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"--clean",
|
||||
"--if-exists",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "pg_dump", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return nil, fmt.Errorf("start pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 返回一个 ReadCloser:读 stdout,关闭时等待进程退出
|
||||
return &cmdReadCloser{ReadCloser: stdout, cmd: cmd}, nil
|
||||
}
|
||||
|
||||
// Restore executes psql to restore from a streaming reader
|
||||
func (d *PgDumper) Restore(ctx context.Context, data io.Reader) error {
|
||||
args := []string{
|
||||
"-h", d.cfg.Host,
|
||||
"-p", fmt.Sprintf("%d", d.cfg.Port),
|
||||
"-U", d.cfg.User,
|
||||
"-d", d.cfg.DBName,
|
||||
"--single-transaction",
|
||||
}
|
||||
|
||||
cmd := exec.CommandContext(ctx, "psql", args...)
|
||||
if d.cfg.Password != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGPASSWORD="+d.cfg.Password)
|
||||
}
|
||||
if d.cfg.SSLMode != "" {
|
||||
cmd.Env = append(cmd.Environ(), "PGSSLMODE="+d.cfg.SSLMode)
|
||||
}
|
||||
|
||||
cmd.Stdin = data
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%v: %s", err, string(output))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// cmdReadCloser wraps a command stdout pipe and waits for the process on Close
|
||||
type cmdReadCloser struct {
|
||||
io.ReadCloser
|
||||
cmd *exec.Cmd
|
||||
}
|
||||
|
||||
func (c *cmdReadCloser) Close() error {
|
||||
// Close the pipe first
|
||||
_ = c.ReadCloser.Close()
|
||||
// Wait for the process to exit
|
||||
if err := c.cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("pg_dump exited with error: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
awsconfig "github.com/aws/aws-sdk-go-v2/config"
|
||||
"github.com/aws/aws-sdk-go-v2/credentials"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
// S3BackupStore implements service.BackupObjectStore using AWS S3 compatible storage
|
||||
type S3BackupStore struct {
|
||||
client *s3.Client
|
||||
bucket string
|
||||
}
|
||||
|
||||
// NewS3BackupStoreFactory returns a BackupObjectStoreFactory that creates S3-backed stores
|
||||
func NewS3BackupStoreFactory() service.BackupObjectStoreFactory {
|
||||
return func(ctx context.Context, cfg *service.BackupS3Config) (service.BackupObjectStore, error) {
|
||||
region := cfg.Region
|
||||
if region == "" {
|
||||
region = "auto" // Cloudflare R2 默认 region
|
||||
}
|
||||
|
||||
awsCfg, err := awsconfig.LoadDefaultConfig(ctx,
|
||||
awsconfig.WithRegion(region),
|
||||
awsconfig.WithCredentialsProvider(
|
||||
credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load aws config: %w", err)
|
||||
}
|
||||
|
||||
client := s3.NewFromConfig(awsCfg, func(o *s3.Options) {
|
||||
if cfg.Endpoint != "" {
|
||||
o.BaseEndpoint = &cfg.Endpoint
|
||||
}
|
||||
if cfg.ForcePathStyle {
|
||||
o.UsePathStyle = true
|
||||
}
|
||||
o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware)
|
||||
o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired
|
||||
})
|
||||
|
||||
return &S3BackupStore{client: client, bucket: cfg.Bucket}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Upload(ctx context.Context, key string, body io.Reader, contentType string) (int64, error) {
|
||||
// 读取全部内容以获取大小(S3 PutObject 需要知道内容长度)
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("read body: %w", err)
|
||||
}
|
||||
|
||||
_, err = s.client.PutObject(ctx, &s3.PutObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
Body: bytes.NewReader(data),
|
||||
ContentType: &contentType,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S3 PutObject: %w", err)
|
||||
}
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Download(ctx context.Context, key string) (io.ReadCloser, error) {
|
||||
result, err := s.client.GetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("S3 GetObject: %w", err)
|
||||
}
|
||||
return result.Body, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) Delete(ctx context.Context, key string) error {
|
||||
_, err := s.client.DeleteObject(ctx, &s3.DeleteObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error) {
|
||||
presignClient := s3.NewPresignClient(s.client)
|
||||
result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{
|
||||
Bucket: &s.bucket,
|
||||
Key: &key,
|
||||
}, s3.WithPresignExpires(expiry))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return result.URL, nil
|
||||
}
|
||||
|
||||
func (s *S3BackupStore) HeadBucket(ctx context.Context) error {
|
||||
_, err := s.client.HeadBucket(ctx, &s3.HeadBucketInput{
|
||||
Bucket: &s.bucket,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 HeadBucket failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,330 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
billingBalanceKeyPrefix = "billing:balance:"
|
||||
billingSubKeyPrefix = "billing:sub:"
|
||||
billingRateLimitKeyPrefix = "apikey:rate:"
|
||||
billingCacheTTL = 5 * time.Minute
|
||||
billingCacheJitter = 30 * time.Second
|
||||
rateLimitCacheTTL = 7 * 24 * time.Hour // 7 days matches the longest window
|
||||
|
||||
// Rate limit window durations — must match service.RateLimitWindow* constants.
|
||||
rateLimitWindow5h = 5 * time.Hour
|
||||
rateLimitWindow1d = 24 * time.Hour
|
||||
rateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// jitteredTTL 返回带随机抖动的 TTL,防止缓存雪崩
|
||||
func jitteredTTL() time.Duration {
|
||||
// 只做“减法抖动”,确保实际 TTL 不会超过 billingCacheTTL(避免上界预期被打破)。
|
||||
if billingCacheJitter <= 0 {
|
||||
return billingCacheTTL
|
||||
}
|
||||
jitter := time.Duration(rand.IntN(int(billingCacheJitter)))
|
||||
return billingCacheTTL - jitter
|
||||
}
|
||||
|
||||
// billingBalanceKey generates the Redis key for user balance cache.
|
||||
func billingBalanceKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// billingSubKey generates the Redis key for subscription cache.
|
||||
func billingSubKey(userID, groupID int64) string {
|
||||
return fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
}
|
||||
|
||||
const (
|
||||
subFieldStatus = "status"
|
||||
subFieldExpiresAt = "expires_at"
|
||||
subFieldDailyUsage = "daily_usage"
|
||||
subFieldWeeklyUsage = "weekly_usage"
|
||||
subFieldMonthlyUsage = "monthly_usage"
|
||||
subFieldVersion = "version"
|
||||
)
|
||||
|
||||
// billingRateLimitKey generates the Redis key for API key rate limit cache.
|
||||
func billingRateLimitKey(keyID int64) string {
|
||||
return fmt.Sprintf("%s%d", billingRateLimitKeyPrefix, keyID)
|
||||
}
|
||||
|
||||
const (
|
||||
rateLimitFieldUsage5h = "usage_5h"
|
||||
rateLimitFieldUsage1d = "usage_1d"
|
||||
rateLimitFieldUsage7d = "usage_7d"
|
||||
rateLimitFieldWindow5h = "window_5h"
|
||||
rateLimitFieldWindow1d = "window_1d"
|
||||
rateLimitFieldWindow7d = "window_7d"
|
||||
)
|
||||
|
||||
var (
|
||||
deductBalanceScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
return 0
|
||||
end
|
||||
local newVal = tonumber(current) - tonumber(ARGV[1])
|
||||
redis.call('SET', KEYS[1], newVal)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
updateSubUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'daily_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'weekly_usage', cost)
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], 'monthly_usage', cost)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
|
||||
// updateRateLimitUsageScript atomically increments all three rate limit usage counters
|
||||
// with window expiration checking. If a window has expired, its usage is reset to cost
|
||||
// (instead of accumulated) and the window timestamp is updated, matching the DB-side
|
||||
// IncrementRateLimitUsage semantics.
|
||||
//
|
||||
// ARGV: [1]=cost, [2]=ttl_seconds, [3]=now_unix, [4]=window_5h_seconds, [5]=window_1d_seconds, [6]=window_7d_seconds
|
||||
updateRateLimitUsageScript = redis.NewScript(`
|
||||
local exists = redis.call('EXISTS', KEYS[1])
|
||||
if exists == 0 then
|
||||
return 0
|
||||
end
|
||||
local cost = tonumber(ARGV[1])
|
||||
local now = tonumber(ARGV[3])
|
||||
local win5h = tonumber(ARGV[4])
|
||||
local win1d = tonumber(ARGV[5])
|
||||
local win7d = tonumber(ARGV[6])
|
||||
|
||||
-- Helper: check if window is expired and update usage + window accordingly
|
||||
-- Returns nothing, modifies the hash in-place.
|
||||
local function update_window(usage_field, window_field, window_duration)
|
||||
local w = tonumber(redis.call('HGET', KEYS[1], window_field) or 0)
|
||||
if w == 0 or (now - w) >= window_duration then
|
||||
-- Window expired or never started: reset usage to cost, start new window
|
||||
redis.call('HSET', KEYS[1], usage_field, tostring(cost))
|
||||
redis.call('HSET', KEYS[1], window_field, tostring(now))
|
||||
else
|
||||
-- Window still valid: accumulate
|
||||
redis.call('HINCRBYFLOAT', KEYS[1], usage_field, cost)
|
||||
end
|
||||
end
|
||||
|
||||
update_window('usage_5h', 'window_5h', win5h)
|
||||
update_window('usage_1d', 'window_1d', win1d)
|
||||
update_window('usage_7d', 'window_7d', win7d)
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
return 1
|
||||
`)
|
||||
)
|
||||
|
||||
type billingCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewBillingCache(rdb *redis.Client) service.BillingCache {
|
||||
return &billingCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *billingCache) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
key := billingBalanceKey(userID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return strconv.ParseFloat(val, 64)
|
||||
}
|
||||
|
||||
func (c *billingCache) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Set(ctx, key, balance, jitteredTTL()).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
_, err := deductBalanceScript.Run(ctx, c.rdb, []string{key}, amount, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: deduct balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
key := billingBalanceKey(userID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*service.SubscriptionCacheData, error) {
|
||||
key := billingSubKey(userID, groupID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return c.parseSubscriptionCache(result)
|
||||
}
|
||||
|
||||
func (c *billingCache) parseSubscriptionCache(data map[string]string) (*service.SubscriptionCacheData, error) {
|
||||
result := &service.SubscriptionCacheData{}
|
||||
|
||||
result.Status = data[subFieldStatus]
|
||||
if result.Status == "" {
|
||||
return nil, errors.New("invalid cache: missing status")
|
||||
}
|
||||
|
||||
if expiresStr, ok := data[subFieldExpiresAt]; ok {
|
||||
expiresAt, err := strconv.ParseInt(expiresStr, 10, 64)
|
||||
if err == nil {
|
||||
result.ExpiresAt = time.Unix(expiresAt, 0)
|
||||
}
|
||||
}
|
||||
|
||||
if dailyStr, ok := data[subFieldDailyUsage]; ok {
|
||||
result.DailyUsage, _ = strconv.ParseFloat(dailyStr, 64)
|
||||
}
|
||||
|
||||
if weeklyStr, ok := data[subFieldWeeklyUsage]; ok {
|
||||
result.WeeklyUsage, _ = strconv.ParseFloat(weeklyStr, 64)
|
||||
}
|
||||
|
||||
if monthlyStr, ok := data[subFieldMonthlyUsage]; ok {
|
||||
result.MonthlyUsage, _ = strconv.ParseFloat(monthlyStr, 64)
|
||||
}
|
||||
|
||||
if versionStr, ok := data[subFieldVersion]; ok {
|
||||
result.Version, _ = strconv.ParseInt(versionStr, 10, 64)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *service.SubscriptionCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := billingSubKey(userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
subFieldStatus: data.Status,
|
||||
subFieldExpiresAt: data.ExpiresAt.Unix(),
|
||||
subFieldDailyUsage: data.DailyUsage,
|
||||
subFieldWeeklyUsage: data.WeeklyUsage,
|
||||
subFieldMonthlyUsage: data.MonthlyUsage,
|
||||
subFieldVersion: data.Version,
|
||||
}
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, jitteredTTL())
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
_, err := updateSubUsageScript.Run(ctx, c.rdb, []string{key}, cost, int(jitteredTTL().Seconds())).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update subscription usage cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
key := billingSubKey(userID, groupID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *billingCache) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*service.APIKeyRateLimitCacheData, error) {
|
||||
key := billingRateLimitKey(keyID)
|
||||
result, err := c.rdb.HGetAll(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
data := &service.APIKeyRateLimitCacheData{}
|
||||
if v, ok := result[rateLimitFieldUsage5h]; ok {
|
||||
data.Usage5h, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage1d]; ok {
|
||||
data.Usage1d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldUsage7d]; ok {
|
||||
data.Usage7d, _ = strconv.ParseFloat(v, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow5h]; ok {
|
||||
data.Window5h, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow1d]; ok {
|
||||
data.Window1d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
if v, ok := result[rateLimitFieldWindow7d]; ok {
|
||||
data.Window7d, _ = strconv.ParseInt(v, 10, 64)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (c *billingCache) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *service.APIKeyRateLimitCacheData) error {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
key := billingRateLimitKey(keyID)
|
||||
fields := map[string]any{
|
||||
rateLimitFieldUsage5h: data.Usage5h,
|
||||
rateLimitFieldUsage1d: data.Usage1d,
|
||||
rateLimitFieldUsage7d: data.Usage7d,
|
||||
rateLimitFieldWindow5h: data.Window5h,
|
||||
rateLimitFieldWindow1d: data.Window1d,
|
||||
rateLimitFieldWindow7d: data.Window7d,
|
||||
}
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.HSet(ctx, key, fields)
|
||||
pipe.Expire(ctx, key, rateLimitCacheTTL)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *billingCache) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
now := time.Now().Unix()
|
||||
_, err := updateRateLimitUsageScript.Run(ctx, c.rdb, []string{key},
|
||||
cost,
|
||||
int(rateLimitCacheTTL.Seconds()),
|
||||
now,
|
||||
int(rateLimitWindow5h.Seconds()),
|
||||
int(rateLimitWindow1d.Seconds()),
|
||||
int(rateLimitWindow7d.Seconds()),
|
||||
).Result()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
log.Printf("Warning: update rate limit usage cache failed for api key %d: %v", keyID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *billingCache) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
key := billingRateLimitKey(keyID)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// --- Task 6.1 验证: math/rand/v2 迁移后 jitteredTTL 行为正确 ---
|
||||
|
||||
func TestJitteredTTL_WithinExpectedRange(t *testing.T) {
|
||||
// jitteredTTL 使用减法抖动: billingCacheTTL - [0, billingCacheJitter)
|
||||
// 所以结果应在 [billingCacheTTL - billingCacheJitter, billingCacheTTL] 范围内
|
||||
lowerBound := billingCacheTTL - billingCacheJitter // 5min - 30s = 4min30s
|
||||
upperBound := billingCacheTTL // 5min
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.GreaterOrEqual(t, int64(ttl), int64(lowerBound),
|
||||
"TTL 不应低于 %v,实际得到 %v", lowerBound, ttl)
|
||||
assert.LessOrEqual(t, int64(ttl), int64(upperBound),
|
||||
"TTL 不应超过 %v(上界不变保证),实际得到 %v", upperBound, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_NeverExceedsBase(t *testing.T) {
|
||||
// 关键安全性测试:jitteredTTL 使用减法抖动,确保永远不超过 billingCacheTTL
|
||||
for i := 0; i < 500; i++ {
|
||||
ttl := jitteredTTL()
|
||||
assert.LessOrEqual(t, int64(ttl), int64(billingCacheTTL),
|
||||
"jitteredTTL 不应超过基础 TTL(上界预期不被打破)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariance(t *testing.T) {
|
||||
// 验证抖动确实产生了不同的值
|
||||
results := make(map[time.Duration]bool)
|
||||
for i := 0; i < 100; i++ {
|
||||
ttl := jitteredTTL()
|
||||
results[ttl] = true
|
||||
}
|
||||
|
||||
require.Greater(t, len(results), 1,
|
||||
"jitteredTTL 应产生不同的值(抖动生效),但 100 次调用结果全部相同")
|
||||
}
|
||||
|
||||
func TestJitteredTTL_AverageNearCenter(t *testing.T) {
|
||||
// 验证平均值大约在抖动范围中间
|
||||
var sum time.Duration
|
||||
runs := 1000
|
||||
for i := 0; i < runs; i++ {
|
||||
sum += jitteredTTL()
|
||||
}
|
||||
|
||||
avg := sum / time.Duration(runs)
|
||||
expectedCenter := billingCacheTTL - billingCacheJitter/2 // 4min45s
|
||||
|
||||
// 允许 ±5s 的误差
|
||||
tolerance := 5 * time.Second
|
||||
assert.InDelta(t, float64(expectedCenter), float64(avg), float64(tolerance),
|
||||
"平均 TTL 应接近抖动范围中心 %v", expectedCenter)
|
||||
}
|
||||
|
||||
func TestBillingKeyGeneration(t *testing.T) {
|
||||
t.Run("balance_key", func(t *testing.T) {
|
||||
key := billingBalanceKey(12345)
|
||||
assert.Equal(t, "billing:balance:12345", key)
|
||||
})
|
||||
|
||||
t.Run("sub_key", func(t *testing.T) {
|
||||
key := billingSubKey(100, 200)
|
||||
assert.Equal(t, "billing:sub:100:200", key)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkJitteredTTL(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = jitteredTTL()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,111 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/logredact"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
func NewClaudeOAuthClient() service.ClaudeOAuthClient {
|
||||
return &claudeOAuthService{
|
||||
baseURL: "https://claude.ai",
|
||||
tokenURL: oauth.TokenURL,
|
||||
clientFactory: createReqClient,
|
||||
}
|
||||
}
|
||||
|
||||
type claudeOAuthService struct {
|
||||
baseURL string
|
||||
tokenURL string
|
||||
clientFactory func(proxyURL string) (*req.Client, error)
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) {
|
||||
client, err := s.clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
var orgs []struct {
|
||||
UUID string `json:"uuid"`
|
||||
Name string `json:"name"`
|
||||
RavenType *string `json:"raven_type"` // nil for personal, "team" for team organization
|
||||
}
|
||||
|
||||
targetURL := s.baseURL + "/api/organizations"
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1: Getting organization UUID from %s", targetURL)
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetSuccessResult(&orgs).
|
||||
Get(targetURL)
|
||||
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 Response - Status: %d", resp.StatusCode)
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get organizations: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if len(orgs) == 0 {
|
||||
return "", fmt.Errorf("no organizations found")
|
||||
}
|
||||
|
||||
// 如果只有一个组织,直接使用
|
||||
if len(orgs) == 1 {
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Single org found, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
// 如果有多个组织,优先选择 raven_type 为 "team" 的组织
|
||||
for _, org := range orgs {
|
||||
if org.RavenType != nil && *org.RavenType == "team" {
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - Selected team org, UUID: %s, Name: %s, RavenType: %s",
|
||||
org.UUID, org.Name, *org.RavenType)
|
||||
return org.UUID, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有 team 类型的组织,使用第一个
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 1 SUCCESS - No team org found, using first org, UUID: %s, Name: %s", orgs[0].UUID, orgs[0].Name)
|
||||
return orgs[0].UUID, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) GetAuthorizationCode(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) {
|
||||
client, err := s.clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
authURL := fmt.Sprintf("%s/v1/oauth/%s/authorize", s.baseURL, orgUUID)
|
||||
|
||||
reqBody := map[string]any{
|
||||
"response_type": "code",
|
||||
"client_id": oauth.ClientID,
|
||||
"organization_uuid": orgUUID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"scope": scope,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2: Getting authorization code from %s", authURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var result struct {
|
||||
RedirectURI string `json:"redirect_uri"`
|
||||
}
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetCookies(&http.Cookie{
|
||||
Name: "sessionKey",
|
||||
Value: sessionKey,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetHeader("Accept-Language", "en-US,en;q=0.9").
|
||||
SetHeader("Cache-Control", "no-cache").
|
||||
SetHeader("Origin", "https://claude.ai").
|
||||
SetHeader("Referer", "https://claude.ai/new").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&result).
|
||||
Post(authURL)
|
||||
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 FAILED - Request error: %v", err)
|
||||
return "", fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return "", fmt.Errorf("failed to get authorization code: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
if result.RedirectURI == "" {
|
||||
return "", fmt.Errorf("no redirect_uri in response")
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(result.RedirectURI)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to parse redirect_uri: %w", err)
|
||||
}
|
||||
|
||||
queryParams := parsedURL.Query()
|
||||
authCode := queryParams.Get("code")
|
||||
responseState := queryParams.Get("state")
|
||||
|
||||
if authCode == "" {
|
||||
return "", fmt.Errorf("no authorization code in redirect_uri")
|
||||
}
|
||||
|
||||
fullCode := authCode
|
||||
if responseState != "" {
|
||||
fullCode = authCode + "#" + responseState
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 2 SUCCESS - Got authorization code")
|
||||
return fullCode, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) {
|
||||
client, err := s.clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
// Parse code which may contain state in format "authCode#state"
|
||||
authCode := code
|
||||
codeState := ""
|
||||
if idx := strings.Index(code, "#"); idx != -1 {
|
||||
authCode = code[:idx]
|
||||
codeState = code[idx+1:]
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
"code": authCode,
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": oauth.ClientID,
|
||||
"redirect_uri": oauth.RedirectURI,
|
||||
"code_verifier": codeVerifier,
|
||||
}
|
||||
|
||||
if codeState != "" {
|
||||
reqBody["state"] = codeState
|
||||
}
|
||||
|
||||
// Setup token requires longer expiration (1 year)
|
||||
if isSetupToken {
|
||||
reqBody["expires_in"] = 31536000 // 365 * 24 * 60 * 60 seconds
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3: Exchanging code for token at %s", s.tokenURL)
|
||||
reqBodyJSON, _ := json.Marshal(logredact.RedactMap(reqBody))
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Request Body: %s", string(reqBodyJSON))
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 FAILED - Request error: %v", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 Response - Status: %d, Body: %s", resp.StatusCode, logredact.RedactJSON(resp.Bytes()))
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("repository.claude_oauth", "[OAuth] Step 3 SUCCESS - Got access token")
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) {
|
||||
client, err := s.clientFactory(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
reqBody := map[string]any{
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refreshToken,
|
||||
"client_id": oauth.ClientID,
|
||||
}
|
||||
|
||||
var tokenResp oauth.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Accept", "application/json, text/plain, */*").
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", "axios/1.8.4").
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createReqClient(proxyURL string) (*req.Client, error) {
|
||||
// 禁用 CookieJar,确保每次授权都是干净的会话
|
||||
client := req.C().
|
||||
SetTimeout(60 * time.Second).
|
||||
ImpersonateChrome().
|
||||
SetCookieJar(nil) // 禁用 CookieJar
|
||||
|
||||
trimmed, _, err := proxyurl.Parse(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if trimmed != "" {
|
||||
client.SetProxyURL(trimmed)
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
@@ -0,0 +1,396 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/oauth"
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
client *claudeOAuthService
|
||||
}
|
||||
|
||||
// requestCapture holds captured request data for assertions in the main goroutine.
|
||||
type requestCapture struct {
|
||||
path string
|
||||
method string
|
||||
cookies []*http.Cookie
|
||||
body []byte
|
||||
bodyJSON map[string]any
|
||||
contentType string
|
||||
}
|
||||
|
||||
func newTestReqClient(rt http.RoundTripper) *req.Client {
|
||||
c := req.C()
|
||||
c.GetClient().Transport = rt
|
||||
return c
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetOrganizationUUID() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
errContain string
|
||||
wantUUID string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte(`[{"uuid":"org-1"}]`))
|
||||
},
|
||||
wantUUID: "org-1",
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), "/api/organizations", captured.path, "unexpected path")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sessionKey", captured.cookies[0].Name)
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte("unauthorized"))
|
||||
},
|
||||
wantErr: true,
|
||||
errContain: "401",
|
||||
},
|
||||
{
|
||||
name: "invalid_json_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = w.Write([]byte("not-json"))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.cookies = r.Cookies()
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
|
||||
|
||||
got, err := s.client.GetOrganizationUUID(context.Background(), "sess", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
if tt.errContain != "" {
|
||||
require.ErrorContains(s.T(), err, tt.errContain)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantUUID, got)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantCode string
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "parses_redirect_uri",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?code=AUTH&state=STATE",
|
||||
})
|
||||
},
|
||||
wantCode: "AUTH#STATE",
|
||||
validate: func(captured requestCapture) {
|
||||
require.True(s.T(), strings.HasPrefix(captured.path, "/v1/oauth/") && strings.HasSuffix(captured.path, "/authorize"), "unexpected path: %s", captured.path)
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.Len(s.T(), captured.cookies, 1, "expected 1 cookie")
|
||||
require.Equal(s.T(), "sess", captured.cookies[0].Value)
|
||||
require.Equal(s.T(), "org-1", captured.bodyJSON["organization_uuid"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "st", captured.bodyJSON["state"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_code_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"redirect_uri": oauth.RedirectURI + "?state=STATE", // no code
|
||||
})
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.path = r.URL.Path
|
||||
captured.method = r.Method
|
||||
captured.cookies = r.Cookies()
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.baseURL = "http://in-process"
|
||||
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
|
||||
|
||||
code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantCode, code)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestExchangeCodeForToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
code string
|
||||
isSetupToken bool
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_state_when_embedded",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 3600,
|
||||
RefreshToken: "rt",
|
||||
Scope: "s",
|
||||
})
|
||||
},
|
||||
code: "AUTH#STATE2",
|
||||
isSetupToken: false,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rt",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"), "unexpected content-type")
|
||||
require.Equal(s.T(), "AUTH", captured.bodyJSON["code"])
|
||||
require.Equal(s.T(), "STATE2", captured.bodyJSON["state"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
require.Equal(s.T(), oauth.RedirectURI, captured.bodyJSON["redirect_uri"])
|
||||
require.Equal(s.T(), "ver", captured.bodyJSON["code_verifier"])
|
||||
// Regular OAuth should not include expires_in
|
||||
require.Nil(s.T(), captured.bodyJSON["expires_in"], "regular OAuth should not include expires_in")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "setup_token_includes_expires_in",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 31536000,
|
||||
})
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: true,
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
// Setup token should include expires_in with 1 year value
|
||||
require.Equal(s.T(), float64(31536000), captured.bodyJSON["expires_in"],
|
||||
"setup token should include expires_in: 31536000")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = w.Write([]byte("bad request"))
|
||||
},
|
||||
code: "AUTH",
|
||||
isSetupToken: false,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
|
||||
|
||||
resp, err := s.client.ExchangeCodeForToken(context.Background(), tt.code, "ver", "", "", tt.isSetupToken)
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ClaudeOAuthServiceSuite) TestRefreshToken() {
|
||||
tests := []struct {
|
||||
name string
|
||||
handler http.HandlerFunc
|
||||
wantErr bool
|
||||
wantResp *oauth.TokenResponse
|
||||
validate func(captured requestCapture)
|
||||
}{
|
||||
{
|
||||
name: "sends_json_format",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "new_refresh_token",
|
||||
Scope: "user:profile user:inference",
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "new_access_token",
|
||||
RefreshToken: "new_refresh_token",
|
||||
},
|
||||
validate: func(captured requestCapture) {
|
||||
require.Equal(s.T(), http.MethodPost, captured.method, "expected POST")
|
||||
// 验证使用 JSON 格式(不是 form 格式)
|
||||
require.True(s.T(), strings.HasPrefix(captured.contentType, "application/json"),
|
||||
"expected JSON content-type, got: %s", captured.contentType)
|
||||
// 验证 JSON body 内容
|
||||
require.Equal(s.T(), "refresh_token", captured.bodyJSON["grant_type"])
|
||||
require.Equal(s.T(), "rt", captured.bodyJSON["refresh_token"])
|
||||
require.Equal(s.T(), oauth.ClientID, captured.bodyJSON["client_id"])
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "returns_new_refresh_token",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
TokenType: "bearer",
|
||||
ExpiresIn: 28800,
|
||||
RefreshToken: "rotated_rt", // Anthropic rotates refresh tokens
|
||||
})
|
||||
},
|
||||
wantResp: &oauth.TokenResponse{
|
||||
AccessToken: "at",
|
||||
RefreshToken: "rotated_rt",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "non_200_returns_error",
|
||||
handler: func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"invalid_grant"}`))
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
var captured requestCapture
|
||||
|
||||
rt := newInProcessTransport(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.method = r.Method
|
||||
captured.contentType = r.Header.Get("Content-Type")
|
||||
captured.body, _ = io.ReadAll(r.Body)
|
||||
_ = json.Unmarshal(captured.body, &captured.bodyJSON)
|
||||
tt.handler(w, r)
|
||||
}), nil)
|
||||
|
||||
client, ok := NewClaudeOAuthClient().(*claudeOAuthService)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
s.client.tokenURL = "http://in-process/token"
|
||||
s.client.clientFactory = func(string) (*req.Client, error) { return newTestReqClient(rt), nil }
|
||||
|
||||
resp, err := s.client.RefreshToken(context.Background(), "rt", "")
|
||||
|
||||
if tt.wantErr {
|
||||
require.Error(s.T(), err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), tt.wantResp.AccessToken, resp.AccessToken)
|
||||
require.Equal(s.T(), tt.wantResp.RefreshToken, resp.RefreshToken)
|
||||
if tt.validate != nil {
|
||||
tt.validate(captured)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeOAuthServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage"
|
||||
|
||||
// 默认 User-Agent,与用户抓包的请求一致
|
||||
const defaultUsageUserAgent = "claude-code/2.1.7"
|
||||
|
||||
type claudeUsageService struct {
|
||||
usageURL string
|
||||
allowPrivateHosts bool
|
||||
httpUpstream service.HTTPUpstream
|
||||
}
|
||||
|
||||
// NewClaudeUsageFetcher 创建 Claude 用量获取服务
|
||||
// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装
|
||||
func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher {
|
||||
return &claudeUsageService{
|
||||
usageURL: defaultClaudeUsageURL,
|
||||
httpUpstream: httpUpstream,
|
||||
}
|
||||
}
|
||||
|
||||
// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容)
|
||||
func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) {
|
||||
return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{
|
||||
AccessToken: accessToken,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
}
|
||||
|
||||
// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent
|
||||
func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) {
|
||||
if opts == nil {
|
||||
return nil, fmt.Errorf("options is nil")
|
||||
}
|
||||
|
||||
// 创建请求
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create request failed: %w", err)
|
||||
}
|
||||
|
||||
// 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩)
|
||||
req.Header.Set("Accept", "application/json, text/plain, */*")
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+opts.AccessToken)
|
||||
req.Header.Set("anthropic-beta", "oauth-2025-04-20")
|
||||
|
||||
// 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值)
|
||||
userAgent := defaultUsageUserAgent
|
||||
if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" {
|
||||
userAgent = opts.Fingerprint.UserAgent
|
||||
}
|
||||
req.Header.Set("User-Agent", userAgent)
|
||||
|
||||
var resp *http.Response
|
||||
|
||||
// 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS
|
||||
if opts.EnableTLSFingerprint && s.httpUpstream != nil {
|
||||
// accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置
|
||||
resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
// 不启用 TLS 指纹,使用普通 HTTP 客户端
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: opts.ProxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
ValidateResolvedIP: true,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create http client failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
msg := fmt.Sprintf("API returned status %d: %s", resp.StatusCode, string(body))
|
||||
return nil, infraerrors.New(http.StatusInternalServerError, "UPSTREAM_ERROR", msg)
|
||||
}
|
||||
|
||||
var usageResp service.ClaudeUsageResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&usageResp); err != nil {
|
||||
return nil, fmt.Errorf("decode response failed: %w", err)
|
||||
}
|
||||
|
||||
return &usageResp, nil
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ClaudeUsageServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
fetcher *claudeUsageService
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
// usageRequestCapture holds captured request data for assertions in the main goroutine.
|
||||
type usageRequestCapture struct {
|
||||
authorization string
|
||||
anthropicBeta string
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_Success() {
|
||||
var captured usageRequestCapture
|
||||
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
captured.authorization = r.Header.Get("Authorization")
|
||||
captured.anthropicBeta = r.Header.Get("anthropic-beta")
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{
|
||||
"five_hour": {"utilization": 12.5, "resets_at": "2025-01-01T00:00:00Z"},
|
||||
"seven_day": {"utilization": 34.0, "resets_at": "2025-01-08T00:00:00Z"},
|
||||
"seven_day_sonnet": {"utilization": 56.0, "resets_at": "2025-01-08T00:00:00Z"}
|
||||
}`)
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
resp, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.NoError(s.T(), err, "FetchUsage")
|
||||
require.Equal(s.T(), 12.5, resp.FiveHour.Utilization, "FiveHour utilization mismatch")
|
||||
require.Equal(s.T(), 34.0, resp.SevenDay.Utilization, "SevenDay utilization mismatch")
|
||||
require.Equal(s.T(), 56.0, resp.SevenDaySonnet.Utilization, "SevenDaySonnet utilization mismatch")
|
||||
|
||||
// Assertions on captured request data
|
||||
require.Equal(s.T(), "Bearer at", captured.authorization, "Authorization header mismatch")
|
||||
require.Equal(s.T(), "oauth-2025-04-20", captured.anthropicBeta, "anthropic-beta header mismatch")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_NonOK() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "nope")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
require.ErrorContains(s.T(), err, "nope")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_BadJSON() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "decode response failed")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Never respond - simulate slow server
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: s.srv.URL,
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
_, err := s.fetcher.FetchUsage(ctx, "at", "")
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func (s *ClaudeUsageServiceSuite) TestFetchUsage_InvalidProxyReturnsError() {
|
||||
s.fetcher = &claudeUsageService{
|
||||
usageURL: "http://example.com",
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
|
||||
_, err := s.fetcher.FetchUsage(context.Background(), "at", "://bad-proxy-url")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "create http client failed")
|
||||
}
|
||||
|
||||
func TestClaudeUsageServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ClaudeUsageServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,564 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 并发控制缓存常量定义
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 SCAN 命令遍历独立的槽位键(concurrency:account:{id}:{requestID}),
|
||||
// 在高并发场景下 SCAN 需要多次往返,且遍历大量键时性能下降明显。
|
||||
//
|
||||
// 新实现改用 Redis 有序集合(Sorted Set):
|
||||
// 1. 每个账号/用户只有一个键,成员为 requestID,分数为时间戳
|
||||
// 2. 使用 ZCARD 原子获取并发数,时间复杂度 O(1)
|
||||
// 3. 使用 ZREMRANGEBYSCORE 清理过期槽位,避免手动管理 TTL
|
||||
// 4. 单次 Redis 调用完成计数,减少网络往返
|
||||
const (
|
||||
// 并发槽位键前缀(有序集合)
|
||||
// 格式: concurrency:account:{accountID}
|
||||
accountSlotKeyPrefix = "concurrency:account:"
|
||||
// 格式: concurrency:user:{userID}
|
||||
userSlotKeyPrefix = "concurrency:user:"
|
||||
// 等待队列计数器格式: concurrency:wait:{userID}
|
||||
waitQueueKeyPrefix = "concurrency:wait:"
|
||||
// 账号级等待队列计数器格式: wait:account:{accountID}
|
||||
accountWaitKeyPrefix = "wait:account:"
|
||||
|
||||
// 默认槽位过期时间(分钟),可通过配置覆盖
|
||||
defaultSlotTTLMinutes = 15
|
||||
)
|
||||
|
||||
var (
|
||||
// acquireScript 使用有序集合计数并在未达上限时添加槽位
|
||||
// 使用 Redis TIME 命令获取服务器时间,避免多实例时钟不同步问题
|
||||
// KEYS[1] = 有序集合键 (concurrency:account:{id} / concurrency:user:{id})
|
||||
// ARGV[1] = maxConcurrency
|
||||
// ARGV[2] = TTL(秒)
|
||||
// ARGV[3] = requestID
|
||||
acquireScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local maxConcurrency = tonumber(ARGV[1])
|
||||
local ttl = tonumber(ARGV[2])
|
||||
local requestID = ARGV[3]
|
||||
|
||||
-- 使用 Redis 服务器时间,确保多实例时钟一致
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
-- 清理过期槽位
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
|
||||
-- 检查是否已存在(支持重试场景刷新时间戳)
|
||||
local exists = redis.call('ZSCORE', key, requestID)
|
||||
if exists ~= false then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
-- 检查是否达到并发上限
|
||||
local count = redis.call('ZCARD', key)
|
||||
if count < maxConcurrency then
|
||||
redis.call('ZADD', key, now, requestID)
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
return 1
|
||||
end
|
||||
|
||||
return 0
|
||||
`)
|
||||
|
||||
// getCountScript 统计有序集合中的槽位数量并清理过期条目
|
||||
// 使用 Redis TIME 命令获取服务器时间
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
getCountScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
-- 使用 Redis 服务器时间
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
return redis.call('ZCARD', key)
|
||||
`)
|
||||
|
||||
// incrementWaitScript - refreshes TTL on each increment to keep queue depth accurate
|
||||
// KEYS[1] = wait queue key
|
||||
// ARGV[1] = maxWait
|
||||
// ARGV[2] = TTL in seconds
|
||||
incrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// incrementAccountWaitScript - account-level wait queue count (refresh TTL on each increment)
|
||||
incrementAccountWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current == false then
|
||||
current = 0
|
||||
else
|
||||
current = tonumber(current)
|
||||
end
|
||||
|
||||
if current >= tonumber(ARGV[1]) then
|
||||
return 0
|
||||
end
|
||||
|
||||
local newVal = redis.call('INCR', KEYS[1])
|
||||
|
||||
-- Refresh TTL so long-running traffic doesn't expire active queue counters.
|
||||
redis.call('EXPIRE', KEYS[1], ARGV[2])
|
||||
|
||||
return 1
|
||||
`)
|
||||
|
||||
// decrementWaitScript - same as before
|
||||
decrementWaitScript = redis.NewScript(`
|
||||
local current = redis.call('GET', KEYS[1])
|
||||
if current ~= false and tonumber(current) > 0 then
|
||||
redis.call('DECR', KEYS[1])
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// cleanupExpiredSlotsScript 清理单个账号/用户有序集合中过期槽位
|
||||
// KEYS[1] = 有序集合键
|
||||
// ARGV[1] = TTL(秒)
|
||||
cleanupExpiredSlotsScript = redis.NewScript(`
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
local timeResult = redis.call('TIME')
|
||||
local now = tonumber(timeResult[1])
|
||||
local expireBefore = now - ttl
|
||||
redis.call('ZREMRANGEBYSCORE', key, '-inf', expireBefore)
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
return 1
|
||||
`)
|
||||
|
||||
// startupCleanupScript 清理非当前进程前缀的槽位成员。
|
||||
// KEYS 是有序集合键列表,ARGV[1] 是当前进程前缀,ARGV[2] 是槽位 TTL。
|
||||
// 遍历每个 KEYS[i],移除前缀不匹配的成员,清空后删 key,否则刷新 EXPIRE。
|
||||
startupCleanupScript = redis.NewScript(`
|
||||
local activePrefix = ARGV[1]
|
||||
local slotTTL = tonumber(ARGV[2])
|
||||
local removed = 0
|
||||
for i = 1, #KEYS do
|
||||
local key = KEYS[i]
|
||||
local members = redis.call('ZRANGE', key, 0, -1)
|
||||
for _, member in ipairs(members) do
|
||||
if string.sub(member, 1, string.len(activePrefix)) ~= activePrefix then
|
||||
removed = removed + redis.call('ZREM', key, member)
|
||||
end
|
||||
end
|
||||
if redis.call('ZCARD', key) == 0 then
|
||||
redis.call('DEL', key)
|
||||
else
|
||||
redis.call('EXPIRE', key, slotTTL)
|
||||
end
|
||||
end
|
||||
return removed
|
||||
`)
|
||||
)
|
||||
|
||||
type concurrencyCache struct {
|
||||
rdb *redis.Client
|
||||
slotTTLSeconds int // 槽位过期时间(秒)
|
||||
waitQueueTTLSeconds int // 等待队列过期时间(秒)
|
||||
}
|
||||
|
||||
// NewConcurrencyCache 创建并发控制缓存
|
||||
// slotTTLMinutes: 槽位过期时间(分钟),0 或负数使用默认值 15 分钟
|
||||
// waitQueueTTLSeconds: 等待队列过期时间(秒),0 或负数使用 slot TTL
|
||||
func NewConcurrencyCache(rdb *redis.Client, slotTTLMinutes int, waitQueueTTLSeconds int) service.ConcurrencyCache {
|
||||
if slotTTLMinutes <= 0 {
|
||||
slotTTLMinutes = defaultSlotTTLMinutes
|
||||
}
|
||||
if waitQueueTTLSeconds <= 0 {
|
||||
waitQueueTTLSeconds = slotTTLMinutes * 60
|
||||
}
|
||||
return &concurrencyCache{
|
||||
rdb: rdb,
|
||||
slotTTLSeconds: slotTTLMinutes * 60,
|
||||
waitQueueTTLSeconds: waitQueueTTLSeconds,
|
||||
}
|
||||
}
|
||||
|
||||
// Helper functions for key generation
|
||||
func accountSlotKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
func userSlotKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func waitQueueKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
}
|
||||
|
||||
func accountWaitKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// Account slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||
key := accountSlotKey(accountID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountSlotKey(accountID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
type accountCmd struct {
|
||||
accountID int64
|
||||
zcardCmd *redis.IntCmd
|
||||
}
|
||||
cmds := make([]accountCmd, 0, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
cmds = append(cmds, accountCmd{
|
||||
accountID: accountID,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
})
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, cmd := range cmds {
|
||||
result[cmd.accountID] = int(cmd.zcardCmd.Val())
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// User slot operations
|
||||
|
||||
func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取,确保多实例时钟一致
|
||||
result, err := acquireScript.Run(ctx, c.rdb, []string{key}, maxConcurrency, c.slotTTLSeconds, requestID).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||
key := userSlotKey(userID)
|
||||
return c.rdb.ZRem(ctx, key, requestID).Err()
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||
key := userSlotKey(userID)
|
||||
// 时间戳在 Lua 脚本内使用 Redis TIME 命令获取
|
||||
result, err := getCountScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
key := waitQueueKey(userID)
|
||||
result, err := incrementWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||
key := waitQueueKey(userID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
// Account wait queue operations
|
||||
|
||||
func (c *concurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
result, err := incrementAccountWaitScript.Run(ctx, c.rdb, []string{key}, maxWait, c.waitQueueTTLSeconds).Int()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return result == 1, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||
key := accountWaitKey(accountID)
|
||||
_, err := decrementWaitScript.Run(ctx, c.rdb, []string{key}).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
key := accountWaitKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Int()
|
||||
if err != nil && !errors.Is(err, redis.Nil) {
|
||||
return 0, err
|
||||
}
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return 0, nil
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||
if len(accounts) == 0 {
|
||||
return map[int64]*service.AccountLoadInfo{}, nil
|
||||
}
|
||||
|
||||
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster(Lua 内动态拼 key 会 CROSSSLOT)。
|
||||
// 每个账号执行 3 个命令:ZREMRANGEBYSCORE(清理过期)、ZCARD(并发数)、GET(等待数)。
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
|
||||
type accountCmds struct {
|
||||
id int64
|
||||
maxConcurrency int
|
||||
zcardCmd *redis.IntCmd
|
||||
getCmd *redis.StringCmd
|
||||
}
|
||||
cmds := make([]accountCmds, 0, len(accounts))
|
||||
for _, acc := range accounts {
|
||||
slotKey := accountSlotKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||
waitKey := accountWaitKeyPrefix + strconv.FormatInt(acc.ID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
ac := accountCmds{
|
||||
id: acc.ID,
|
||||
maxConcurrency: acc.MaxConcurrency,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
getCmd: pipe.Get(ctx, waitKey),
|
||||
}
|
||||
cmds = append(cmds, ac)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.AccountLoadInfo, len(accounts))
|
||||
for _, ac := range cmds {
|
||||
currentConcurrency := int(ac.zcardCmd.Val())
|
||||
waitingCount := 0
|
||||
if v, err := ac.getCmd.Int(); err == nil {
|
||||
waitingCount = v
|
||||
}
|
||||
loadRate := 0
|
||||
if ac.maxConcurrency > 0 {
|
||||
loadRate = (currentConcurrency + waitingCount) * 100 / ac.maxConcurrency
|
||||
}
|
||||
loadMap[ac.id] = &service.AccountLoadInfo{
|
||||
AccountID: ac.id,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||
if len(users) == 0 {
|
||||
return map[int64]*service.UserLoadInfo{}, nil
|
||||
}
|
||||
|
||||
// 使用 Pipeline 替代 Lua 脚本,兼容 Redis Cluster。
|
||||
now, err := c.rdb.Time(ctx).Result()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("redis TIME: %w", err)
|
||||
}
|
||||
cutoffTime := now.Unix() - int64(c.slotTTLSeconds)
|
||||
|
||||
pipe := c.rdb.Pipeline()
|
||||
|
||||
type userCmds struct {
|
||||
id int64
|
||||
maxConcurrency int
|
||||
zcardCmd *redis.IntCmd
|
||||
getCmd *redis.StringCmd
|
||||
}
|
||||
cmds := make([]userCmds, 0, len(users))
|
||||
for _, u := range users {
|
||||
slotKey := userSlotKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||
waitKey := waitQueueKeyPrefix + strconv.FormatInt(u.ID, 10)
|
||||
pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10))
|
||||
uc := userCmds{
|
||||
id: u.ID,
|
||||
maxConcurrency: u.MaxConcurrency,
|
||||
zcardCmd: pipe.ZCard(ctx, slotKey),
|
||||
getCmd: pipe.Get(ctx, waitKey),
|
||||
}
|
||||
cmds = append(cmds, uc)
|
||||
}
|
||||
|
||||
if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) {
|
||||
return nil, fmt.Errorf("pipeline exec: %w", err)
|
||||
}
|
||||
|
||||
loadMap := make(map[int64]*service.UserLoadInfo, len(users))
|
||||
for _, uc := range cmds {
|
||||
currentConcurrency := int(uc.zcardCmd.Val())
|
||||
waitingCount := 0
|
||||
if v, err := uc.getCmd.Int(); err == nil {
|
||||
waitingCount = v
|
||||
}
|
||||
loadRate := 0
|
||||
if uc.maxConcurrency > 0 {
|
||||
loadRate = (currentConcurrency + waitingCount) * 100 / uc.maxConcurrency
|
||||
}
|
||||
loadMap[uc.id] = &service.UserLoadInfo{
|
||||
UserID: uc.id,
|
||||
CurrentConcurrency: currentConcurrency,
|
||||
WaitingCount: waitingCount,
|
||||
LoadRate: loadRate,
|
||||
}
|
||||
}
|
||||
|
||||
return loadMap, nil
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
key := accountSlotKey(accountID)
|
||||
_, err := cleanupExpiredSlotsScript.Run(ctx, c.rdb, []string{key}, c.slotTTLSeconds).Result()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *concurrencyCache) CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error {
|
||||
if activeRequestPrefix == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 清理有序集合中非当前进程前缀的成员
|
||||
slotPatterns := []string{accountSlotKeyPrefix + "*", userSlotKeyPrefix + "*"}
|
||||
for _, pattern := range slotPatterns {
|
||||
if err := c.cleanupSlotsByPattern(ctx, pattern, activeRequestPrefix); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 删除所有等待队列计数器(重启后计数器失效)
|
||||
waitPatterns := []string{accountWaitKeyPrefix + "*", waitQueueKeyPrefix + "*"}
|
||||
for _, pattern := range waitPatterns {
|
||||
if err := c.deleteKeysByPattern(ctx, pattern); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// cleanupSlotsByPattern 扫描匹配 pattern 的有序集合键,批量调用 Lua 脚本清理非当前进程成员。
|
||||
func (c *concurrencyCache) cleanupSlotsByPattern(ctx context.Context, pattern, activePrefix string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
_, err := startupCleanupScript.Run(ctx, c.rdb, keys, activePrefix, c.slotTTLSeconds).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("cleanup slots %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// deleteKeysByPattern 扫描匹配 pattern 的键并删除。
|
||||
func (c *concurrencyCache) deleteKeysByPattern(ctx context.Context, pattern string) error {
|
||||
const scanCount = 200
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := c.rdb.Scan(ctx, cursor, pattern, scanCount).Result()
|
||||
if err != nil {
|
||||
return fmt.Errorf("scan %s: %w", pattern, err)
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
if err := c.rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
return fmt.Errorf("del %s: %w", pattern, err)
|
||||
}
|
||||
}
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// 基准测试用 TTL 配置
|
||||
const benchSlotTTLMinutes = 15
|
||||
|
||||
var benchSlotTTL = time.Duration(benchSlotTTLMinutes) * time.Minute
|
||||
|
||||
// BenchmarkAccountConcurrency 用于对比 SCAN 与有序集合的计数性能。
|
||||
func BenchmarkAccountConcurrency(b *testing.B) {
|
||||
rdb := newBenchmarkRedisClient(b)
|
||||
defer func() {
|
||||
_ = rdb.Close()
|
||||
}()
|
||||
|
||||
cache, _ := NewConcurrencyCache(rdb, benchSlotTTLMinutes, int(benchSlotTTL.Seconds())).(*concurrencyCache)
|
||||
ctx := context.Background()
|
||||
|
||||
for _, size := range []int{10, 100, 1000} {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("zset/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
key := accountSlotKey(accountID)
|
||||
|
||||
b.StopTimer()
|
||||
members := make([]redis.Z, 0, size)
|
||||
now := float64(time.Now().Unix())
|
||||
for i := 0; i < size; i++ {
|
||||
members = append(members, redis.Z{
|
||||
Score: now,
|
||||
Member: fmt.Sprintf("req_%d", i),
|
||||
})
|
||||
}
|
||||
if err := rdb.ZAdd(ctx, key, members...).Err(); err != nil {
|
||||
b.Fatalf("初始化有序集合失败: %v", err)
|
||||
}
|
||||
if err := rdb.Expire(ctx, key, benchSlotTTL).Err(); err != nil {
|
||||
b.Fatalf("设置有序集合 TTL 失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := cache.GetAccountConcurrency(ctx, accountID); err != nil {
|
||||
b.Fatalf("获取并发数量失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, key).Err(); err != nil {
|
||||
b.Fatalf("清理有序集合失败: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("scan/slots=%d", size), func(b *testing.B) {
|
||||
accountID := time.Now().UnixNano()
|
||||
pattern := fmt.Sprintf("%s%d:*", accountSlotKeyPrefix, accountID)
|
||||
keys := make([]string, 0, size)
|
||||
|
||||
b.StopTimer()
|
||||
pipe := rdb.Pipeline()
|
||||
for i := 0; i < size; i++ {
|
||||
key := fmt.Sprintf("%s%d:req_%d", accountSlotKeyPrefix, accountID, i)
|
||||
keys = append(keys, key)
|
||||
pipe.Set(ctx, key, "1", benchSlotTTL)
|
||||
}
|
||||
if _, err := pipe.Exec(ctx); err != nil {
|
||||
b.Fatalf("初始化扫描键失败: %v", err)
|
||||
}
|
||||
b.StartTimer()
|
||||
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if _, err := scanSlotCount(ctx, rdb, pattern); err != nil {
|
||||
b.Fatalf("SCAN 计数失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
b.StopTimer()
|
||||
if err := rdb.Del(ctx, keys...).Err(); err != nil {
|
||||
b.Fatalf("清理扫描键失败: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func scanSlotCount(ctx context.Context, rdb *redis.Client, pattern string) (int, error) {
|
||||
var cursor uint64
|
||||
count := 0
|
||||
for {
|
||||
keys, nextCursor, err := rdb.Scan(ctx, cursor, pattern, 100).Result()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
count += len(keys)
|
||||
if nextCursor == 0 {
|
||||
break
|
||||
}
|
||||
cursor = nextCursor
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func newBenchmarkRedisClient(b *testing.B) *redis.Client {
|
||||
b.Helper()
|
||||
|
||||
redisURL := os.Getenv("TEST_REDIS_URL")
|
||||
if redisURL == "" {
|
||||
b.Skip("未设置 TEST_REDIS_URL,跳过 Redis 基准测试")
|
||||
}
|
||||
|
||||
opt, err := redis.ParseURL(redisURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析 TEST_REDIS_URL 失败: %v", err)
|
||||
}
|
||||
|
||||
client := redis.NewClient(opt)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.Ping(ctx).Err(); err != nil {
|
||||
b.Fatalf("Redis 连接失败: %v", err)
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
@@ -0,0 +1,487 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
accountID := int64(30)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL account waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||
// Setup: Create accounts with different load states
|
||||
account1 := int64(100)
|
||||
account2 := int64(101)
|
||||
account3 := int64(102)
|
||||
|
||||
// Account 1: 2/3 slots used, 1 waiting
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 2: 1/2 slots used, 0 waiting
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||
|
||||
// Query batch load
|
||||
accounts := []service.AccountWithConcurrency{
|
||||
{ID: account1, MaxConcurrency: 3},
|
||||
{ID: account2, MaxConcurrency: 2},
|
||||
{ID: account3, MaxConcurrency: 1},
|
||||
}
|
||||
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), loadMap, 3)
|
||||
|
||||
// Verify account1: (2 + 1) / 3 = 100%
|
||||
load1 := loadMap[account1]
|
||||
require.NotNil(s.T(), load1)
|
||||
require.Equal(s.T(), account1, load1.AccountID)
|
||||
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||
require.Equal(s.T(), 100, load1.LoadRate)
|
||||
|
||||
// Verify account2: (1 + 0) / 2 = 50%
|
||||
load2 := loadMap[account2]
|
||||
require.NotNil(s.T(), load2)
|
||||
require.Equal(s.T(), account2, load2.AccountID)
|
||||
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||
require.Equal(s.T(), 50, load2.LoadRate)
|
||||
|
||||
// Verify account3: (0 + 0) / 1 = 0%
|
||||
load3 := loadMap[account3]
|
||||
require.NotNil(s.T(), load3)
|
||||
require.Equal(s.T(), account3, load3.AccountID)
|
||||
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||
require.Equal(s.T(), 0, load3.LoadRate)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||
// Test with empty account list
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||
require.NoError(s.T(), err)
|
||||
require.Empty(s.T(), loadMap)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||
accountID := int64(200)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
// Acquire 3 slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Verify 3 slots exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, cur)
|
||||
|
||||
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||
now := time.Now().Unix()
|
||||
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Run cleanup
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify only 1 slot remains (req3)
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur)
|
||||
|
||||
// Verify req3 still exists
|
||||
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), members, 1)
|
||||
require.Equal(s.T(), "req3", members[0])
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
accountID := int64(201)
|
||||
|
||||
// Acquire 2 fresh slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Run cleanup (should not remove anything)
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify both slots still exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
@@ -0,0 +1,533 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
const usageLogsCleanupBatchSize = 10000
|
||||
const usageBillingDedupCleanupBatchSize = 10000
|
||||
|
||||
// NewDashboardAggregationRepository 创建仪表盘预聚合仓储。
|
||||
func NewDashboardAggregationRepository(sqlDB *sql.DB) service.DashboardAggregationRepository {
|
||||
if sqlDB == nil {
|
||||
return nil
|
||||
}
|
||||
if !isPostgresDriver(sqlDB) {
|
||||
log.Printf("[DashboardAggregation] 检测到非 PostgreSQL 驱动,已自动禁用预聚合")
|
||||
return nil
|
||||
}
|
||||
return newDashboardAggregationRepositoryWithSQL(sqlDB)
|
||||
}
|
||||
|
||||
func newDashboardAggregationRepositoryWithSQL(sqlq sqlExecutor) *dashboardAggregationRepository {
|
||||
return &dashboardAggregationRepository{sql: sqlq}
|
||||
}
|
||||
|
||||
func isPostgresDriver(db *sql.DB) bool {
|
||||
if db == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := db.Driver().(*pq.Driver)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.aggregateRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) aggregateRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 以桶边界聚合,允许覆盖 end 所在桶的剩余区间。
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if r == nil || r.sql == nil {
|
||||
return nil
|
||||
}
|
||||
loc := timezone.Location()
|
||||
startLocal := start.In(loc)
|
||||
endLocal := end.In(loc)
|
||||
if !endLocal.After(startLocal) {
|
||||
return nil
|
||||
}
|
||||
|
||||
hourStart := startLocal.Truncate(time.Hour)
|
||||
hourEnd := endLocal.Truncate(time.Hour)
|
||||
if endLocal.After(hourEnd) {
|
||||
hourEnd = hourEnd.Add(time.Hour)
|
||||
}
|
||||
|
||||
dayStart := truncateToDay(startLocal)
|
||||
dayEnd := truncateToDay(endLocal)
|
||||
if endLocal.After(dayEnd) {
|
||||
dayEnd = dayEnd.Add(24 * time.Hour)
|
||||
}
|
||||
|
||||
// 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。
|
||||
if db, ok := r.sql.(*sql.DB); ok {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
txRepo := newDashboardAggregationRepositoryWithSQL(tx)
|
||||
if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd)
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error {
|
||||
// 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
var ts time.Time
|
||||
query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1"
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &ts); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return time.Unix(0, 0).UTC(), nil
|
||||
}
|
||||
return time.Time{}, err
|
||||
}
|
||||
return ts.UTC(), nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_aggregation_watermark (id, last_aggregated_at, updated_at)
|
||||
VALUES (1, $1, NOW())
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET last_aggregated_at = EXCLUDED.last_aggregated_at, updated_at = EXCLUDED.updated_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, aggregatedAt.UTC())
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
hourlyCutoffUTC := hourlyCutoff.UTC()
|
||||
dailyCutoffUTC := dailyCutoff.UTC()
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start < $1", hourlyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date < $1::date", dailyCutoffUTC); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isPartitioned {
|
||||
return r.dropUsageLogsPartitions(ctx, cutoff)
|
||||
}
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid
|
||||
FROM usage_logs
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM usage_logs
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageLogsCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageLogsCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
for {
|
||||
res, err := r.sql.ExecContext(ctx, `
|
||||
WITH victims AS (
|
||||
SELECT ctid, request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM usage_billing_dedup
|
||||
WHERE created_at < $1
|
||||
LIMIT $2
|
||||
), archived AS (
|
||||
INSERT INTO usage_billing_dedup_archive (request_id, api_key_id, request_fingerprint, created_at)
|
||||
SELECT request_id, api_key_id, request_fingerprint, created_at
|
||||
FROM victims
|
||||
ON CONFLICT (request_id, api_key_id) DO NOTHING
|
||||
)
|
||||
DELETE FROM usage_billing_dedup
|
||||
WHERE ctid IN (SELECT ctid FROM victims)
|
||||
`, cutoff.UTC(), usageBillingDedupCleanupBatchSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected < usageBillingDedupCleanupBatchSize {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
isPartitioned, err := r.isUsageLogsPartitioned(ctx)
|
||||
if err != nil || !isPartitioned {
|
||||
return err
|
||||
}
|
||||
monthStart := truncateToMonthUTC(now)
|
||||
prevMonth := monthStart.AddDate(0, -1, 0)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
|
||||
for _, m := range []time.Time{prevMonth, monthStart, nextMonth} {
|
||||
if err := r.createUsageLogsPartition(ctx, m); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertHourlyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_hourly_users (bucket_start, user_id)
|
||||
SELECT DISTINCT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
user_id
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) insertDailyActiveUsers(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
INSERT INTO usage_dashboard_daily_users (bucket_date, user_id)
|
||||
SELECT DISTINCT
|
||||
(bucket_start AT TIME ZONE $3)::date AS bucket_date,
|
||||
user_id
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
ON CONFLICT DO NOTHING
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertHourlyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH hourly AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE $3) AT TIME ZONE $3 AS bucket_start,
|
||||
COUNT(*) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(COALESCE(duration_ms, 0)), 0) AS total_duration_ms
|
||||
FROM usage_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_start, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_hourly_users
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY bucket_start
|
||||
)
|
||||
INSERT INTO usage_dashboard_hourly (
|
||||
bucket_start,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
hourly.bucket_start,
|
||||
hourly.total_requests,
|
||||
hourly.input_tokens,
|
||||
hourly.output_tokens,
|
||||
hourly.cache_creation_tokens,
|
||||
hourly.cache_read_tokens,
|
||||
hourly.total_cost,
|
||||
hourly.actual_cost,
|
||||
hourly.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM hourly
|
||||
LEFT JOIN user_counts ON user_counts.bucket_start = hourly.bucket_start
|
||||
ON CONFLICT (bucket_start)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) upsertDailyAggregates(ctx context.Context, start, end time.Time) error {
|
||||
tzName := timezone.Name()
|
||||
query := `
|
||||
WITH daily AS (
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE $5)::date AS bucket_date,
|
||||
COALESCE(SUM(total_requests), 0) AS total_requests,
|
||||
COALESCE(SUM(input_tokens), 0) AS input_tokens,
|
||||
COALESCE(SUM(output_tokens), 0) AS output_tokens,
|
||||
COALESCE(SUM(cache_creation_tokens), 0) AS cache_creation_tokens,
|
||||
COALESCE(SUM(cache_read_tokens), 0) AS cache_read_tokens,
|
||||
COALESCE(SUM(total_cost), 0) AS total_cost,
|
||||
COALESCE(SUM(actual_cost), 0) AS actual_cost,
|
||||
COALESCE(SUM(total_duration_ms), 0) AS total_duration_ms
|
||||
FROM usage_dashboard_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY (bucket_start AT TIME ZONE $5)::date
|
||||
),
|
||||
user_counts AS (
|
||||
SELECT bucket_date, COUNT(*) AS active_users
|
||||
FROM usage_dashboard_daily_users
|
||||
WHERE bucket_date >= $3::date AND bucket_date < $4::date
|
||||
GROUP BY bucket_date
|
||||
)
|
||||
INSERT INTO usage_dashboard_daily (
|
||||
bucket_date,
|
||||
total_requests,
|
||||
input_tokens,
|
||||
output_tokens,
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
total_cost,
|
||||
actual_cost,
|
||||
total_duration_ms,
|
||||
active_users,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
daily.bucket_date,
|
||||
daily.total_requests,
|
||||
daily.input_tokens,
|
||||
daily.output_tokens,
|
||||
daily.cache_creation_tokens,
|
||||
daily.cache_read_tokens,
|
||||
daily.total_cost,
|
||||
daily.actual_cost,
|
||||
daily.total_duration_ms,
|
||||
COALESCE(user_counts.active_users, 0) AS active_users,
|
||||
NOW()
|
||||
FROM daily
|
||||
LEFT JOIN user_counts ON user_counts.bucket_date = daily.bucket_date
|
||||
ON CONFLICT (bucket_date)
|
||||
DO UPDATE SET
|
||||
total_requests = EXCLUDED.total_requests,
|
||||
input_tokens = EXCLUDED.input_tokens,
|
||||
output_tokens = EXCLUDED.output_tokens,
|
||||
cache_creation_tokens = EXCLUDED.cache_creation_tokens,
|
||||
cache_read_tokens = EXCLUDED.cache_read_tokens,
|
||||
total_cost = EXCLUDED.total_cost,
|
||||
actual_cost = EXCLUDED.actual_cost,
|
||||
total_duration_ms = EXCLUDED.total_duration_ms,
|
||||
active_users = EXCLUDED.active_users,
|
||||
computed_at = EXCLUDED.computed_at
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query, start, end, start, end, tzName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) isUsageLogsPartitioned(ctx context.Context) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS(
|
||||
SELECT 1
|
||||
FROM pg_partitioned_table pt
|
||||
JOIN pg_class c ON c.oid = pt.partrelid
|
||||
WHERE c.relname = 'usage_logs'
|
||||
)
|
||||
`
|
||||
var partitioned bool
|
||||
if err := scanSingleRow(ctx, r.sql, query, nil, &partitioned); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return partitioned, nil
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) dropUsageLogsPartitions(ctx context.Context, cutoff time.Time) error {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT c.relname
|
||||
FROM pg_inherits
|
||||
JOIN pg_class c ON c.oid = pg_inherits.inhrelid
|
||||
JOIN pg_class p ON p.oid = pg_inherits.inhparent
|
||||
WHERE p.relname = 'usage_logs'
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
_ = rows.Close()
|
||||
}()
|
||||
|
||||
cutoffMonth := truncateToMonthUTC(cutoff)
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
if !strings.HasPrefix(name, "usage_logs_") {
|
||||
continue
|
||||
}
|
||||
suffix := strings.TrimPrefix(name, "usage_logs_")
|
||||
month, err := time.Parse("200601", suffix)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
month = month.UTC()
|
||||
if month.Before(cutoffMonth) {
|
||||
if _, err := r.sql.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(name))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (r *dashboardAggregationRepository) createUsageLogsPartition(ctx context.Context, month time.Time) error {
|
||||
monthStart := truncateToMonthUTC(month)
|
||||
nextMonth := monthStart.AddDate(0, 1, 0)
|
||||
name := fmt.Sprintf("usage_logs_%s", monthStart.Format("200601"))
|
||||
query := fmt.Sprintf(
|
||||
"CREATE TABLE IF NOT EXISTS %s PARTITION OF usage_logs FOR VALUES FROM (%s) TO (%s)",
|
||||
pq.QuoteIdentifier(name),
|
||||
pq.QuoteLiteral(monthStart.Format("2006-01-02")),
|
||||
pq.QuoteLiteral(nextMonth.Format("2006-01-02")),
|
||||
)
|
||||
_, err := r.sql.ExecContext(ctx, query)
|
||||
return err
|
||||
}
|
||||
|
||||
func truncateToDay(t time.Time) time.Time {
|
||||
return timezone.StartOfDay(t)
|
||||
}
|
||||
|
||||
func truncateToMonthUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), 1, 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
@@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const dashboardStatsCacheKey = "dashboard:stats:v1"
|
||||
|
||||
type dashboardCache struct {
|
||||
rdb *redis.Client
|
||||
keyPrefix string
|
||||
}
|
||||
|
||||
func NewDashboardCache(rdb *redis.Client, cfg *config.Config) service.DashboardStatsCache {
|
||||
prefix := "sub2api:"
|
||||
if cfg != nil {
|
||||
prefix = strings.TrimSpace(cfg.Dashboard.KeyPrefix)
|
||||
}
|
||||
if prefix != "" && !strings.HasSuffix(prefix, ":") {
|
||||
prefix += ":"
|
||||
}
|
||||
return &dashboardCache{
|
||||
rdb: rdb,
|
||||
keyPrefix: prefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *dashboardCache) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
val, err := c.rdb.Get(ctx, c.buildKey()).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", service.ErrDashboardStatsCacheMiss
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *dashboardCache) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
return c.rdb.Set(ctx, c.buildKey(), data, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *dashboardCache) buildKey() string {
|
||||
if c.keyPrefix == "" {
|
||||
return dashboardStatsCacheKey
|
||||
}
|
||||
return c.keyPrefix + dashboardStatsCacheKey
|
||||
}
|
||||
|
||||
func (c *dashboardCache) DeleteDashboardStats(ctx context.Context) error {
|
||||
return c.rdb.Del(ctx, c.buildKey()).Err()
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDashboardCacheKeyPrefix(t *testing.T) {
|
||||
cache := NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "prod",
|
||||
},
|
||||
})
|
||||
impl, ok := cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "prod:", impl.keyPrefix)
|
||||
|
||||
cache = NewDashboardCache(nil, &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{
|
||||
KeyPrefix: "staging:",
|
||||
},
|
||||
})
|
||||
impl, ok = cache.(*dashboardCache)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "staging:", impl.keyPrefix)
|
||||
}
|
||||
@@ -0,0 +1,32 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
verifyCodeKeyPrefix = "verify_code:"
|
||||
passwordResetKeyPrefix = "password_reset:"
|
||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||
)
|
||||
|
||||
// verifyCodeKey generates the Redis key for email verification code.
|
||||
func verifyCodeKey(email string) string {
|
||||
return verifyCodeKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetKey generates the Redis key for password reset token.
|
||||
func passwordResetKey(email string) string {
|
||||
return passwordResetKeyPrefix + email
|
||||
}
|
||||
|
||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||
func passwordResetSentAtKey(email string) string {
|
||||
return passwordResetSentAtKeyPrefix + email
|
||||
}
|
||||
|
||||
type emailCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewEmailCache(rdb *redis.Client) service.EmailCache {
|
||||
return &emailCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *emailCache) GetVerificationCode(ctx context.Context, email string) (*service.VerificationCodeData, error) {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.VerificationCodeData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetVerificationCode(ctx context.Context, email string, data *service.VerificationCodeData, ttl time.Duration) error {
|
||||
key := verifyCodeKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
key := verifyCodeKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset token methods
|
||||
|
||||
func (c *emailCache) GetPasswordResetToken(ctx context.Context, email string) (*service.PasswordResetTokenData, error) {
|
||||
key := passwordResetKey(email)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var data service.PasswordResetTokenData
|
||||
if err := json.Unmarshal([]byte(val), &data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &data, nil
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetToken(ctx context.Context, email string, data *service.PasswordResetTokenData, ttl time.Duration) error {
|
||||
key := passwordResetKey(email)
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *emailCache) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
key := passwordResetKey(email)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
// Password reset email cooldown methods
|
||||
|
||||
func (c *emailCache) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
key := passwordResetSentAtKey(email)
|
||||
exists, err := c.rdb.Exists(ctx, key).Result()
|
||||
return err == nil && exists > 0
|
||||
}
|
||||
|
||||
func (c *emailCache) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
key := passwordResetSentAtKey(email)
|
||||
return c.rdb.Set(ctx, key, "1", ttl).Err()
|
||||
}
|
||||
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,45 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
// Package repository 提供应用程序的基础设施层组件。
|
||||
// 包括数据库连接初始化、ORM 客户端管理、Redis 连接、数据库迁移等核心功能。
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq" // PostgreSQL 驱动,通过副作用导入注册驱动
|
||||
)
|
||||
|
||||
// InitEnt 初始化 Ent ORM 客户端并返回客户端实例和底层的 *sql.DB。
|
||||
//
|
||||
// 该函数执行以下操作:
|
||||
// 1. 初始化全局时区设置,确保时间处理一致性
|
||||
// 2. 建立 PostgreSQL 数据库连接
|
||||
// 3. 自动执行数据库迁移,确保 schema 与代码同步
|
||||
// 4. 创建并返回 Ent 客户端实例
|
||||
//
|
||||
// 重要提示:调用者必须负责关闭返回的 ent.Client(关闭时会自动关闭底层的 driver/db)。
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 应用程序配置,包含数据库连接信息和时区设置
|
||||
//
|
||||
// 返回:
|
||||
// - *ent.Client: Ent ORM 客户端,用于执行数据库操作
|
||||
// - *sql.DB: 底层的 SQL 数据库连接,可用于直接执行原生 SQL
|
||||
// - error: 初始化过程中的错误
|
||||
func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
|
||||
// 优先初始化时区设置,确保所有时间操作使用统一的时区。
|
||||
// 这对于跨时区部署和日志时间戳的一致性至关重要。
|
||||
if err := timezone.Init(cfg.Timezone); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建包含时区信息的数据库连接字符串 (DSN)。
|
||||
// 时区信息会传递给 PostgreSQL,确保数据库层面的时间处理正确。
|
||||
dsn := cfg.Database.DSNWithTimezone(cfg.Timezone)
|
||||
|
||||
// 使用 Ent 的 SQL 驱动打开 PostgreSQL 连接。
|
||||
// dialect.Postgres 指定使用 PostgreSQL 方言进行 SQL 生成。
|
||||
drv, err := entsql.Open(dialect.Postgres, dsn)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
applyDBPoolSettings(drv.DB(), cfg)
|
||||
|
||||
// 确保数据库 schema 已准备就绪。
|
||||
// SQL 迁移文件是 schema 的权威来源(source of truth)。
|
||||
// 这种方式比 Ent 的自动迁移更可控,支持复杂的迁移场景。
|
||||
migrationCtx, cancel := context.WithTimeout(context.Background(), 10*time.Minute)
|
||||
defer cancel()
|
||||
if err := applyMigrationsFS(migrationCtx, drv.DB(), migrations.FS); err != nil {
|
||||
_ = drv.Close() // 迁移失败时关闭驱动,避免资源泄露
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
|
||||
client := ent.NewClient(ent.Driver(drv))
|
||||
|
||||
// 启动阶段:从配置或数据库中确保系统密钥可用。
|
||||
if err := ensureBootstrapSecrets(migrationCtx, client, cfg); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 在密钥补齐后执行完整配置校验,避免空 jwt.secret 导致服务运行时失败。
|
||||
if err := cfg.Validate(); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, fmt.Errorf("validate config after secret bootstrap: %w", err)
|
||||
}
|
||||
|
||||
// SIMPLE 模式:启动时补齐各平台默认分组。
|
||||
// - anthropic/openai/gemini: 确保存在 <platform>-default
|
||||
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
|
||||
if cfg.RunMode == config.RunModeSimple {
|
||||
seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer seedCancel()
|
||||
if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
if err := ensureSimpleModeAdminConcurrency(seedCtx, client); err != nil {
|
||||
_ = client.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return client, drv.DB(), nil
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
errorPassthroughCacheKey = "error_passthrough_rules"
|
||||
errorPassthroughPubSubKey = "error_passthrough_rules_updated"
|
||||
errorPassthroughCacheTTL = 24 * time.Hour
|
||||
)
|
||||
|
||||
type errorPassthroughCache struct {
|
||||
rdb *redis.Client
|
||||
localCache []*model.ErrorPassthroughRule
|
||||
localMu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewErrorPassthroughCache 创建错误透传规则缓存
|
||||
func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache {
|
||||
return &errorPassthroughCache{
|
||||
rdb: rdb,
|
||||
}
|
||||
}
|
||||
|
||||
// Get 从缓存获取规则列表
|
||||
func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) {
|
||||
// 先检查本地缓存
|
||||
c.localMu.RLock()
|
||||
if c.localCache != nil {
|
||||
rules := c.localCache
|
||||
c.localMu.RUnlock()
|
||||
return rules, true
|
||||
}
|
||||
c.localMu.RUnlock()
|
||||
|
||||
// 从 Redis 获取
|
||||
data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes()
|
||||
if err != nil {
|
||||
if err != redis.Nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err)
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
var rules []*model.ErrorPassthroughRule
|
||||
if err := json.Unmarshal(data, &rules); err != nil {
|
||||
log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return rules, true
|
||||
}
|
||||
|
||||
// Set 设置缓存
|
||||
func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error {
|
||||
data, err := json.Marshal(rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 更新本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = rules
|
||||
c.localMu.Unlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Invalidate 使缓存失效
|
||||
func (c *errorPassthroughCache) Invalidate(ctx context.Context) error {
|
||||
// 清除本地缓存
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 清除 Redis 缓存
|
||||
return c.rdb.Del(ctx, errorPassthroughCacheKey).Err()
|
||||
}
|
||||
|
||||
// NotifyUpdate 通知其他实例刷新缓存
|
||||
func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error {
|
||||
return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err()
|
||||
}
|
||||
|
||||
// SubscribeUpdates 订阅缓存更新通知
|
||||
func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) {
|
||||
go func() {
|
||||
sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey)
|
||||
defer func() { _ = sub.Close() }()
|
||||
|
||||
ch := sub.Channel()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case msg := <-ch:
|
||||
if msg == nil {
|
||||
return
|
||||
}
|
||||
// 清除本地缓存,下次访问时会从 Redis 或数据库重新加载
|
||||
c.localMu.Lock()
|
||||
c.localCache = nil
|
||||
c.localMu.Unlock()
|
||||
|
||||
// 调用处理函数
|
||||
handler()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
@@ -0,0 +1,181 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule"
|
||||
"github.com/Wei-Shaw/sub2api/internal/model"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type errorPassthroughRepository struct {
|
||||
client *ent.Client
|
||||
}
|
||||
|
||||
// NewErrorPassthroughRepository 创建错误透传规则仓库
|
||||
func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository {
|
||||
return &errorPassthroughRepository{client: client}
|
||||
}
|
||||
|
||||
// List 获取所有规则
|
||||
func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) {
|
||||
rules, err := r.client.ErrorPassthroughRule.Query().
|
||||
Order(ent.Asc(errorpassthroughrule.FieldPriority)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]*model.ErrorPassthroughRule, len(rules))
|
||||
for i, rule := range rules {
|
||||
result[i] = r.toModel(rule)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByID 根据 ID 获取规则
|
||||
func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) {
|
||||
rule, err := r.client.ErrorPassthroughRule.Get(ctx, id)
|
||||
if err != nil {
|
||||
if ent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(rule), nil
|
||||
}
|
||||
|
||||
// Create 创建规则
|
||||
func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.Create().
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody).
|
||||
SetSkipMonitoring(rule.SkipMonitoring)
|
||||
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(created), nil
|
||||
}
|
||||
|
||||
// Update 更新规则
|
||||
func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) {
|
||||
builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID).
|
||||
SetName(rule.Name).
|
||||
SetEnabled(rule.Enabled).
|
||||
SetPriority(rule.Priority).
|
||||
SetMatchMode(rule.MatchMode).
|
||||
SetPassthroughCode(rule.PassthroughCode).
|
||||
SetPassthroughBody(rule.PassthroughBody).
|
||||
SetSkipMonitoring(rule.SkipMonitoring)
|
||||
|
||||
// 处理可选字段
|
||||
if len(rule.ErrorCodes) > 0 {
|
||||
builder.SetErrorCodes(rule.ErrorCodes)
|
||||
} else {
|
||||
builder.ClearErrorCodes()
|
||||
}
|
||||
if len(rule.Keywords) > 0 {
|
||||
builder.SetKeywords(rule.Keywords)
|
||||
} else {
|
||||
builder.ClearKeywords()
|
||||
}
|
||||
if len(rule.Platforms) > 0 {
|
||||
builder.SetPlatforms(rule.Platforms)
|
||||
} else {
|
||||
builder.ClearPlatforms()
|
||||
}
|
||||
if rule.ResponseCode != nil {
|
||||
builder.SetResponseCode(*rule.ResponseCode)
|
||||
} else {
|
||||
builder.ClearResponseCode()
|
||||
}
|
||||
if rule.CustomMessage != nil {
|
||||
builder.SetCustomMessage(*rule.CustomMessage)
|
||||
} else {
|
||||
builder.ClearCustomMessage()
|
||||
}
|
||||
if rule.Description != nil {
|
||||
builder.SetDescription(*rule.Description)
|
||||
} else {
|
||||
builder.ClearDescription()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return r.toModel(updated), nil
|
||||
}
|
||||
|
||||
// Delete 删除规则
|
||||
func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx)
|
||||
}
|
||||
|
||||
// toModel 将 Ent 实体转换为服务模型
|
||||
func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule {
|
||||
rule := &model.ErrorPassthroughRule{
|
||||
ID: int64(e.ID),
|
||||
Name: e.Name,
|
||||
Enabled: e.Enabled,
|
||||
Priority: e.Priority,
|
||||
ErrorCodes: e.ErrorCodes,
|
||||
Keywords: e.Keywords,
|
||||
MatchMode: e.MatchMode,
|
||||
Platforms: e.Platforms,
|
||||
PassthroughCode: e.PassthroughCode,
|
||||
PassthroughBody: e.PassthroughBody,
|
||||
SkipMonitoring: e.SkipMonitoring,
|
||||
CreatedAt: e.CreatedAt,
|
||||
UpdatedAt: e.UpdatedAt,
|
||||
}
|
||||
|
||||
if e.ResponseCode != nil {
|
||||
rule.ResponseCode = e.ResponseCode
|
||||
}
|
||||
if e.CustomMessage != nil {
|
||||
rule.CustomMessage = e.CustomMessage
|
||||
}
|
||||
if e.Description != nil {
|
||||
rule.Description = e.Description
|
||||
}
|
||||
|
||||
// 确保切片不为 nil
|
||||
if rule.ErrorCodes == nil {
|
||||
rule.ErrorCodes = []int{}
|
||||
}
|
||||
if rule.Keywords == nil {
|
||||
rule.Keywords = []string{}
|
||||
}
|
||||
if rule.Platforms == nil {
|
||||
rule.Platforms = []string{}
|
||||
}
|
||||
|
||||
return rule
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// clientFromContext 从 context 中获取事务 client,如果不存在则返回默认 client。
|
||||
//
|
||||
// 这个辅助函数支持 repository 方法在事务上下文中工作:
|
||||
// - 如果 context 中存在事务(通过 ent.NewTxContext 设置),返回事务的 client
|
||||
// - 否则返回传入的默认 client
|
||||
//
|
||||
// 使用示例:
|
||||
//
|
||||
// func (r *someRepo) SomeMethod(ctx context.Context) error {
|
||||
// client := clientFromContext(ctx, r.client)
|
||||
// return client.SomeEntity.Create().Save(ctx)
|
||||
// }
|
||||
func clientFromContext(ctx context.Context, defaultClient *dbent.Client) *dbent.Client {
|
||||
if tx := dbent.TxFromContext(ctx); tx != nil {
|
||||
return tx.Client()
|
||||
}
|
||||
return defaultClient
|
||||
}
|
||||
|
||||
// translatePersistenceError 将数据库层错误翻译为业务层错误。
|
||||
//
|
||||
// 这是 Repository 层的核心错误处理函数,确保数据库细节不会泄露到业务层。
|
||||
// 通过统一的错误翻译,业务层可以使用语义明确的错误类型(如 ErrUserNotFound)
|
||||
// 而不是依赖于特定数据库的错误(如 sql.ErrNoRows)。
|
||||
//
|
||||
// 参数:
|
||||
// - err: 原始数据库错误
|
||||
// - notFound: 当记录不存在时返回的业务错误(可为 nil 表示不处理)
|
||||
// - conflict: 当违反唯一约束时返回的业务错误(可为 nil 表示不处理)
|
||||
//
|
||||
// 返回:
|
||||
// - 翻译后的业务错误,或原始错误(如果不匹配任何规则)
|
||||
//
|
||||
// 示例:
|
||||
//
|
||||
// err := translatePersistenceError(dbErr, service.ErrUserNotFound, service.ErrEmailExists)
|
||||
func translatePersistenceError(err error, notFound, conflict *infraerrors.ApplicationError) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 兼容 Ent ORM 和标准 database/sql 的 NotFound 行为。
|
||||
// Ent 使用自定义的 NotFoundError,而标准库使用 sql.ErrNoRows。
|
||||
// 这里同时处理两种情况,保持业务错误映射一致。
|
||||
if notFound != nil && (errors.Is(err, sql.ErrNoRows) || dbent.IsNotFound(err)) {
|
||||
return notFound.WithCause(err)
|
||||
}
|
||||
|
||||
// 处理唯一约束冲突(如邮箱已存在、名称重复等)
|
||||
if conflict != nil && isUniqueConstraintViolation(err) {
|
||||
return conflict.WithCause(err)
|
||||
}
|
||||
|
||||
// 未匹配任何规则,返回原始错误
|
||||
return err
|
||||
}
|
||||
|
||||
// isUniqueConstraintViolation 判断错误是否为唯一约束冲突。
|
||||
//
|
||||
// 支持多种检测方式:
|
||||
// 1. PostgreSQL 特定错误码 23505(唯一约束冲突)
|
||||
// 2. 错误消息中包含的通用关键词
|
||||
//
|
||||
// 这种多层次的检测确保了对不同数据库驱动和 ORM 的兼容性。
|
||||
func isUniqueConstraintViolation(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 优先检测 PostgreSQL 特定错误码(最精确)。
|
||||
// 错误码 23505 对应 unique_violation。
|
||||
// 参考:https://www.postgresql.org/docs/current/errcodes-appendix.html
|
||||
var pgErr *pq.Error
|
||||
if errors.As(err, &pgErr) {
|
||||
return pgErr.Code == "23505"
|
||||
}
|
||||
|
||||
// 回退到错误消息检测(兼容其他场景)。
|
||||
// 这些关键词覆盖了 PostgreSQL、MySQL 等主流数据库的错误消息。
|
||||
msg := strings.ToLower(err.Error())
|
||||
return strings.Contains(msg, "duplicate key") ||
|
||||
strings.Contains(msg, "unique constraint") ||
|
||||
strings.Contains(msg, "duplicate entry")
|
||||
}
|
||||
@@ -0,0 +1,427 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func mustCreateUser(t *testing.T, client *dbent.Client, u *service.User) *service.User {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
|
||||
create := client.User.Create().
|
||||
SetEmail(u.Email).
|
||||
SetPasswordHash(u.PasswordHash).
|
||||
SetRole(u.Role).
|
||||
SetStatus(u.Status).
|
||||
SetBalance(u.Balance).
|
||||
SetConcurrency(u.Concurrency).
|
||||
SetUsername(u.Username).
|
||||
SetNotes(u.Notes)
|
||||
if !u.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(u.CreatedAt)
|
||||
}
|
||||
if !u.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(u.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user")
|
||||
|
||||
u.ID = created.ID
|
||||
u.CreatedAt = created.CreatedAt
|
||||
u.UpdatedAt = created.UpdatedAt
|
||||
|
||||
if len(u.AllowedGroups) > 0 {
|
||||
for _, groupID := range u.AllowedGroups {
|
||||
_, err := client.UserAllowedGroup.Create().
|
||||
SetUserID(u.ID).
|
||||
SetGroupID(groupID).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create user_allowed_groups row")
|
||||
}
|
||||
}
|
||||
|
||||
return u
|
||||
}
|
||||
|
||||
func mustCreateGroup(t *testing.T, client *dbent.Client, g *service.Group) *service.Group {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if g.Platform == "" {
|
||||
g.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if g.Status == "" {
|
||||
g.Status = service.StatusActive
|
||||
}
|
||||
if g.SubscriptionType == "" {
|
||||
g.SubscriptionType = service.SubscriptionTypeStandard
|
||||
}
|
||||
|
||||
create := client.Group.Create().
|
||||
SetName(g.Name).
|
||||
SetPlatform(g.Platform).
|
||||
SetStatus(g.Status).
|
||||
SetSubscriptionType(g.SubscriptionType).
|
||||
SetRateMultiplier(g.RateMultiplier).
|
||||
SetIsExclusive(g.IsExclusive)
|
||||
if g.Description != "" {
|
||||
create.SetDescription(g.Description)
|
||||
}
|
||||
if g.DailyLimitUSD != nil {
|
||||
create.SetDailyLimitUsd(*g.DailyLimitUSD)
|
||||
}
|
||||
if g.WeeklyLimitUSD != nil {
|
||||
create.SetWeeklyLimitUsd(*g.WeeklyLimitUSD)
|
||||
}
|
||||
if g.MonthlyLimitUSD != nil {
|
||||
create.SetMonthlyLimitUsd(*g.MonthlyLimitUSD)
|
||||
}
|
||||
if !g.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(g.CreatedAt)
|
||||
}
|
||||
if !g.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(g.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create group")
|
||||
|
||||
g.ID = created.ID
|
||||
g.CreatedAt = created.CreatedAt
|
||||
g.UpdatedAt = created.UpdatedAt
|
||||
return g
|
||||
}
|
||||
|
||||
func mustCreateProxy(t *testing.T, client *dbent.Client, p *service.Proxy) *service.Proxy {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if p.Protocol == "" {
|
||||
p.Protocol = "http"
|
||||
}
|
||||
if p.Host == "" {
|
||||
p.Host = "127.0.0.1"
|
||||
}
|
||||
if p.Port == 0 {
|
||||
p.Port = 8080
|
||||
}
|
||||
if p.Status == "" {
|
||||
p.Status = service.StatusActive
|
||||
}
|
||||
|
||||
create := client.Proxy.Create().
|
||||
SetName(p.Name).
|
||||
SetProtocol(p.Protocol).
|
||||
SetHost(p.Host).
|
||||
SetPort(p.Port).
|
||||
SetStatus(p.Status)
|
||||
if p.Username != "" {
|
||||
create.SetUsername(p.Username)
|
||||
}
|
||||
if p.Password != "" {
|
||||
create.SetPassword(p.Password)
|
||||
}
|
||||
if !p.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(p.CreatedAt)
|
||||
}
|
||||
if !p.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(p.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create proxy")
|
||||
|
||||
p.ID = created.ID
|
||||
p.CreatedAt = created.CreatedAt
|
||||
p.UpdatedAt = created.UpdatedAt
|
||||
return p
|
||||
}
|
||||
|
||||
func mustCreateAccount(t *testing.T, client *dbent.Client, a *service.Account) *service.Account {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if a.Platform == "" {
|
||||
a.Platform = service.PlatformAnthropic
|
||||
}
|
||||
if a.Type == "" {
|
||||
a.Type = service.AccountTypeOAuth
|
||||
}
|
||||
if a.Status == "" {
|
||||
a.Status = service.StatusActive
|
||||
}
|
||||
if a.Concurrency == 0 {
|
||||
a.Concurrency = 3
|
||||
}
|
||||
if a.Priority == 0 {
|
||||
a.Priority = 50
|
||||
}
|
||||
if !a.Schedulable {
|
||||
a.Schedulable = true
|
||||
}
|
||||
if a.Credentials == nil {
|
||||
a.Credentials = map[string]any{}
|
||||
}
|
||||
if a.Extra == nil {
|
||||
a.Extra = map[string]any{}
|
||||
}
|
||||
|
||||
create := client.Account.Create().
|
||||
SetName(a.Name).
|
||||
SetPlatform(a.Platform).
|
||||
SetType(a.Type).
|
||||
SetCredentials(a.Credentials).
|
||||
SetExtra(a.Extra).
|
||||
SetConcurrency(a.Concurrency).
|
||||
SetPriority(a.Priority).
|
||||
SetStatus(a.Status).
|
||||
SetSchedulable(a.Schedulable).
|
||||
SetErrorMessage(a.ErrorMessage)
|
||||
|
||||
if a.ProxyID != nil {
|
||||
create.SetProxyID(*a.ProxyID)
|
||||
}
|
||||
if a.LastUsedAt != nil {
|
||||
create.SetLastUsedAt(*a.LastUsedAt)
|
||||
}
|
||||
if a.RateLimitedAt != nil {
|
||||
create.SetRateLimitedAt(*a.RateLimitedAt)
|
||||
}
|
||||
if a.RateLimitResetAt != nil {
|
||||
create.SetRateLimitResetAt(*a.RateLimitResetAt)
|
||||
}
|
||||
if a.OverloadUntil != nil {
|
||||
create.SetOverloadUntil(*a.OverloadUntil)
|
||||
}
|
||||
if a.SessionWindowStart != nil {
|
||||
create.SetSessionWindowStart(*a.SessionWindowStart)
|
||||
}
|
||||
if a.SessionWindowEnd != nil {
|
||||
create.SetSessionWindowEnd(*a.SessionWindowEnd)
|
||||
}
|
||||
if a.SessionWindowStatus != "" {
|
||||
create.SetSessionWindowStatus(a.SessionWindowStatus)
|
||||
}
|
||||
if !a.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(a.CreatedAt)
|
||||
}
|
||||
if !a.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(a.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create account")
|
||||
|
||||
a.ID = created.ID
|
||||
a.CreatedAt = created.CreatedAt
|
||||
a.UpdatedAt = created.UpdatedAt
|
||||
return a
|
||||
}
|
||||
|
||||
func mustCreateApiKey(t *testing.T, client *dbent.Client, k *service.APIKey) *service.APIKey {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if k.Status == "" {
|
||||
k.Status = service.StatusActive
|
||||
}
|
||||
if k.Key == "" {
|
||||
k.Key = "sk-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
if k.Name == "" {
|
||||
k.Name = "default"
|
||||
}
|
||||
|
||||
create := client.APIKey.Create().
|
||||
SetUserID(k.UserID).
|
||||
SetKey(k.Key).
|
||||
SetName(k.Name).
|
||||
SetStatus(k.Status)
|
||||
if k.Quota != 0 {
|
||||
create.SetQuota(k.Quota)
|
||||
}
|
||||
if k.QuotaUsed != 0 {
|
||||
create.SetQuotaUsed(k.QuotaUsed)
|
||||
}
|
||||
if k.RateLimit5h != 0 {
|
||||
create.SetRateLimit5h(k.RateLimit5h)
|
||||
}
|
||||
if k.RateLimit1d != 0 {
|
||||
create.SetRateLimit1d(k.RateLimit1d)
|
||||
}
|
||||
if k.RateLimit7d != 0 {
|
||||
create.SetRateLimit7d(k.RateLimit7d)
|
||||
}
|
||||
if k.Usage5h != 0 {
|
||||
create.SetUsage5h(k.Usage5h)
|
||||
}
|
||||
if k.Usage1d != 0 {
|
||||
create.SetUsage1d(k.Usage1d)
|
||||
}
|
||||
if k.Usage7d != 0 {
|
||||
create.SetUsage7d(k.Usage7d)
|
||||
}
|
||||
if k.Window5hStart != nil {
|
||||
create.SetWindow5hStart(*k.Window5hStart)
|
||||
}
|
||||
if k.Window1dStart != nil {
|
||||
create.SetWindow1dStart(*k.Window1dStart)
|
||||
}
|
||||
if k.Window7dStart != nil {
|
||||
create.SetWindow7dStart(*k.Window7dStart)
|
||||
}
|
||||
if k.ExpiresAt != nil {
|
||||
create.SetExpiresAt(*k.ExpiresAt)
|
||||
}
|
||||
if k.GroupID != nil {
|
||||
create.SetGroupID(*k.GroupID)
|
||||
}
|
||||
if !k.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(k.CreatedAt)
|
||||
}
|
||||
if !k.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(k.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create api key")
|
||||
|
||||
k.ID = created.ID
|
||||
k.CreatedAt = created.CreatedAt
|
||||
k.UpdatedAt = created.UpdatedAt
|
||||
return k
|
||||
}
|
||||
|
||||
func mustCreateRedeemCode(t *testing.T, client *dbent.Client, c *service.RedeemCode) *service.RedeemCode {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if c.Status == "" {
|
||||
c.Status = service.StatusUnused
|
||||
}
|
||||
if c.Type == "" {
|
||||
c.Type = service.RedeemTypeBalance
|
||||
}
|
||||
if c.Code == "" {
|
||||
c.Code = "rc-" + time.Now().Format("150405.000000")
|
||||
}
|
||||
|
||||
create := client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays)
|
||||
if c.UsedBy != nil {
|
||||
create.SetUsedBy(*c.UsedBy)
|
||||
}
|
||||
if c.UsedAt != nil {
|
||||
create.SetUsedAt(*c.UsedAt)
|
||||
}
|
||||
if c.GroupID != nil {
|
||||
create.SetGroupID(*c.GroupID)
|
||||
}
|
||||
if !c.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(c.CreatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create redeem code")
|
||||
|
||||
c.ID = created.ID
|
||||
c.CreatedAt = created.CreatedAt
|
||||
return c
|
||||
}
|
||||
|
||||
func mustCreateSubscription(t *testing.T, client *dbent.Client, s *service.UserSubscription) *service.UserSubscription {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
if s.Status == "" {
|
||||
s.Status = service.SubscriptionStatusActive
|
||||
}
|
||||
now := time.Now()
|
||||
if s.StartsAt.IsZero() {
|
||||
s.StartsAt = now.Add(-1 * time.Hour)
|
||||
}
|
||||
if s.ExpiresAt.IsZero() {
|
||||
s.ExpiresAt = now.Add(24 * time.Hour)
|
||||
}
|
||||
if s.AssignedAt.IsZero() {
|
||||
s.AssignedAt = now
|
||||
}
|
||||
if s.CreatedAt.IsZero() {
|
||||
s.CreatedAt = now
|
||||
}
|
||||
if s.UpdatedAt.IsZero() {
|
||||
s.UpdatedAt = now
|
||||
}
|
||||
|
||||
create := client.UserSubscription.Create().
|
||||
SetUserID(s.UserID).
|
||||
SetGroupID(s.GroupID).
|
||||
SetStartsAt(s.StartsAt).
|
||||
SetExpiresAt(s.ExpiresAt).
|
||||
SetStatus(s.Status).
|
||||
SetAssignedAt(s.AssignedAt).
|
||||
SetNotes(s.Notes).
|
||||
SetDailyUsageUsd(s.DailyUsageUSD).
|
||||
SetWeeklyUsageUsd(s.WeeklyUsageUSD).
|
||||
SetMonthlyUsageUsd(s.MonthlyUsageUSD)
|
||||
|
||||
if s.AssignedBy != nil {
|
||||
create.SetAssignedBy(*s.AssignedBy)
|
||||
}
|
||||
if !s.CreatedAt.IsZero() {
|
||||
create.SetCreatedAt(s.CreatedAt)
|
||||
}
|
||||
if !s.UpdatedAt.IsZero() {
|
||||
create.SetUpdatedAt(s.UpdatedAt)
|
||||
}
|
||||
|
||||
created, err := create.Save(ctx)
|
||||
require.NoError(t, err, "create user subscription")
|
||||
|
||||
s.ID = created.ID
|
||||
s.CreatedAt = created.CreatedAt
|
||||
s.UpdatedAt = created.UpdatedAt
|
||||
return s
|
||||
}
|
||||
|
||||
func mustBindAccountToGroup(t *testing.T, client *dbent.Client, accountID, groupID int64, priority int) {
|
||||
t.Helper()
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := client.AccountGroup.Create().
|
||||
SetAccountID(accountID).
|
||||
SetGroupID(groupID).
|
||||
SetPriority(priority).
|
||||
Save(ctx)
|
||||
require.NoError(t, err, "create account_group")
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const stickySessionPrefix = "sticky_session:"
|
||||
|
||||
type gatewayCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGatewayCache(rdb *redis.Client) service.GatewayCache {
|
||||
return &gatewayCache{rdb: rdb}
|
||||
}
|
||||
|
||||
// buildSessionKey 构建 session key,包含 groupID 实现分组隔离
|
||||
// 格式: sticky_session:{groupID}:{sessionHash}
|
||||
func buildSessionKey(groupID int64, sessionHash string) string {
|
||||
return fmt.Sprintf("%s%d:%s", stickySessionPrefix, groupID, sessionHash)
|
||||
}
|
||||
|
||||
func (c *gatewayCache) GetSessionAccountID(ctx context.Context, groupID int64, sessionHash string) (int64, error) {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Get(ctx, key).Int64()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) SetSessionAccountID(ctx context.Context, groupID int64, sessionHash string, accountID int64, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Set(ctx, key, accountID, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, sessionHash string, ttl time.Duration) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Expire(ctx, key, ttl).Err()
|
||||
}
|
||||
|
||||
// DeleteSessionAccountID 删除粘性会话与账号的绑定关系。
|
||||
// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用,
|
||||
// 以便下次请求能够重新选择可用账号。
|
||||
//
|
||||
// DeleteSessionAccountID removes the sticky session binding for the given session.
|
||||
// Called when the bound account becomes unavailable (e.g., error status, disabled,
|
||||
// or unschedulable), allowing subsequent requests to select a new available account.
|
||||
func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error {
|
||||
key := buildSessionKey(groupID, sessionHash)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -0,0 +1,109 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGatewayCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
groupID := int64(1)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// GatewayRoutingSuite 测试网关路由相关的数据库查询
|
||||
// 验证账户选择和分流逻辑在真实数据库环境下的行为
|
||||
type GatewayRoutingSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
accountRepo *accountRepository
|
||||
}
|
||||
|
||||
func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayRoutingSuite))
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
|
||||
// 创建各平台账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-oauth",
|
||||
Platform: service.PlatformGemini,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 1,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-oauth",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 2,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
})
|
||||
|
||||
// 创建不应被选中的 anthropic 账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "anthropic-oauth",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 0,
|
||||
})
|
||||
|
||||
// 查询 gemini + antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
|
||||
|
||||
// 验证返回的账户平台
|
||||
platforms := make(map[string]bool)
|
||||
for _, acc := range accounts {
|
||||
platforms[acc.Platform] = true
|
||||
}
|
||||
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
|
||||
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
|
||||
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
|
||||
|
||||
// 验证账户 ID 匹配
|
||||
ids := make(map[int64]bool)
|
||||
for _, acc := range accounts {
|
||||
ids[acc.ID] = true
|
||||
}
|
||||
s.Require().True(ids[geminiAcc.ID])
|
||||
s.Require().True(ids[antigravityAcc.ID])
|
||||
}
|
||||
|
||||
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
|
||||
// 创建 gemini 分组
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{
|
||||
Name: "gemini-group",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// 创建账户
|
||||
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "unbound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只绑定一个账户到分组
|
||||
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
|
||||
|
||||
// 查询分组内的账户
|
||||
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
|
||||
s.Require().Equal(boundAcc.ID, accounts[0].ID)
|
||||
|
||||
// 确认未绑定的账户不在结果中
|
||||
for _, acc := range accounts {
|
||||
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// 创建多种平台账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-1",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只查询 antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(antigravity.ID, accounts[0].ID)
|
||||
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
|
||||
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 创建可调度账户
|
||||
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "active-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "inactive-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
|
||||
|
||||
// 创建错误状态账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "error-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusError,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
|
||||
s.Require().Equal(activeAcc.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
// TestPlatformRoutingDecision 验证平台路由决策
|
||||
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
|
||||
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
|
||||
// 创建两种平台的账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-route-test",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-route-test",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expectedService string
|
||||
}{
|
||||
{
|
||||
name: "Gemini账户路由到ForwardNative",
|
||||
accountID: geminiAcc.ID,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
},
|
||||
{
|
||||
name: "Antigravity账户路由到ForwardGemini",
|
||||
accountID: antigravityAcc.ID,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 从数据库获取账户
|
||||
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 模拟 Handler 层的路由决策
|
||||
var routedService string
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
s.Require().Equal(tt.expectedService, routedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
package repository
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
|
||||
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
|
||||
// Returned as geminicli.DriveClient interface for DI (Strategy A).
|
||||
func NewGeminiDriveClient() geminicli.DriveClient {
|
||||
return geminicli.NewDriveClient()
|
||||
}
|
||||
@@ -0,0 +1,125 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiOAuthClient struct {
|
||||
tokenURL string
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
func NewGeminiOAuthClient(cfg *config.Config) service.GeminiOAuthClient {
|
||||
return &geminiOAuthClient{
|
||||
tokenURL: geminicli.TokenURL,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) ExchangeCode(ctx context.Context, oauthType, code, codeVerifier, redirectURI, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client, err := createGeminiReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
// Use different OAuth clients based on oauthType:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public)
|
||||
// - google_one: always use built-in Gemini CLI OAuth client (public)
|
||||
// - ai_studio: requires a user-provided OAuth client
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
formData.Set("code", code)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token exchange failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (c *geminiOAuthClient) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*geminicli.TokenResponse, error) {
|
||||
client, err := createGeminiReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
|
||||
oauthCfgInput := geminicli.OAuthConfig{
|
||||
ClientID: c.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: c.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: c.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
// Force use of built-in Gemini CLI OAuth client
|
||||
oauthCfgInput.ClientID = ""
|
||||
oauthCfgInput.ClientSecret = ""
|
||||
}
|
||||
|
||||
oauthCfg, err := geminicli.EffectiveOAuthConfig(oauthCfgInput, oauthType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", oauthCfg.ClientID)
|
||||
formData.Set("client_secret", oauthCfg.ClientSecret)
|
||||
|
||||
var tokenResp geminicli.TokenResponse
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(c.tokenURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, fmt.Errorf("token refresh failed: status %d, body: %s", resp.StatusCode, geminicli.SanitizeBodyForLogs(resp.String()))
|
||||
}
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createGeminiReqClient(proxyURL string) (*req.Client, error) {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 60 * time.Second,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthTokenKeyPrefix = "oauth:token:"
|
||||
oauthRefreshLockKeyPrefix = "oauth:refresh_lock:"
|
||||
)
|
||||
|
||||
type geminiTokenCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewGeminiTokenCache(rdb *redis.Client) service.GeminiTokenCache {
|
||||
return &geminiTokenCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Get(ctx, key).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Set(ctx, key, token, ttl).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", oauthTokenKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *geminiTokenCache) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
key := fmt.Sprintf("%s%s", oauthRefreshLockKeyPrefix, cacheKey)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GeminiTokenCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GeminiTokenCache
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGeminiTokenCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||
cacheKey := "project-123"
|
||||
token := "token-value"
|
||||
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||
|
||||
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), token, got)
|
||||
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||
|
||||
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||
}
|
||||
|
||||
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
cache := NewGeminiTokenCache(rdb)
|
||||
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||
require.Error(t, err)
|
||||
}
|
||||
@@ -0,0 +1,135 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/googleapi"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type geminiCliCodeAssistClient struct {
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func NewGeminiCliCodeAssistClient() service.GeminiCliCodeAssistClient {
|
||||
return &geminiCliCodeAssistClient{baseURL: geminicli.GeminiCliBaseURL}
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.LoadCodeAssistRequest) (*geminicli.LoadCodeAssistResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultLoadCodeAssistRequest()
|
||||
}
|
||||
|
||||
var out geminicli.LoadCodeAssistResponse
|
||||
client, err := createGeminiCliReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:loadCodeAssist")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken, proxyURL string, reqBody *geminicli.OnboardUserRequest) (*geminicli.OnboardUserResponse, error) {
|
||||
if reqBody == nil {
|
||||
reqBody = defaultOnboardUserRequest()
|
||||
}
|
||||
|
||||
fmt.Printf("[CodeAssist] OnboardUser request body: %+v\n", reqBody)
|
||||
|
||||
var out geminicli.OnboardUserResponse
|
||||
client, err := createGeminiCliReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create HTTP client: %w", err)
|
||||
}
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Content-Type", "application/json").
|
||||
SetHeader("User-Agent", geminicli.GeminiCLIUserAgent).
|
||||
SetBody(reqBody).
|
||||
SetSuccessResult(&out).
|
||||
Post(c.baseURL + "/v1internal:onboardUser")
|
||||
if err != nil {
|
||||
fmt.Printf("[CodeAssist] OnboardUser request error: %v\n", err)
|
||||
return nil, fmt.Errorf("request failed: %w", err)
|
||||
}
|
||||
if !resp.IsSuccessState() {
|
||||
body := resp.String()
|
||||
sanitizedBody := geminicli.SanitizeBodyForLogs(body)
|
||||
fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody)
|
||||
|
||||
// Check if this is a SERVICE_DISABLED error and extract activation URL
|
||||
if googleapi.IsServiceDisabledError(body) {
|
||||
activationURL := googleapi.ExtractActivationURL(body)
|
||||
if activationURL != "" {
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL)
|
||||
}
|
||||
return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com")
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody)
|
||||
}
|
||||
fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out)
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func createGeminiCliReqClient(proxyURL string) (*req.Client, error) {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 30 * time.Second,
|
||||
})
|
||||
}
|
||||
|
||||
func defaultLoadCodeAssistRequest() *geminicli.LoadCodeAssistRequest {
|
||||
return &geminicli.LoadCodeAssistRequest{
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func defaultOnboardUserRequest() *geminicli.OnboardUserRequest {
|
||||
return &geminicli.OnboardUserRequest{
|
||||
TierID: "LEGACY",
|
||||
Metadata: geminicli.LoadCodeAssistMetadata{
|
||||
IDEType: "ANTIGRAVITY",
|
||||
Platform: "PLATFORM_UNSPECIFIED",
|
||||
PluginType: "GEMINI",
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type githubReleaseClient struct {
|
||||
httpClient *http.Client
|
||||
downloadHTTPClient *http.Client
|
||||
}
|
||||
|
||||
type githubReleaseClientError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// NewGitHubReleaseClient 创建 GitHub Release 客户端
|
||||
// proxyURL 为空时直连 GitHub,支持 http/https/socks5/socks5h 协议
|
||||
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
|
||||
// - false(默认):返回错误占位客户端,禁止回退到直连
|
||||
// - true:回退到直连(仅限管理员显式开启)
|
||||
func NewGitHubReleaseClient(proxyURL string, allowDirectOnProxyError bool) service.GitHubReleaseClient {
|
||||
// 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据,
|
||||
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
|
||||
slog.Warn("proxy client init failed, all requests will fail", "service", "github_release", "error", err)
|
||||
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
|
||||
}
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
|
||||
// 下载客户端需要更长的超时时间
|
||||
downloadClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 10 * time.Minute,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
|
||||
slog.Warn("proxy download client init failed, all requests will fail", "service", "github_release", "error", err)
|
||||
return &githubReleaseClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
|
||||
}
|
||||
downloadClient = &http.Client{Timeout: 10 * time.Minute}
|
||||
}
|
||||
|
||||
return &githubReleaseClient{
|
||||
httpClient: sharedClient,
|
||||
downloadHTTPClient: downloadClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
return c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClientError) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchLatestRelease(ctx context.Context, repo string) (*service.GitHubRelease, error) {
|
||||
url := fmt.Sprintf("https://api.github.com/repos/%s/releases/latest", repo)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Accept", "application/vnd.github.v3+json")
|
||||
req.Header.Set("User-Agent", "Sub2API-Updater")
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("GitHub API returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
var release service.GitHubRelease
|
||||
if err := json.NewDecoder(resp.Body).Decode(&release); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &release, nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) DownloadFile(ctx context.Context, url, dest string, maxSize int64) error {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 使用预配置的下载客户端(已包含代理配置)
|
||||
resp, err := c.downloadHTTPClient.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("download returned %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
// SECURITY: Check Content-Length if available
|
||||
if resp.ContentLength > maxSize {
|
||||
return fmt.Errorf("file too large: %d bytes (max %d)", resp.ContentLength, maxSize)
|
||||
}
|
||||
|
||||
out, err := os.Create(dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// SECURITY: Use LimitReader to enforce max download size even if Content-Length is missing/wrong
|
||||
limited := io.LimitReader(resp.Body, maxSize+1)
|
||||
written, err := io.Copy(out, limited)
|
||||
|
||||
// Close file before attempting to remove (required on Windows)
|
||||
_ = out.Close()
|
||||
|
||||
if err != nil {
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return err
|
||||
}
|
||||
|
||||
// Check if we hit the limit (downloaded more than maxSize)
|
||||
if written > maxSize {
|
||||
_ = os.Remove(dest) // Clean up partial file (best-effort)
|
||||
return fmt.Errorf("download exceeded maximum size of %d bytes", maxSize)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *githubReleaseClient) FetchChecksumFile(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
@@ -0,0 +1,317 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GitHubReleaseServiceSuite struct {
|
||||
suite.Suite
|
||||
srv *httptest.Server
|
||||
client *githubReleaseClient
|
||||
tempDir string
|
||||
}
|
||||
|
||||
// testTransport redirects requests to the test server
|
||||
type testTransport struct {
|
||||
testServerURL string
|
||||
}
|
||||
|
||||
func (t *testTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
// Rewrite the URL to point to our test server
|
||||
testURL := t.testServerURL + req.URL.Path
|
||||
newReq, err := http.NewRequestWithContext(req.Context(), req.Method, testURL, req.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newReq.Header = req.Header
|
||||
return http.DefaultTransport.RoundTrip(newReq)
|
||||
}
|
||||
|
||||
func newTestGitHubReleaseClient() *githubReleaseClient {
|
||||
return &githubReleaseClient{
|
||||
httpClient: &http.Client{},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) SetupTest() {
|
||||
s.tempDir = s.T().TempDir()
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_ContentLength() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Length", "100")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(bytes.Repeat([]byte("a"), 100))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file1.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized download with Content-Length")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for rejected download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_EnforcesMaxSize_Chunked() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Force chunked encoding (unknown Content-Length) by flushing headers before writing.
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file2.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 10)
|
||||
require.Error(s.T(), err, "expected error for oversized chunked download")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to be cleaned up for oversized chunked download")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_Success() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
for i := 0; i < 10; i++ {
|
||||
_, _ = w.Write(bytes.Repeat([]byte("b"), 10))
|
||||
if fl, ok := w.(http.Flusher); ok {
|
||||
fl.Flush()
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "file3.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 200)
|
||||
require.NoError(s.T(), err, "expected success")
|
||||
|
||||
b, err := os.ReadFile(dest)
|
||||
require.NoError(s.T(), err, "read")
|
||||
require.True(s.T(), strings.HasPrefix(string(b), "b"), "downloaded content should start with 'b'")
|
||||
require.Len(s.T(), b, 100, "downloaded content length mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_404() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "notfound.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for 404")
|
||||
|
||||
_, statErr := os.Stat(dest)
|
||||
require.Error(s.T(), statErr, "expected file to not exist for 404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Success() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("sum"))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
body, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.NoError(s.T(), err, "FetchChecksumFile")
|
||||
require.Equal(s.T(), "sum", string(body), "checksum body mismatch")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_Non200() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), s.srv.URL)
|
||||
require.Error(s.T(), err, "expected error for non-200")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "cancelled.bin")
|
||||
err := s.client.DownloadFile(ctx, s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for cancelled context")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidURL() {
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
dest := filepath.Join(s.tempDir, "invalid.bin")
|
||||
err := s.client.DownloadFile(context.Background(), "://invalid-url", dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestDownloadFile_InvalidDestPath() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("content"))
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
// Use a path that cannot be created (directory doesn't exist)
|
||||
dest := filepath.Join(s.tempDir, "nonexistent", "subdir", "file.bin")
|
||||
err := s.client.DownloadFile(context.Background(), s.srv.URL, dest, 100)
|
||||
require.Error(s.T(), err, "expected error for invalid destination path")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_InvalidURL() {
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(context.Background(), "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Success() {
|
||||
releaseJSON := `{
|
||||
"tag_name": "v1.0.0",
|
||||
"name": "Release 1.0.0",
|
||||
"body": "Release notes",
|
||||
"html_url": "https://github.com/test/repo/releases/v1.0.0",
|
||||
"assets": [
|
||||
{
|
||||
"name": "app-linux-amd64.tar.gz",
|
||||
"browser_download_url": "https://github.com/test/repo/releases/download/v1.0.0/app-linux-amd64.tar.gz"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
require.Equal(s.T(), "/repos/test/repo/releases/latest", r.URL.Path)
|
||||
require.Equal(s.T(), "application/vnd.github.v3+json", r.Header.Get("Accept"))
|
||||
require.Equal(s.T(), "Sub2API-Updater", r.Header.Get("User-Agent"))
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(releaseJSON))
|
||||
}))
|
||||
|
||||
// Use custom transport to redirect requests to test server
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
release, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v1.0.0", release.TagName)
|
||||
require.Equal(s.T(), "Release 1.0.0", release.Name)
|
||||
require.Len(s.T(), release.Assets, 1)
|
||||
require.Equal(s.T(), "app-linux-amd64.tar.gz", release.Assets[0].Name)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_Non200() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
require.Contains(s.T(), err.Error(), "404")
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_InvalidJSON() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("not valid json"))
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
_, err := s.client.FetchLatestRelease(context.Background(), "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchLatestRelease_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = &githubReleaseClient{
|
||||
httpClient: &http.Client{
|
||||
Transport: &testTransport{testServerURL: s.srv.URL},
|
||||
},
|
||||
downloadHTTPClient: &http.Client{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchLatestRelease(ctx, "test/repo")
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *GitHubReleaseServiceSuite) TestFetchChecksumFile_ContextCancel() {
|
||||
s.srv = newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
s.client = newTestGitHubReleaseClient()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := s.client.FetchChecksumFile(ctx, s.srv.URL)
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestGitHubReleaseServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(GitHubReleaseServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,671 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/apikey"
|
||||
"github.com/Wei-Shaw/sub2api/ent/group"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type sqlExecutor interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type groupRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewGroupRepository(client *dbent.Client, sqlDB *sql.DB) service.GroupRepository {
|
||||
return newGroupRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newGroupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *groupRepository {
|
||||
return &groupRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) error {
|
||||
builder := r.client.Group.Create().
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
}
|
||||
|
||||
// 设置支持的模型系列(始终设置,空数组表示不限制)
|
||||
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
groupIn.ID = created.ID
|
||||
groupIn.CreatedAt = created.CreatedAt
|
||||
groupIn.UpdatedAt = created.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group create failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
}
|
||||
return translatePersistenceError(err, nil, service.ErrGroupExists)
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByID(ctx context.Context, id int64) (*service.Group, error) {
|
||||
out, err := r.GetByIDLite(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
count, _ := r.GetAccountCount(ctx, out.ID)
|
||||
out.AccountCount = count
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetByIDLite(ctx context.Context, id int64) (*service.Group, error) {
|
||||
// AccountCount is intentionally not loaded here; use GetByID when needed.
|
||||
m, err := r.client.Group.Query().
|
||||
Where(group.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
return groupEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) error {
|
||||
builder := r.client.Group.UpdateOneID(groupIn.ID).
|
||||
SetName(groupIn.Name).
|
||||
SetDescription(groupIn.Description).
|
||||
SetPlatform(groupIn.Platform).
|
||||
SetRateMultiplier(groupIn.RateMultiplier).
|
||||
SetIsExclusive(groupIn.IsExclusive).
|
||||
SetStatus(groupIn.Status).
|
||||
SetSubscriptionType(groupIn.SubscriptionType).
|
||||
SetNillableDailyLimitUsd(groupIn.DailyLimitUSD).
|
||||
SetNillableWeeklyLimitUsd(groupIn.WeeklyLimitUSD).
|
||||
SetNillableMonthlyLimitUsd(groupIn.MonthlyLimitUSD).
|
||||
SetNillableImagePrice1k(groupIn.ImagePrice1K).
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetNillableSoraImagePrice360(groupIn.SoraImagePrice360).
|
||||
SetNillableSoraImagePrice540(groupIn.SoraImagePrice540).
|
||||
SetNillableSoraVideoPricePerRequest(groupIn.SoraVideoPricePerRequest).
|
||||
SetNillableSoraVideoPricePerRequestHd(groupIn.SoraVideoPricePerRequestHD).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
if groupIn.DailyLimitUSD != nil {
|
||||
builder = builder.SetDailyLimitUsd(*groupIn.DailyLimitUSD)
|
||||
} else {
|
||||
builder = builder.ClearDailyLimitUsd()
|
||||
}
|
||||
if groupIn.WeeklyLimitUSD != nil {
|
||||
builder = builder.SetWeeklyLimitUsd(*groupIn.WeeklyLimitUSD)
|
||||
} else {
|
||||
builder = builder.ClearWeeklyLimitUsd()
|
||||
}
|
||||
if groupIn.MonthlyLimitUSD != nil {
|
||||
builder = builder.SetMonthlyLimitUsd(*groupIn.MonthlyLimitUSD)
|
||||
} else {
|
||||
builder = builder.ClearMonthlyLimitUsd()
|
||||
}
|
||||
if groupIn.ImagePrice1K != nil {
|
||||
builder = builder.SetImagePrice1k(*groupIn.ImagePrice1K)
|
||||
} else {
|
||||
builder = builder.ClearImagePrice1k()
|
||||
}
|
||||
if groupIn.ImagePrice2K != nil {
|
||||
builder = builder.SetImagePrice2k(*groupIn.ImagePrice2K)
|
||||
} else {
|
||||
builder = builder.ClearImagePrice2k()
|
||||
}
|
||||
if groupIn.ImagePrice4K != nil {
|
||||
builder = builder.SetImagePrice4k(*groupIn.ImagePrice4K)
|
||||
} else {
|
||||
builder = builder.ClearImagePrice4k()
|
||||
}
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
builder = builder.SetFallbackGroupID(*groupIn.FallbackGroupID)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
// 处理 FallbackGroupIDOnInvalidRequest:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupIDOnInvalidRequest != nil {
|
||||
builder = builder.SetFallbackGroupIDOnInvalidRequest(*groupIn.FallbackGroupIDOnInvalidRequest)
|
||||
} else {
|
||||
builder = builder.ClearFallbackGroupIDOnInvalidRequest()
|
||||
}
|
||||
|
||||
// 处理 ModelRouting:nil 时清除,否则设置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
} else {
|
||||
builder = builder.ClearModelRouting()
|
||||
}
|
||||
|
||||
// 处理 SupportedModelScopes(始终设置,空数组表示不限制)
|
||||
builder = builder.SetSupportedModelScopes(groupIn.SupportedModelScopes)
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
}
|
||||
groupIn.UpdatedAt = updated.UpdatedAt
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupIn.ID, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group update failed: group=%d err=%v", groupIn.ID, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Group.Delete().Where(group.IDEQ(id)).Exec(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "", nil)
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]service.Group, *pagination.PaginationResult, error) {
|
||||
q := r.client.Group.Query()
|
||||
|
||||
if platform != "" {
|
||||
q = q.Where(group.PlatformEQ(platform))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(group.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(group.Or(
|
||||
group.NameContainsFold(search),
|
||||
group.DescriptionContainsFold(search),
|
||||
))
|
||||
}
|
||||
if isExclusive != nil {
|
||||
q = q.Where(group.IsExclusiveEQ(*isExclusive))
|
||||
}
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groups, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActive(ctx context.Context) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) {
|
||||
groups, err := r.client.Group.Query().
|
||||
Where(group.StatusEQ(service.StatusActive), group.PlatformEQ(platform)).
|
||||
Order(dbent.Asc(group.FieldSortOrder), dbent.Asc(group.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupIDs := make([]int64, 0, len(groups))
|
||||
outGroups := make([]service.Group, 0, len(groups))
|
||||
for i := range groups {
|
||||
g := groupEntityToService(groups[i])
|
||||
outGroups = append(outGroups, *g)
|
||||
groupIDs = append(groupIDs, g.ID)
|
||||
}
|
||||
|
||||
counts, err := r.loadAccountCounts(ctx, groupIDs)
|
||||
if err == nil {
|
||||
for i := range outGroups {
|
||||
outGroups[i].AccountCount = counts[outGroups[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return outGroups, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx)
|
||||
}
|
||||
|
||||
// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。
|
||||
// 返回结构:map[groupID]exists。
|
||||
func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) {
|
||||
result := make(map[int64]bool, len(ids))
|
||||
if len(ids) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
uniqueIDs := make([]int64, 0, len(ids))
|
||||
seen := make(map[int64]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
if id <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
uniqueIDs = append(uniqueIDs, id)
|
||||
result[id] = false
|
||||
}
|
||||
if len(uniqueIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id
|
||||
FROM groups
|
||||
WHERE id = ANY($1) AND deleted_at IS NULL
|
||||
`, pq.Array(uniqueIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[id] = true
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
res, err := r.sql.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", groupID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
affected, _ := res.RowsAffected()
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group account clear failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
return affected, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
g, err := r.client.Group.Query().Where(group.IDEQ(id)).Only(ctx)
|
||||
if err != nil {
|
||||
return nil, translatePersistenceError(err, service.ErrGroupNotFound, nil)
|
||||
}
|
||||
groupSvc := groupEntityToService(g)
|
||||
|
||||
// 使用 ent 事务统一包裹:避免手工基于 *sql.Tx 构造 ent client 带来的驱动断言问题,
|
||||
// 同时保证级联删除的原子性。
|
||||
tx, err := r.client.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, err
|
||||
}
|
||||
exec := r.client
|
||||
txClient := r.client
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
exec = tx.Client()
|
||||
txClient = exec
|
||||
}
|
||||
// err 为 dbent.ErrTxStarted 时,复用当前 client 参与同一事务。
|
||||
|
||||
// Lock the group row to avoid concurrent writes while we cascade.
|
||||
// 这里使用 exec.QueryContext 手动扫描,确保同一事务内加锁并能区分"未找到"与其他错误。
|
||||
rows, err := exec.QueryContext(ctx, "SELECT id FROM groups WHERE id = $1 AND deleted_at IS NULL FOR UPDATE", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var lockedID int64
|
||||
if rows.Next() {
|
||||
if err := rows.Scan(&lockedID); err != nil {
|
||||
_ = rows.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lockedID == 0 {
|
||||
return nil, service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
var affectedUserIDs []int64
|
||||
if groupSvc.IsSubscriptionType() {
|
||||
// 只查询未软删除的订阅,避免通知已取消订阅的用户
|
||||
rows, err := exec.QueryContext(ctx, "SELECT user_id FROM user_subscriptions WHERE group_id = $1 AND deleted_at IS NULL", id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var userID int64
|
||||
if scanErr := rows.Scan(&userID); scanErr != nil {
|
||||
_ = rows.Close()
|
||||
return nil, scanErr
|
||||
}
|
||||
affectedUserIDs = append(affectedUserIDs, userID)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 软删除订阅:设置 deleted_at 而非硬删除
|
||||
if _, err := exec.ExecContext(ctx, "UPDATE user_subscriptions SET deleted_at = NOW() WHERE group_id = $1 AND deleted_at IS NULL", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Clear group_id for api keys bound to this group.
|
||||
// 仅更新未软删除的记录,避免修改已删除数据,保证审计与历史回溯一致性。
|
||||
// 与 APIKeyRepository 的软删除语义保持一致,减少跨模块行为差异。
|
||||
if _, err := txClient.APIKey.Update().
|
||||
Where(apikey.GroupIDEQ(id), apikey.DeletedAtIsNil()).
|
||||
ClearGroupID().
|
||||
Save(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 3. Remove the group id from user_allowed_groups join table.
|
||||
// Legacy users.allowed_groups 列已弃用,不再同步。
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM user_allowed_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 4. Delete account_groups join rows.
|
||||
if _, err := exec.ExecContext(ctx, "DELETE FROM account_groups WHERE group_id = $1", id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 5. Soft-delete group itself.
|
||||
if _, err := txClient.Group.Delete().Where(group.IDEQ(id)).Exec(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group cascade delete failed: group=%d err=%v", id, err)
|
||||
}
|
||||
|
||||
return affectedUserIDs, nil
|
||||
}
|
||||
|
||||
func (r *groupRepository) loadAccountCounts(ctx context.Context, groupIDs []int64) (counts map[int64]int64, err error) {
|
||||
counts = make(map[int64]int64, len(groupIDs))
|
||||
if len(groupIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT group_id, COUNT(*) FROM account_groups WHERE group_id = ANY($1) GROUP BY group_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var count int64
|
||||
if err = rows.Scan(&groupID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[groupID] = count
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// GetAccountIDsByGroupIDs 获取多个分组的所有账号 ID(去重)
|
||||
func (r *groupRepository) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
rows, err := r.sql.QueryContext(
|
||||
ctx,
|
||||
"SELECT DISTINCT account_id FROM account_groups WHERE group_id = ANY($1) ORDER BY account_id",
|
||||
pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var accountIDs []int64
|
||||
for rows.Next() {
|
||||
var accountID int64
|
||||
if err := rows.Scan(&accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accountIDs = append(accountIDs, accountID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return accountIDs, nil
|
||||
}
|
||||
|
||||
// BindAccountsToGroup 将多个账号绑定到指定分组(批量插入,忽略已存在的绑定)
|
||||
func (r *groupRepository) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
if len(accountIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 使用 INSERT ... ON CONFLICT DO NOTHING 忽略已存在的绑定
|
||||
_, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
`INSERT INTO account_groups (account_id, group_id, priority, created_at)
|
||||
SELECT unnest($1::bigint[]), $2, 50, NOW()
|
||||
ON CONFLICT (account_id, group_id) DO NOTHING`,
|
||||
pq.Array(accountIDs),
|
||||
groupID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 发送调度器事件
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &groupID, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue bind accounts to group failed: group=%d err=%v", groupID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateSortOrders 批量更新分组排序
|
||||
func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []service.GroupSortOrderUpdate) error {
|
||||
if len(updates) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。
|
||||
sortOrderByID := make(map[int64]int, len(updates))
|
||||
groupIDs := make([]int64, 0, len(updates))
|
||||
for _, u := range updates {
|
||||
if u.ID <= 0 {
|
||||
continue
|
||||
}
|
||||
if _, exists := sortOrderByID[u.ID]; !exists {
|
||||
groupIDs = append(groupIDs, u.ID)
|
||||
}
|
||||
sortOrderByID[u.ID] = u.SortOrder
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。
|
||||
var existingCount int
|
||||
if err := scanSingleRow(
|
||||
ctx,
|
||||
r.sql,
|
||||
`SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`,
|
||||
[]any{pq.Array(groupIDs)},
|
||||
&existingCount,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
if existingCount != len(groupIDs) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
args := make([]any, 0, len(groupIDs)*2+1)
|
||||
caseClauses := make([]string, 0, len(groupIDs))
|
||||
placeholder := 1
|
||||
for _, id := range groupIDs {
|
||||
caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1))
|
||||
args = append(args, id, sortOrderByID[id])
|
||||
placeholder += 2
|
||||
}
|
||||
args = append(args, pq.Array(groupIDs))
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE groups
|
||||
SET sort_order = CASE id
|
||||
%s
|
||||
ELSE sort_order
|
||||
END
|
||||
WHERE deleted_at IS NULL AND id = ANY($%d)
|
||||
`, strings.Join(caseClauses, "\n\t\t\t"), placeholder)
|
||||
|
||||
result, err := r.sql.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected != int64(len(groupIDs)) {
|
||||
return service.ErrGroupNotFound
|
||||
}
|
||||
|
||||
for _, id := range groupIDs {
|
||||
if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil {
|
||||
logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,752 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GroupRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *dbent.Tx
|
||||
repo *groupRepository
|
||||
}
|
||||
|
||||
type forbidSQLExecutor struct {
|
||||
called bool
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql exec")
|
||||
}
|
||||
|
||||
func (s *forbidSQLExecutor) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
s.called = true
|
||||
return nil, errors.New("unexpected sql query")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newGroupRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestGroupRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(GroupRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *GroupRepoSuite) TestCreate() {
|
||||
group := &service.Group{
|
||||
Name: "test-create",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, group)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(group.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetByIDLite_DoesNotUseAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "lite-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
spy := &forbidSQLExecutor{}
|
||||
repo := newGroupRepositoryWithSQL(s.tx.Client(), spy)
|
||||
|
||||
got, err := repo.GetByIDLite(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(group.ID, got.ID)
|
||||
s.Require().False(spy.called, "expected no direct sql executor usage")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "original",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
group.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, group)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete() {
|
||||
group := &service.Group{
|
||||
Name: "to-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *GroupRepoSuite) TestList() {
|
||||
baseGroups, basePage, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(groups, len(baseGroups)+2)
|
||||
s.Require().Equal(basePage.Total+2, page.Total)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Platform() {
|
||||
baseGroups, _, err := s.repo.ListWithFilters(
|
||||
s.ctx,
|
||||
pagination.PaginationParams{Page: 1, PageSize: 10},
|
||||
service.PlatformOpenAI,
|
||||
"",
|
||||
"",
|
||||
nil,
|
||||
)
|
||||
s.Require().NoError(err, "ListWithFilters base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformOpenAI, "", "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify all groups are OpenAI platform
|
||||
for _, g := range groups {
|
||||
s.Require().Equal(service.PlatformOpenAI, g.Platform)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Status() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(service.StatusDisabled, groups[0].Status)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_IsExclusive() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
isExclusive := true
|
||||
groups, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "", &isExclusive)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().True(groups[0].IsExclusive)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_Search() {
|
||||
newRepo := func() (*groupRepository, context.Context) {
|
||||
tx := testEntTx(s.T())
|
||||
return newGroupRepositoryWithSQL(tx.Client(), tx), context.Background()
|
||||
}
|
||||
|
||||
containsID := func(groups []service.Group, id int64) bool {
|
||||
for i := range groups {
|
||||
if groups[i].ID == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
mustCreate := func(repo *groupRepository, ctx context.Context, g *service.Group) *service.Group {
|
||||
s.Require().NoError(repo.Create(ctx, g))
|
||||
s.Require().NotZero(g.ID)
|
||||
return g
|
||||
}
|
||||
|
||||
newGroup := func(name string) *service.Group {
|
||||
return &service.Group{
|
||||
Name: name,
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
}
|
||||
|
||||
s.Run("search_name_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("it-group-search-name-target"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-name-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "name-target", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by name")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_description_should_match", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := newGroup("it-group-search-desc-target")
|
||||
target.Description = "something about desc-needle in here"
|
||||
target = mustCreate(repo, ctx, target)
|
||||
|
||||
other := newGroup("it-group-search-desc-other")
|
||||
other.Description = "nothing to see here"
|
||||
other = mustCreate(repo, ctx, other)
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "desc-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected target group to match by description")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_nonexistent_should_return_empty", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
_ = mustCreate(repo, ctx, newGroup("it-group-search-nonexistent-baseline"))
|
||||
|
||||
search := s.T().Name() + "__no_such_group__"
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", search, nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(groups)
|
||||
})
|
||||
|
||||
s.Run("search_should_be_case_insensitive", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
target := mustCreate(repo, ctx, newGroup("MiXeDCaSe-Needle"))
|
||||
other := mustCreate(repo, ctx, newGroup("it-group-search-case-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "mixedcase-needle", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, target.ID), "expected case-insensitive match")
|
||||
s.Require().False(containsID(groups, other.ID), "expected other group to be filtered out")
|
||||
})
|
||||
|
||||
s.Run("search_should_escape_like_wildcards", func() {
|
||||
repo, ctx := newRepo()
|
||||
|
||||
percentTarget := mustCreate(repo, ctx, newGroup("it-group-search-100%-target"))
|
||||
percentOther := mustCreate(repo, ctx, newGroup("it-group-search-100X-other"))
|
||||
|
||||
groups, _, err := repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "100%", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, percentTarget.ID), "expected literal %% match")
|
||||
s.Require().False(containsID(groups, percentOther.ID), "expected %% not to act as wildcard")
|
||||
|
||||
underscoreTarget := mustCreate(repo, ctx, newGroup("it-group-search-ab_cd-target"))
|
||||
underscoreOther := mustCreate(repo, ctx, newGroup("it-group-search-abXcd-other"))
|
||||
|
||||
groups, _, err = repo.ListWithFilters(ctx, pagination.PaginationParams{Page: 1, PageSize: 50}, "", "", "ab_cd", nil)
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(containsID(groups, underscoreTarget.ID), "expected literal _ match")
|
||||
s.Require().False(containsID(groups, underscoreOther.ID), "expected _ not to act as wildcard")
|
||||
})
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() {
|
||||
g1 := &service.Group{
|
||||
Name: "sort-g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g2 := &service.Group{
|
||||
Name: "sort-g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g3 := &service.Group{
|
||||
Name: "sort-g3",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g3))
|
||||
|
||||
err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
|
||||
{ID: g1.ID, SortOrder: 30},
|
||||
{ID: g2.ID, SortOrder: 10},
|
||||
{ID: g3.ID, SortOrder: 20},
|
||||
{ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
got1, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
got2, err := s.repo.GetByID(s.ctx, g2.ID)
|
||||
s.Require().NoError(err)
|
||||
got3, err := s.repo.GetByID(s.ctx, g3.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(30, got1.SortOrder)
|
||||
s.Require().Equal(15, got2.SortOrder)
|
||||
s.Require().Equal(20, got3.SortOrder)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() {
|
||||
g1 := &service.Group{
|
||||
Name: "sort-no-partial",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
|
||||
before, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
beforeSort := before.SortOrder
|
||||
|
||||
err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{
|
||||
{ID: g1.ID, SortOrder: 99},
|
||||
{ID: 99999999, SortOrder: 1},
|
||||
})
|
||||
s.Require().Error(err)
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
|
||||
after, err := s.repo.GetByID(s.ctx, g1.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(beforeSort, after.SortOrder)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListWithFilters_AccountCount() {
|
||||
g1 := &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
g2 := &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: true,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g1))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g2))
|
||||
|
||||
var accountID int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"acc1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g1.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g2.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
isExclusive := true
|
||||
groups, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.PlatformAnthropic, service.StatusActive, "", &isExclusive)
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(groups, 1)
|
||||
s.Require().Equal(g2.ID, groups[0].ID, "ListWithFilters returned wrong group")
|
||||
s.Require().Equal(int64(1), groups[0].AccountCount, "AccountCount mismatch")
|
||||
}
|
||||
|
||||
// --- ListActive / ListActiveByPlatform ---
|
||||
|
||||
func (s *GroupRepoSuite) TestListActive() {
|
||||
baseGroups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive base")
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "active1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "inactive1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(groups, len(baseGroups)+1)
|
||||
// Verify our test group is in the results
|
||||
var found bool
|
||||
for _, g := range groups {
|
||||
if g.Name == "active1" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().True(found, "active1 group should be in results")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestListActiveByPlatform() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g1",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g2",
|
||||
Platform: service.PlatformOpenAI,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "g3",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusDisabled,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
groups, err := s.repo.ListActiveByPlatform(s.ctx, service.PlatformAnthropic)
|
||||
s.Require().NoError(err, "ListActiveByPlatform")
|
||||
// 1 default anthropic group + 1 test active anthropic group = 2 total
|
||||
s.Require().Len(groups, 2)
|
||||
// Verify our test group is in the results
|
||||
var found bool
|
||||
for _, g := range groups {
|
||||
if g.Name == "g1" {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
s.Require().True(found, "g1 group should be in results")
|
||||
}
|
||||
|
||||
// --- ExistsByName ---
|
||||
|
||||
func (s *GroupRepoSuite) TestExistsByName() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.Group{
|
||||
Name: "existing-group",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}))
|
||||
|
||||
exists, err := s.repo.ExistsByName(s.ctx, "existing-group")
|
||||
s.Require().NoError(err, "ExistsByName")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByName(s.ctx, "non-existing")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- GetAccountCount ---
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount() {
|
||||
group := &service.Group{
|
||||
Name: "g-count",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
var a1 int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"a1", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a1,
|
||||
))
|
||||
var a2 int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"a2", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&a2,
|
||||
))
|
||||
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, group.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, group.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestGetAccountCount_Empty() {
|
||||
group := &service.Group{
|
||||
Name: "g-empty",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- DeleteAccountGroupsByGroupID ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID() {
|
||||
g := &service.Group{
|
||||
Name: "g-del",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
var accountID int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{"acc-del", service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&accountID,
|
||||
))
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", accountID, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "DeleteAccountGroupsByGroupID")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
count, err := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().NoError(err, "GetAccountCount")
|
||||
s.Require().Equal(int64(0), count, "expected 0 account groups")
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDeleteAccountGroupsByGroupID_MultipleAccounts() {
|
||||
g := &service.Group{
|
||||
Name: "g-multi",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, g))
|
||||
|
||||
insertAccount := func(name string) int64 {
|
||||
var id int64
|
||||
s.Require().NoError(scanSingleRow(
|
||||
s.ctx,
|
||||
s.tx,
|
||||
"INSERT INTO accounts (name, platform, type) VALUES ($1, $2, $3) RETURNING id",
|
||||
[]any{name, service.PlatformAnthropic, service.AccountTypeOAuth},
|
||||
&id,
|
||||
))
|
||||
return id
|
||||
}
|
||||
a1 := insertAccount("a1")
|
||||
a2 := insertAccount("a2")
|
||||
a3 := insertAccount("a3")
|
||||
_, err := s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a1, g.ID, 1)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a2, g.ID, 2)
|
||||
s.Require().NoError(err)
|
||||
_, err = s.tx.ExecContext(s.ctx, "INSERT INTO account_groups (account_id, group_id, priority, created_at) VALUES ($1, $2, $3, NOW())", a3, g.ID, 3)
|
||||
s.Require().NoError(err)
|
||||
|
||||
affected, err := s.repo.DeleteAccountGroupsByGroupID(s.ctx, g.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(int64(3), affected)
|
||||
|
||||
count, _ := s.repo.GetAccountCount(s.ctx, g.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- 软删除过滤测试 ---
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete_SoftDelete_NotVisibleInList() {
|
||||
group := &service.Group{
|
||||
Name: "to-soft-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
// 获取删除前的列表数量
|
||||
listBefore, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
|
||||
s.Require().NoError(err)
|
||||
beforeCount := len(listBefore)
|
||||
|
||||
// 软删除
|
||||
err = s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "Delete (soft delete)")
|
||||
|
||||
// 验证列表中不再包含软删除的 group
|
||||
listAfter, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 100})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(listAfter, beforeCount-1, "soft deleted group should not appear in list")
|
||||
|
||||
// 验证 GetByID 也无法找到
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err)
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func (s *GroupRepoSuite) TestDelete_SoftDeletedGroup_lockForUpdate() {
|
||||
group := &service.Group{
|
||||
Name: "lock-soft-delete",
|
||||
Platform: service.PlatformAnthropic,
|
||||
RateMultiplier: 1.0,
|
||||
IsExclusive: false,
|
||||
Status: service.StatusActive,
|
||||
SubscriptionType: service.SubscriptionTypeStandard,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, group))
|
||||
|
||||
// 软删除
|
||||
err := s.repo.Delete(s.ctx, group.ID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 验证软删除的 group 在 GetByID 时返回 ErrGroupNotFound
|
||||
// 这证明 lockForUpdate 的 deleted_at IS NULL 过滤正在工作
|
||||
_, err = s.repo.GetByID(s.ctx, group.ID)
|
||||
s.Require().Error(err, "should fail to get soft-deleted group")
|
||||
s.Require().ErrorIs(err, service.ErrGroupNotFound)
|
||||
}
|
||||
@@ -0,0 +1,886 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||
)
|
||||
|
||||
// 默认配置常量
|
||||
// 这些值在配置文件未指定时作为回退默认值使用
|
||||
const (
|
||||
// directProxyKey: 无代理时的缓存键标识
|
||||
directProxyKey = "direct"
|
||||
// defaultMaxIdleConns: 默认最大空闲连接总数
|
||||
// HTTP/2 场景下,单连接可多路复用,240 足以支撑高并发
|
||||
defaultMaxIdleConns = 240
|
||||
// defaultMaxIdleConnsPerHost: 默认每主机最大空闲连接数
|
||||
defaultMaxIdleConnsPerHost = 120
|
||||
// defaultMaxConnsPerHost: 默认每主机最大连接数(含活跃连接)
|
||||
// 达到上限后新请求会等待,而非无限创建连接
|
||||
defaultMaxConnsPerHost = 240
|
||||
// defaultIdleConnTimeout: 默认空闲连接超时时间(90秒)
|
||||
// 超时后连接会被关闭,释放系统资源(建议小于上游 LB 超时)
|
||||
defaultIdleConnTimeout = 90 * time.Second
|
||||
// defaultResponseHeaderTimeout: 默认等待响应头超时时间(5分钟)
|
||||
// LLM 请求可能排队较久,需要较长超时
|
||||
defaultResponseHeaderTimeout = 300 * time.Second
|
||||
// defaultMaxUpstreamClients: 默认最大客户端缓存数量
|
||||
// 超出后会淘汰最久未使用的客户端
|
||||
defaultMaxUpstreamClients = 5000
|
||||
// defaultClientIdleTTLSeconds: 默认客户端空闲回收阈值(15分钟)
|
||||
defaultClientIdleTTLSeconds = 900
|
||||
)
|
||||
|
||||
var errUpstreamClientLimitReached = errors.New("upstream client cache limit reached")
|
||||
|
||||
// poolSettings 连接池配置参数
|
||||
// 封装 Transport 所需的各项连接池参数
|
||||
type poolSettings struct {
|
||||
maxIdleConns int // 最大空闲连接总数
|
||||
maxIdleConnsPerHost int // 每主机最大空闲连接数
|
||||
maxConnsPerHost int // 每主机最大连接数(含活跃)
|
||||
idleConnTimeout time.Duration // 空闲连接超时时间
|
||||
responseHeaderTimeout time.Duration // 等待响应头超时时间
|
||||
}
|
||||
|
||||
// upstreamClientEntry 上游客户端缓存条目
|
||||
// 记录客户端实例及其元数据,用于连接池管理和淘汰策略
|
||||
type upstreamClientEntry struct {
|
||||
client *http.Client // HTTP 客户端实例
|
||||
proxyKey string // 代理标识(用于检测代理变更)
|
||||
poolKey string // 连接池配置标识(用于检测配置变更)
|
||||
lastUsed int64 // 最后使用时间戳(纳秒),用于 LRU 淘汰
|
||||
inFlight int64 // 当前进行中的请求数,>0 时不可淘汰
|
||||
}
|
||||
|
||||
// httpUpstreamService 通用 HTTP 上游服务
|
||||
// 用于向任意 HTTP API(Claude、OpenAI 等)发送请求,支持可选代理
|
||||
//
|
||||
// 架构设计:
|
||||
// - 根据隔离策略(proxy/account/account_proxy)缓存客户端实例
|
||||
// - 每个客户端拥有独立的 Transport 连接池
|
||||
// - 支持 LRU + 空闲时间双重淘汰策略
|
||||
//
|
||||
// 性能优化:
|
||||
// 1. 根据隔离策略缓存客户端实例,避免频繁创建 http.Client
|
||||
// 2. 复用 Transport 连接池,减少 TCP 握手和 TLS 协商开销
|
||||
// 3. 支持账号级隔离与空闲回收,降低连接层关联风险
|
||||
// 4. 达到最大连接数后等待可用连接,而非无限创建
|
||||
// 5. 仅回收空闲客户端,避免中断活跃请求
|
||||
// 6. HTTP/2 多路复用,连接上限不等于并发请求上限
|
||||
// 7. 代理变更时清空旧连接池,避免复用错误代理
|
||||
// 8. 账号并发数与连接池上限对应(账号隔离策略下)
|
||||
type httpUpstreamService struct {
|
||||
cfg *config.Config // 全局配置
|
||||
mu sync.RWMutex // 保护 clients map 的读写锁
|
||||
clients map[string]*upstreamClientEntry // 客户端缓存池,key 由隔离策略决定
|
||||
}
|
||||
|
||||
// NewHTTPUpstream 创建通用 HTTP 上游服务
|
||||
// 使用配置中的连接池参数构建 Transport
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置,包含连接池参数和隔离策略
|
||||
//
|
||||
// 返回:
|
||||
// - service.HTTPUpstream 接口实现
|
||||
func NewHTTPUpstream(cfg *config.Config) service.HTTPUpstream {
|
||||
return &httpUpstreamService{
|
||||
cfg: cfg,
|
||||
clients: make(map[string]*upstreamClientEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Do 执行 HTTP 请求
|
||||
// 根据隔离策略获取或创建客户端,并跟踪请求生命周期
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Response: HTTP 响应(Body 已包装,关闭时自动更新计数)
|
||||
// - error: 请求错误
|
||||
//
|
||||
// 注意:
|
||||
// - 调用方必须关闭 resp.Body,否则会导致 inFlight 计数泄漏
|
||||
// - inFlight > 0 的客户端不会被淘汰,确保活跃请求不被中断
|
||||
func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) {
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取或创建对应的客户端,并标记请求占用
|
||||
entry, err := s.acquireClient(proxyURL, accountID, accountConcurrency)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
// 这确保了流式响应(如 SSE)在完全读取前不会被淘汰
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求
|
||||
// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - req: HTTP 请求对象
|
||||
// - proxyURL: 代理地址,空字符串表示直连
|
||||
// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择
|
||||
// - accountConcurrency: 账户并发限制,用于动态调整连接池大小
|
||||
// - enableTLSFingerprint: 是否启用 TLS 指纹伪装
|
||||
//
|
||||
// TLS 指纹说明:
|
||||
// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
// - 指纹模板根据 accountID % len(profiles) 自动选择
|
||||
// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景
|
||||
func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
// 如果未启用 TLS 指纹,直接使用标准请求路径
|
||||
if !enableTLSFingerprint {
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
// TLS 指纹已启用,记录调试日志
|
||||
targetHost := ""
|
||||
if req != nil && req.URL != nil {
|
||||
targetHost = req.URL.Host
|
||||
}
|
||||
proxyInfo := "direct"
|
||||
if proxyURL != "" {
|
||||
proxyInfo = proxyURL
|
||||
}
|
||||
slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo)
|
||||
|
||||
if err := s.validateRequestHost(req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取 TLS 指纹 Profile
|
||||
registry := tlsfingerprint.GlobalRegistry()
|
||||
profile := registry.GetProfileByAccountID(accountID)
|
||||
if profile == nil {
|
||||
// 如果获取不到 profile,回退到普通请求
|
||||
slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request")
|
||||
return s.Do(req, proxyURL, accountID, accountConcurrency)
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE)
|
||||
|
||||
// 获取或创建带 TLS 指纹的客户端
|
||||
entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile)
|
||||
if err != nil {
|
||||
slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 执行请求
|
||||
resp, err := entry.client.Do(req)
|
||||
if err != nil {
|
||||
// 请求失败,立即减少计数
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode)
|
||||
|
||||
// 包装响应体,在关闭时自动减少计数并更新时间戳
|
||||
resp.Body = wrapTrackedBody(resp.Body, func() {
|
||||
atomic.AddInt64(&entry.inFlight, -1)
|
||||
atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano())
|
||||
})
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端
|
||||
func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true)
|
||||
}
|
||||
|
||||
// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目
|
||||
// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离
|
||||
func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
isolation := s.getIsolationMode()
|
||||
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀
|
||||
cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls"
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey)
|
||||
return entry, nil
|
||||
}
|
||||
slog.Debug("tls_fingerprint_evicting_stale_client",
|
||||
"account_id", accountID,
|
||||
"cache_key", cacheKey,
|
||||
"proxy_changed", entry.proxyKey != proxyKey,
|
||||
"pool_changed", entry.poolKey != poolKey)
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 创建带 TLS 指纹的 Transport
|
||||
slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey)
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build TLS fingerprint transport: %w", err)
|
||||
}
|
||||
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) shouldValidateResolvedIP() bool {
|
||||
if s.cfg == nil {
|
||||
return false
|
||||
}
|
||||
if !s.cfg.Security.URLAllowlist.Enabled {
|
||||
return false
|
||||
}
|
||||
return !s.cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) validateRequestHost(req *http.Request) error {
|
||||
if !s.shouldValidateResolvedIP() {
|
||||
return nil
|
||||
}
|
||||
if req == nil || req.URL == nil {
|
||||
return errors.New("request url is nil")
|
||||
}
|
||||
host := strings.TrimSpace(req.URL.Hostname())
|
||||
if host == "" {
|
||||
return errors.New("request host is empty")
|
||||
}
|
||||
if err := urlvalidator.ValidateResolvedIP(host); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *httpUpstreamService) redirectChecker(req *http.Request, via []*http.Request) error {
|
||||
if len(via) >= 10 {
|
||||
return errors.New("stopped after 10 redirects")
|
||||
}
|
||||
return s.validateRequestHost(req)
|
||||
}
|
||||
|
||||
// acquireClient 获取或创建客户端,并标记为进行中请求
|
||||
// 用于请求路径,避免在获取后被淘汰
|
||||
func (s *httpUpstreamService) acquireClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, true, true)
|
||||
}
|
||||
|
||||
// getOrCreateClient 获取或创建客户端
|
||||
// 根据隔离策略和参数决定缓存键,处理代理变更和配置变更
|
||||
//
|
||||
// 参数:
|
||||
// - proxyURL: 代理地址
|
||||
// - accountID: 账户 ID
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - *upstreamClientEntry: 客户端缓存条目
|
||||
//
|
||||
// 隔离策略说明:
|
||||
// - proxy: 按代理地址隔离,同一代理共享客户端
|
||||
// - account: 按账户隔离,同一账户共享客户端(代理变更时重建)
|
||||
// - account_proxy: 按账户+代理组合隔离,最细粒度
|
||||
func (s *httpUpstreamService) getOrCreateClient(proxyURL string, accountID int64, accountConcurrency int) (*upstreamClientEntry, error) {
|
||||
return s.getClientEntry(proxyURL, accountID, accountConcurrency, false, false)
|
||||
}
|
||||
|
||||
// getClientEntry 获取或创建客户端条目
|
||||
// markInFlight=true 时会标记进行中请求,用于请求路径防止被淘汰
|
||||
// enforceLimit=true 时会限制客户端数量,超限且无法淘汰时返回错误
|
||||
func (s *httpUpstreamService) getClientEntry(proxyURL string, accountID int64, accountConcurrency int, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) {
|
||||
// 获取隔离模式
|
||||
isolation := s.getIsolationMode()
|
||||
// 标准化代理 URL 并解析
|
||||
proxyKey, parsedProxy, err := normalizeProxyURL(proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 构建缓存键(根据隔离策略不同)
|
||||
cacheKey := buildCacheKey(isolation, proxyKey, accountID)
|
||||
// 构建连接池配置键(用于检测配置变更)
|
||||
poolKey := s.buildPoolKey(isolation, accountConcurrency)
|
||||
|
||||
now := time.Now()
|
||||
nowUnix := now.UnixNano()
|
||||
|
||||
// 读锁快速路径:命中缓存直接返回,减少锁竞争
|
||||
s.mu.RLock()
|
||||
if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
return entry, nil
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
// 写锁慢路径:创建或重建客户端
|
||||
s.mu.Lock()
|
||||
if entry, ok := s.clients[cacheKey]; ok {
|
||||
if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) {
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.AddInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
s.removeClientLocked(cacheKey, entry)
|
||||
}
|
||||
|
||||
// 超出缓存上限时尝试淘汰,无法淘汰则拒绝新建
|
||||
if enforceLimit && s.maxUpstreamClients() > 0 {
|
||||
s.evictIdleLocked(now)
|
||||
if len(s.clients) >= s.maxUpstreamClients() {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
s.mu.Unlock()
|
||||
return nil, errUpstreamClientLimitReached
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 缓存未命中或需要重建,创建新客户端
|
||||
settings := s.resolvePoolSettings(isolation, accountConcurrency)
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return nil, fmt.Errorf("build transport: %w", err)
|
||||
}
|
||||
client := &http.Client{Transport: transport}
|
||||
if s.shouldValidateResolvedIP() {
|
||||
client.CheckRedirect = s.redirectChecker
|
||||
}
|
||||
entry := &upstreamClientEntry{
|
||||
client: client,
|
||||
proxyKey: proxyKey,
|
||||
poolKey: poolKey,
|
||||
}
|
||||
atomic.StoreInt64(&entry.lastUsed, nowUnix)
|
||||
if markInFlight {
|
||||
atomic.StoreInt64(&entry.inFlight, 1)
|
||||
}
|
||||
s.clients[cacheKey] = entry
|
||||
|
||||
// 执行淘汰策略:先淘汰空闲超时的,再淘汰超出数量限制的
|
||||
s.evictIdleLocked(now)
|
||||
s.evictOverLimitLocked()
|
||||
s.mu.Unlock()
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
// shouldReuseEntry 判断缓存条目是否可复用
|
||||
// 若代理或连接池配置发生变化,则需要重建客户端
|
||||
func (s *httpUpstreamService) shouldReuseEntry(entry *upstreamClientEntry, isolation, proxyKey, poolKey string) bool {
|
||||
if entry == nil {
|
||||
return false
|
||||
}
|
||||
if isolation == config.ConnectionPoolIsolationAccount && entry.proxyKey != proxyKey {
|
||||
return false
|
||||
}
|
||||
if entry.poolKey != poolKey {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// removeClientLocked 移除客户端(需持有锁)
|
||||
// 从缓存中删除并关闭空闲连接
|
||||
//
|
||||
// 参数:
|
||||
// - key: 缓存键
|
||||
// - entry: 客户端条目
|
||||
func (s *httpUpstreamService) removeClientLocked(key string, entry *upstreamClientEntry) {
|
||||
delete(s.clients, key)
|
||||
if entry != nil && entry.client != nil {
|
||||
// 关闭空闲连接,释放系统资源
|
||||
// 注意:这不会中断活跃连接
|
||||
entry.client.CloseIdleConnections()
|
||||
}
|
||||
}
|
||||
|
||||
// evictIdleLocked 淘汰空闲超时的客户端(需持有锁)
|
||||
// 遍历所有客户端,移除超过 TTL 且无活跃请求的条目
|
||||
//
|
||||
// 参数:
|
||||
// - now: 当前时间
|
||||
func (s *httpUpstreamService) evictIdleLocked(now time.Time) {
|
||||
ttl := s.clientIdleTTL()
|
||||
if ttl <= 0 {
|
||||
return
|
||||
}
|
||||
// 计算淘汰截止时间
|
||||
cutoff := now.Add(-ttl).UnixNano()
|
||||
for key, entry := range s.clients {
|
||||
// 跳过有活跃请求的客户端
|
||||
if atomic.LoadInt64(&entry.inFlight) != 0 {
|
||||
continue
|
||||
}
|
||||
// 淘汰超时的空闲客户端
|
||||
if atomic.LoadInt64(&entry.lastUsed) <= cutoff {
|
||||
s.removeClientLocked(key, entry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldestIdleLocked 淘汰最久未使用且无活跃请求的客户端(需持有锁)
|
||||
func (s *httpUpstreamService) evictOldestIdleLocked() bool {
|
||||
var (
|
||||
oldestKey string
|
||||
oldestEntry *upstreamClientEntry
|
||||
oldestTime int64
|
||||
)
|
||||
// 查找最久未使用且无活跃请求的客户端
|
||||
for key, entry := range s.clients {
|
||||
// 跳过有活跃请求的客户端
|
||||
if atomic.LoadInt64(&entry.inFlight) != 0 {
|
||||
continue
|
||||
}
|
||||
lastUsed := atomic.LoadInt64(&entry.lastUsed)
|
||||
if oldestEntry == nil || lastUsed < oldestTime {
|
||||
oldestKey = key
|
||||
oldestEntry = entry
|
||||
oldestTime = lastUsed
|
||||
}
|
||||
}
|
||||
// 所有客户端都有活跃请求,无法淘汰
|
||||
if oldestEntry == nil {
|
||||
return false
|
||||
}
|
||||
s.removeClientLocked(oldestKey, oldestEntry)
|
||||
return true
|
||||
}
|
||||
|
||||
// evictOverLimitLocked 淘汰超出数量限制的客户端(需持有锁)
|
||||
// 使用 LRU 策略,优先淘汰最久未使用且无活跃请求的客户端
|
||||
func (s *httpUpstreamService) evictOverLimitLocked() bool {
|
||||
maxClients := s.maxUpstreamClients()
|
||||
if maxClients <= 0 {
|
||||
return false
|
||||
}
|
||||
evicted := false
|
||||
// 循环淘汰直到满足数量限制
|
||||
for len(s.clients) > maxClients {
|
||||
if !s.evictOldestIdleLocked() {
|
||||
return evicted
|
||||
}
|
||||
evicted = true
|
||||
}
|
||||
return evicted
|
||||
}
|
||||
|
||||
// getIsolationMode 获取连接池隔离模式
|
||||
// 从配置中读取,无效值回退到 account_proxy 模式
|
||||
//
|
||||
// 返回:
|
||||
// - string: 隔离模式(proxy/account/account_proxy)
|
||||
func (s *httpUpstreamService) getIsolationMode() string {
|
||||
if s.cfg == nil {
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.ConnectionPoolIsolation))
|
||||
if mode == "" {
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
switch mode {
|
||||
case config.ConnectionPoolIsolationProxy, config.ConnectionPoolIsolationAccount, config.ConnectionPoolIsolationAccountProxy:
|
||||
return mode
|
||||
default:
|
||||
return config.ConnectionPoolIsolationAccountProxy
|
||||
}
|
||||
}
|
||||
|
||||
// maxUpstreamClients 获取最大客户端缓存数量
|
||||
// 从配置中读取,无效值使用默认值
|
||||
func (s *httpUpstreamService) maxUpstreamClients() int {
|
||||
if s.cfg == nil {
|
||||
return defaultMaxUpstreamClients
|
||||
}
|
||||
if s.cfg.Gateway.MaxUpstreamClients > 0 {
|
||||
return s.cfg.Gateway.MaxUpstreamClients
|
||||
}
|
||||
return defaultMaxUpstreamClients
|
||||
}
|
||||
|
||||
// clientIdleTTL 获取客户端空闲回收阈值
|
||||
// 从配置中读取,无效值使用默认值
|
||||
func (s *httpUpstreamService) clientIdleTTL() time.Duration {
|
||||
if s.cfg == nil {
|
||||
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
if s.cfg.Gateway.ClientIdleTTLSeconds > 0 {
|
||||
return time.Duration(s.cfg.Gateway.ClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
return time.Duration(defaultClientIdleTTLSeconds) * time.Second
|
||||
}
|
||||
|
||||
// resolvePoolSettings 解析连接池配置
|
||||
// 根据隔离策略和账户并发数动态调整连接池参数
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - poolSettings: 连接池配置
|
||||
//
|
||||
// 说明:
|
||||
// - 账户隔离模式下,连接池大小与账户并发数对应
|
||||
// - 这确保了单账户不会占用过多连接资源
|
||||
func (s *httpUpstreamService) resolvePoolSettings(isolation string, accountConcurrency int) poolSettings {
|
||||
settings := defaultPoolSettings(s.cfg)
|
||||
// 账户隔离模式下,根据账户并发数调整连接池大小
|
||||
if (isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy) && accountConcurrency > 0 {
|
||||
settings.maxIdleConns = accountConcurrency
|
||||
settings.maxIdleConnsPerHost = accountConcurrency
|
||||
settings.maxConnsPerHost = accountConcurrency
|
||||
}
|
||||
return settings
|
||||
}
|
||||
|
||||
// buildPoolKey 构建连接池配置键
|
||||
// 用于检测配置变更,配置变更时需要重建客户端
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - accountConcurrency: 账户并发限制
|
||||
//
|
||||
// 返回:
|
||||
// - string: 配置键
|
||||
func (s *httpUpstreamService) buildPoolKey(isolation string, accountConcurrency int) string {
|
||||
if isolation == config.ConnectionPoolIsolationAccount || isolation == config.ConnectionPoolIsolationAccountProxy {
|
||||
if accountConcurrency > 0 {
|
||||
return fmt.Sprintf("account:%d", accountConcurrency)
|
||||
}
|
||||
}
|
||||
return "default"
|
||||
}
|
||||
|
||||
// buildCacheKey 构建客户端缓存键
|
||||
// 根据隔离策略决定缓存键的组成
|
||||
//
|
||||
// 参数:
|
||||
// - isolation: 隔离模式
|
||||
// - proxyKey: 代理标识
|
||||
// - accountID: 账户 ID
|
||||
//
|
||||
// 返回:
|
||||
// - string: 缓存键
|
||||
//
|
||||
// 缓存键格式:
|
||||
// - proxy 模式: "proxy:{proxyKey}"
|
||||
// - account 模式: "account:{accountID}"
|
||||
// - account_proxy 模式: "account:{accountID}|proxy:{proxyKey}"
|
||||
func buildCacheKey(isolation, proxyKey string, accountID int64) string {
|
||||
switch isolation {
|
||||
case config.ConnectionPoolIsolationAccount:
|
||||
return fmt.Sprintf("account:%d", accountID)
|
||||
case config.ConnectionPoolIsolationAccountProxy:
|
||||
return fmt.Sprintf("account:%d|proxy:%s", accountID, proxyKey)
|
||||
default:
|
||||
return fmt.Sprintf("proxy:%s", proxyKey)
|
||||
}
|
||||
}
|
||||
|
||||
// normalizeProxyURL 标准化代理 URL
|
||||
// 处理空值和解析错误,返回标准化的键和解析后的 URL
|
||||
//
|
||||
// 参数:
|
||||
// - raw: 原始代理 URL 字符串
|
||||
//
|
||||
// 返回:
|
||||
// - string: 标准化的代理键(空返回 "direct")
|
||||
// - *url.URL: 解析后的 URL(空返回 nil)
|
||||
// - error: 非空代理 URL 解析失败时返回错误(禁止回退到直连)
|
||||
func normalizeProxyURL(raw string) (string, *url.URL, error) {
|
||||
_, parsed, err := proxyurl.Parse(raw)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
return directProxyKey, nil, nil
|
||||
}
|
||||
// 规范化:小写 scheme/host,去除路径和查询参数
|
||||
parsed.Scheme = strings.ToLower(parsed.Scheme)
|
||||
parsed.Host = strings.ToLower(parsed.Host)
|
||||
parsed.Path = ""
|
||||
parsed.RawPath = ""
|
||||
parsed.RawQuery = ""
|
||||
parsed.Fragment = ""
|
||||
parsed.ForceQuery = false
|
||||
if hostname := parsed.Hostname(); hostname != "" {
|
||||
port := parsed.Port()
|
||||
if (parsed.Scheme == "http" && port == "80") || (parsed.Scheme == "https" && port == "443") {
|
||||
port = ""
|
||||
}
|
||||
hostname = strings.ToLower(hostname)
|
||||
if port != "" {
|
||||
parsed.Host = net.JoinHostPort(hostname, port)
|
||||
} else {
|
||||
parsed.Host = hostname
|
||||
}
|
||||
}
|
||||
return parsed.String(), parsed, nil
|
||||
}
|
||||
|
||||
// defaultPoolSettings 获取默认连接池配置
|
||||
// 从全局配置中读取,无效值使用常量默认值
|
||||
//
|
||||
// 参数:
|
||||
// - cfg: 全局配置
|
||||
//
|
||||
// 返回:
|
||||
// - poolSettings: 连接池配置
|
||||
func defaultPoolSettings(cfg *config.Config) poolSettings {
|
||||
maxIdleConns := defaultMaxIdleConns
|
||||
maxIdleConnsPerHost := defaultMaxIdleConnsPerHost
|
||||
maxConnsPerHost := defaultMaxConnsPerHost
|
||||
idleConnTimeout := defaultIdleConnTimeout
|
||||
responseHeaderTimeout := defaultResponseHeaderTimeout
|
||||
|
||||
if cfg != nil {
|
||||
if cfg.Gateway.MaxIdleConns > 0 {
|
||||
maxIdleConns = cfg.Gateway.MaxIdleConns
|
||||
}
|
||||
if cfg.Gateway.MaxIdleConnsPerHost > 0 {
|
||||
maxIdleConnsPerHost = cfg.Gateway.MaxIdleConnsPerHost
|
||||
}
|
||||
if cfg.Gateway.MaxConnsPerHost >= 0 {
|
||||
maxConnsPerHost = cfg.Gateway.MaxConnsPerHost
|
||||
}
|
||||
if cfg.Gateway.IdleConnTimeoutSeconds > 0 {
|
||||
idleConnTimeout = time.Duration(cfg.Gateway.IdleConnTimeoutSeconds) * time.Second
|
||||
}
|
||||
if cfg.Gateway.ResponseHeaderTimeout > 0 {
|
||||
responseHeaderTimeout = time.Duration(cfg.Gateway.ResponseHeaderTimeout) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
return poolSettings{
|
||||
maxIdleConns: maxIdleConns,
|
||||
maxIdleConnsPerHost: maxIdleConnsPerHost,
|
||||
maxConnsPerHost: maxConnsPerHost,
|
||||
idleConnTimeout: idleConnTimeout,
|
||||
responseHeaderTimeout: responseHeaderTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
// buildUpstreamTransport 构建上游请求的 Transport
|
||||
// 使用配置文件中的连接池参数,支持生产环境调优
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 代理配置错误
|
||||
//
|
||||
// Transport 参数说明:
|
||||
// - MaxIdleConns: 所有主机的最大空闲连接总数
|
||||
// - MaxIdleConnsPerHost: 每主机最大空闲连接数(影响连接复用率)
|
||||
// - MaxConnsPerHost: 每主机最大连接数(达到后新请求等待)
|
||||
// - IdleConnTimeout: 空闲连接超时(超时后关闭)
|
||||
// - ResponseHeaderTimeout: 等待响应头超时(不影响流式传输)
|
||||
func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
}
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport
|
||||
// 使用 utls 库模拟 Claude CLI 的 TLS 指纹
|
||||
//
|
||||
// 参数:
|
||||
// - settings: 连接池配置
|
||||
// - proxyURL: 代理 URL(nil 表示直连)
|
||||
// - profile: TLS 指纹配置
|
||||
//
|
||||
// 返回:
|
||||
// - *http.Transport: 配置好的 Transport 实例
|
||||
// - error: 配置错误
|
||||
//
|
||||
// 代理类型处理:
|
||||
// - nil/空: 直连,使用 TLSFingerprintDialer
|
||||
// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手)
|
||||
// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手)
|
||||
func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) {
|
||||
transport := &http.Transport{
|
||||
MaxIdleConns: settings.maxIdleConns,
|
||||
MaxIdleConnsPerHost: settings.maxIdleConnsPerHost,
|
||||
MaxConnsPerHost: settings.maxConnsPerHost,
|
||||
IdleConnTimeout: settings.idleConnTimeout,
|
||||
ResponseHeaderTimeout: settings.responseHeaderTimeout,
|
||||
// 禁用默认的 TLS,我们使用自定义的 DialTLSContext
|
||||
ForceAttemptHTTP2: false,
|
||||
}
|
||||
|
||||
// 根据代理类型选择合适的 TLS 指纹 Dialer
|
||||
if proxyURL == nil {
|
||||
// 直连:使用 TLSFingerprintDialer
|
||||
slog.Debug("tls_fingerprint_transport_direct")
|
||||
dialer := tlsfingerprint.NewDialer(profile, nil)
|
||||
transport.DialTLSContext = dialer.DialTLSContext
|
||||
} else {
|
||||
scheme := strings.ToLower(proxyURL.Scheme)
|
||||
switch scheme {
|
||||
case "socks5", "socks5h":
|
||||
// SOCKS5 代理:使用 SOCKS5ProxyDialer
|
||||
slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host)
|
||||
socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = socks5Dialer.DialTLSContext
|
||||
case "http", "https":
|
||||
// HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道)
|
||||
slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host)
|
||||
httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL)
|
||||
transport.DialTLSContext = httpDialer.DialTLSContext
|
||||
default:
|
||||
// 未知代理类型,回退到普通代理配置(无 TLS 指纹)
|
||||
slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme)
|
||||
if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return transport, nil
|
||||
}
|
||||
|
||||
// trackedBody 带跟踪功能的响应体包装器
|
||||
// 在 Close 时执行回调,用于更新请求计数
|
||||
type trackedBody struct {
|
||||
io.ReadCloser // 原始响应体
|
||||
once sync.Once
|
||||
onClose func() // 关闭时的回调函数
|
||||
}
|
||||
|
||||
// Close 关闭响应体并执行回调
|
||||
// 使用 sync.Once 确保回调只执行一次
|
||||
func (b *trackedBody) Close() error {
|
||||
err := b.ReadCloser.Close()
|
||||
if b.onClose != nil {
|
||||
b.once.Do(b.onClose)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// wrapTrackedBody 包装响应体以跟踪关闭事件
|
||||
// 用于在响应体关闭时更新 inFlight 计数
|
||||
//
|
||||
// 参数:
|
||||
// - body: 原始响应体
|
||||
// - onClose: 关闭时的回调函数
|
||||
//
|
||||
// 返回:
|
||||
// - io.ReadCloser: 包装后的响应体
|
||||
func wrapTrackedBody(body io.ReadCloser, onClose func()) io.ReadCloser {
|
||||
if body == nil {
|
||||
return body
|
||||
}
|
||||
return &trackedBody{ReadCloser: body, onClose: onClose}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// httpClientSink 用于防止编译器优化掉基准测试中的赋值操作
|
||||
// 这是 Go 基准测试的常见模式,确保测试结果准确
|
||||
var httpClientSink *http.Client
|
||||
|
||||
// BenchmarkHTTPUpstreamProxyClient 对比重复创建与复用代理客户端的开销
|
||||
//
|
||||
// 测试目的:
|
||||
// - 验证连接池复用相比每次新建的性能提升
|
||||
// - 量化内存分配差异
|
||||
//
|
||||
// 预期结果:
|
||||
// - "复用" 子测试应显著快于 "新建"
|
||||
// - "复用" 子测试应零内存分配
|
||||
func BenchmarkHTTPUpstreamProxyClient(b *testing.B) {
|
||||
// 创建测试配置
|
||||
cfg := &config.Config{
|
||||
Gateway: config.GatewayConfig{ResponseHeaderTimeout: 300},
|
||||
}
|
||||
upstream := NewHTTPUpstream(cfg)
|
||||
svc, ok := upstream.(*httpUpstreamService)
|
||||
if !ok {
|
||||
b.Fatalf("类型断言失败,无法获取 httpUpstreamService")
|
||||
}
|
||||
|
||||
proxyURL := "http://127.0.0.1:8080"
|
||||
b.ReportAllocs() // 报告内存分配统计
|
||||
|
||||
// 子测试:每次新建客户端
|
||||
// 模拟未优化前的行为,每次请求都创建新的 http.Client
|
||||
b.Run("新建", func(b *testing.B) {
|
||||
parsedProxy, err := url.Parse(proxyURL)
|
||||
if err != nil {
|
||||
b.Fatalf("解析代理地址失败: %v", err)
|
||||
}
|
||||
settings := defaultPoolSettings(cfg)
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 每次迭代都创建新客户端,包含 Transport 分配
|
||||
transport, err := buildUpstreamTransport(settings, parsedProxy)
|
||||
if err != nil {
|
||||
b.Fatalf("创建 Transport 失败: %v", err)
|
||||
}
|
||||
httpClientSink = &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 子测试:复用已缓存的客户端
|
||||
// 模拟优化后的行为,从缓存获取客户端
|
||||
b.Run("复用", func(b *testing.B) {
|
||||
// 预热:确保客户端已缓存
|
||||
entry, err := svc.getOrCreateClient(proxyURL, 1, 1)
|
||||
if err != nil {
|
||||
b.Fatalf("getOrCreateClient: %v", err)
|
||||
}
|
||||
client := entry.client
|
||||
b.ResetTimer() // 重置计时器,排除预热时间
|
||||
for i := 0; i < b.N; i++ {
|
||||
// 直接使用缓存的客户端,无内存分配
|
||||
httpClientSink = client
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,301 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// HTTPUpstreamSuite HTTP 上游服务测试套件
|
||||
// 使用 testify/suite 组织测试,支持 SetupTest 初始化
|
||||
type HTTPUpstreamSuite struct {
|
||||
suite.Suite
|
||||
cfg *config.Config // 测试用配置
|
||||
}
|
||||
|
||||
// SetupTest 每个测试用例执行前的初始化
|
||||
// 创建空配置,各测试用例可按需覆盖
|
||||
func (s *HTTPUpstreamSuite) SetupTest() {
|
||||
s.cfg = &config.Config{
|
||||
Security: config.SecurityConfig{
|
||||
URLAllowlist: config.URLAllowlistConfig{
|
||||
AllowPrivateHosts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// newService 创建测试用的 httpUpstreamService 实例
|
||||
// 返回具体类型以便访问内部状态进行断言
|
||||
func (s *HTTPUpstreamSuite) newService() *httpUpstreamService {
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
svc, ok := up.(*httpUpstreamService)
|
||||
require.True(s.T(), ok, "expected *httpUpstreamService")
|
||||
return svc
|
||||
}
|
||||
|
||||
// TestDefaultResponseHeaderTimeout 测试默认响应头超时配置
|
||||
// 验证未配置时使用 300 秒默认值
|
||||
func (s *HTTPUpstreamSuite) TestDefaultResponseHeaderTimeout() {
|
||||
svc := s.newService()
|
||||
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 300*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
// TestCustomResponseHeaderTimeout 测试自定义响应头超时配置
|
||||
// 验证配置值能正确应用到 Transport
|
||||
func (s *HTTPUpstreamSuite) TestCustomResponseHeaderTimeout() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 7}
|
||||
svc := s.newService()
|
||||
entry := mustGetOrCreateClient(s.T(), svc, "", 0, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 7*time.Second, transport.ResponseHeaderTimeout, "ResponseHeaderTimeout mismatch")
|
||||
}
|
||||
|
||||
// TestGetOrCreateClient_InvalidURLReturnsError 测试无效代理 URL 返回错误
|
||||
// 验证解析失败时拒绝回退到直连模式
|
||||
func (s *HTTPUpstreamSuite) TestGetOrCreateClient_InvalidURLReturnsError() {
|
||||
svc := s.newService()
|
||||
_, err := svc.getClientEntry("://bad-proxy-url", 1, 1, false, false)
|
||||
require.Error(s.T(), err, "expected error for invalid proxy URL")
|
||||
}
|
||||
|
||||
// TestNormalizeProxyURL_Canonicalizes 测试代理 URL 规范化
|
||||
// 验证等价地址能够映射到同一缓存键
|
||||
func (s *HTTPUpstreamSuite) TestNormalizeProxyURL_Canonicalizes() {
|
||||
key1, _, err1 := normalizeProxyURL("http://proxy.local:8080")
|
||||
require.NoError(s.T(), err1)
|
||||
key2, _, err2 := normalizeProxyURL("http://proxy.local:8080/")
|
||||
require.NoError(s.T(), err2)
|
||||
require.Equal(s.T(), key1, key2, "expected normalized proxy keys to match")
|
||||
}
|
||||
|
||||
// TestAcquireClient_OverLimitReturnsError 测试连接池缓存上限保护
|
||||
// 验证超限且无可淘汰条目时返回错误
|
||||
func (s *HTTPUpstreamSuite) TestAcquireClient_OverLimitReturnsError() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
|
||||
MaxUpstreamClients: 1,
|
||||
}
|
||||
svc := s.newService()
|
||||
entry1, err := svc.acquireClient("http://proxy-a:8080", 1, 1)
|
||||
require.NoError(s.T(), err, "expected first acquire to succeed")
|
||||
require.NotNil(s.T(), entry1, "expected entry")
|
||||
|
||||
entry2, err := svc.acquireClient("http://proxy-b:8080", 2, 1)
|
||||
require.Error(s.T(), err, "expected error when cache limit reached")
|
||||
require.Nil(s.T(), entry2, "expected nil entry when cache limit reached")
|
||||
}
|
||||
|
||||
// TestDo_WithoutProxy_GoesDirect 测试无代理时直连
|
||||
// 验证空代理 URL 时请求直接发送到目标服务器
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithoutProxy_GoesDirect() {
|
||||
// 创建模拟上游服务器
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/x", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "", 1, 1)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct", string(b), "unexpected body")
|
||||
}
|
||||
|
||||
// TestDo_WithHTTPProxy_UsesProxy 测试 HTTP 代理功能
|
||||
// 验证请求通过代理服务器转发,使用绝对 URI 格式
|
||||
func (s *HTTPUpstreamSuite) TestDo_WithHTTPProxy_UsesProxy() {
|
||||
// 用于接收代理请求的通道
|
||||
seen := make(chan string, 1)
|
||||
// 创建模拟代理服务器
|
||||
proxySrv := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
seen <- r.RequestURI // 记录请求 URI
|
||||
_, _ = io.WriteString(w, "proxied")
|
||||
}))
|
||||
s.T().Cleanup(proxySrv.Close)
|
||||
|
||||
s.cfg.Gateway = config.GatewayConfig{ResponseHeaderTimeout: 1}
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
|
||||
// 发送请求到外部地址,应通过代理
|
||||
req, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, proxySrv.URL, 1, 1)
|
||||
require.NoError(s.T(), err, "Do")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "proxied", string(b), "unexpected body")
|
||||
|
||||
// 验证代理收到的是绝对 URI 格式(HTTP 代理规范要求)
|
||||
select {
|
||||
case uri := <-seen:
|
||||
require.Equal(s.T(), "http://example.com/test", uri, "expected absolute-form request URI")
|
||||
default:
|
||||
require.Fail(s.T(), "expected proxy to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDo_EmptyProxy_UsesDirect 测试空代理字符串
|
||||
// 验证空字符串代理等同于直连
|
||||
func (s *HTTPUpstreamSuite) TestDo_EmptyProxy_UsesDirect() {
|
||||
upstream := newLocalTestServer(s.T(), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_, _ = io.WriteString(w, "direct-empty")
|
||||
}))
|
||||
s.T().Cleanup(upstream.Close)
|
||||
|
||||
up := NewHTTPUpstream(s.cfg)
|
||||
req, err := http.NewRequest(http.MethodGet, upstream.URL+"/y", nil)
|
||||
require.NoError(s.T(), err, "NewRequest")
|
||||
resp, err := up.Do(req, "", 1, 1)
|
||||
require.NoError(s.T(), err, "Do with empty proxy")
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
b, _ := io.ReadAll(resp.Body)
|
||||
require.Equal(s.T(), "direct-empty", string(b))
|
||||
}
|
||||
|
||||
// TestAccountIsolation_DifferentAccounts 测试账户隔离模式
|
||||
// 验证不同账户使用独立的连接池
|
||||
func (s *HTTPUpstreamSuite) TestAccountIsolation_DifferentAccounts() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 同一代理,不同账户
|
||||
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 1, 3)
|
||||
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy.local:8080", 2, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "不同账号不应共享连接池")
|
||||
require.Equal(s.T(), 2, len(svc.clients), "账号隔离应缓存两个客户端")
|
||||
}
|
||||
|
||||
// TestAccountProxyIsolation_DifferentProxy 测试账户+代理组合隔离模式
|
||||
// 验证同一账户使用不同代理时创建独立连接池
|
||||
func (s *HTTPUpstreamSuite) TestAccountProxyIsolation_DifferentProxy() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy}
|
||||
svc := s.newService()
|
||||
// 同一账户,不同代理
|
||||
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
|
||||
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "账号+代理隔离应区分不同代理")
|
||||
require.Equal(s.T(), 2, len(svc.clients), "账号+代理隔离应缓存两个客户端")
|
||||
}
|
||||
|
||||
// TestAccountModeProxyChangeClearsPool 测试账户模式下代理变更
|
||||
// 验证账户切换代理时清理旧连接池,避免复用错误代理
|
||||
func (s *HTTPUpstreamSuite) TestAccountModeProxyChangeClearsPool() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 同一账户,先后使用不同代理
|
||||
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 3)
|
||||
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 1, 3)
|
||||
require.NotSame(s.T(), entry1, entry2, "账号切换代理应创建新连接池")
|
||||
require.Equal(s.T(), 1, len(svc.clients), "账号模式下应仅保留一个连接池")
|
||||
require.False(s.T(), hasEntry(svc, entry1), "旧连接池应被清理")
|
||||
}
|
||||
|
||||
// TestAccountConcurrencyOverridesPoolSettings 测试账户并发数覆盖连接池配置
|
||||
// 验证账户隔离模式下,连接池大小与账户并发数对应
|
||||
func (s *HTTPUpstreamSuite) TestAccountConcurrencyOverridesPoolSettings() {
|
||||
s.cfg.Gateway = config.GatewayConfig{ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount}
|
||||
svc := s.newService()
|
||||
// 账户并发数为 12
|
||||
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 12)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
// 连接池参数应与并发数一致
|
||||
require.Equal(s.T(), 12, transport.MaxConnsPerHost, "MaxConnsPerHost mismatch")
|
||||
require.Equal(s.T(), 12, transport.MaxIdleConns, "MaxIdleConns mismatch")
|
||||
require.Equal(s.T(), 12, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost mismatch")
|
||||
}
|
||||
|
||||
// TestAccountConcurrencyFallbackToDefault 测试账户并发数为 0 时回退到默认配置
|
||||
// 验证未指定并发数时使用全局配置值
|
||||
func (s *HTTPUpstreamSuite) TestAccountConcurrencyFallbackToDefault() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
|
||||
MaxIdleConns: 77,
|
||||
MaxIdleConnsPerHost: 55,
|
||||
MaxConnsPerHost: 66,
|
||||
}
|
||||
svc := s.newService()
|
||||
// 账户并发数为 0,应使用全局配置
|
||||
entry := mustGetOrCreateClient(s.T(), svc, "", 1, 0)
|
||||
transport, ok := entry.client.Transport.(*http.Transport)
|
||||
require.True(s.T(), ok, "expected *http.Transport")
|
||||
require.Equal(s.T(), 66, transport.MaxConnsPerHost, "MaxConnsPerHost fallback mismatch")
|
||||
require.Equal(s.T(), 77, transport.MaxIdleConns, "MaxIdleConns fallback mismatch")
|
||||
require.Equal(s.T(), 55, transport.MaxIdleConnsPerHost, "MaxIdleConnsPerHost fallback mismatch")
|
||||
}
|
||||
|
||||
// TestEvictOverLimitRemovesOldestIdle 测试超出数量限制时的 LRU 淘汰
|
||||
// 验证优先淘汰最久未使用的空闲客户端
|
||||
func (s *HTTPUpstreamSuite) TestEvictOverLimitRemovesOldestIdle() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccountProxy,
|
||||
MaxUpstreamClients: 2, // 最多缓存 2 个客户端
|
||||
}
|
||||
svc := s.newService()
|
||||
// 创建两个客户端,设置不同的最后使用时间
|
||||
entry1 := mustGetOrCreateClient(s.T(), svc, "http://proxy-a:8080", 1, 1)
|
||||
entry2 := mustGetOrCreateClient(s.T(), svc, "http://proxy-b:8080", 2, 1)
|
||||
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Hour).UnixNano()) // 最久
|
||||
atomic.StoreInt64(&entry2.lastUsed, time.Now().Add(-time.Hour).UnixNano())
|
||||
// 创建第三个客户端,触发淘汰
|
||||
_ = mustGetOrCreateClient(s.T(), svc, "http://proxy-c:8080", 3, 1)
|
||||
|
||||
require.LessOrEqual(s.T(), len(svc.clients), 2, "应保持在缓存上限内")
|
||||
require.False(s.T(), hasEntry(svc, entry1), "最久未使用的连接池应被清理")
|
||||
}
|
||||
|
||||
// TestIdleTTLDoesNotEvictActive 测试活跃请求保护
|
||||
// 验证有进行中请求的客户端不会被空闲超时淘汰
|
||||
func (s *HTTPUpstreamSuite) TestIdleTTLDoesNotEvictActive() {
|
||||
s.cfg.Gateway = config.GatewayConfig{
|
||||
ConnectionPoolIsolation: config.ConnectionPoolIsolationAccount,
|
||||
ClientIdleTTLSeconds: 1, // 1 秒空闲超时
|
||||
}
|
||||
svc := s.newService()
|
||||
entry1 := mustGetOrCreateClient(s.T(), svc, "", 1, 1)
|
||||
// 设置为很久之前使用,但有活跃请求
|
||||
atomic.StoreInt64(&entry1.lastUsed, time.Now().Add(-2*time.Minute).UnixNano())
|
||||
atomic.StoreInt64(&entry1.inFlight, 1) // 模拟有活跃请求
|
||||
// 创建新客户端,触发淘汰检查
|
||||
_, _ = svc.getOrCreateClient("", 2, 1)
|
||||
|
||||
require.True(s.T(), hasEntry(svc, entry1), "有活跃请求时不应回收")
|
||||
}
|
||||
|
||||
// TestHTTPUpstreamSuite 运行测试套件
|
||||
func TestHTTPUpstreamSuite(t *testing.T) {
|
||||
suite.Run(t, new(HTTPUpstreamSuite))
|
||||
}
|
||||
|
||||
// mustGetOrCreateClient 测试辅助函数,调用 getOrCreateClient 并断言无错误
|
||||
func mustGetOrCreateClient(t *testing.T, svc *httpUpstreamService, proxyURL string, accountID int64, concurrency int) *upstreamClientEntry {
|
||||
t.Helper()
|
||||
entry, err := svc.getOrCreateClient(proxyURL, accountID, concurrency)
|
||||
require.NoError(t, err, "getOrCreateClient(%q, %d, %d)", proxyURL, accountID, concurrency)
|
||||
return entry
|
||||
}
|
||||
|
||||
// hasEntry 检查客户端是否存在于缓存中
|
||||
// 辅助函数,用于验证淘汰逻辑
|
||||
func hasEntry(svc *httpUpstreamService, target *upstreamClientEntry) bool {
|
||||
for _, entry := range svc.clients {
|
||||
if entry == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
@@ -0,0 +1,237 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type idempotencyRepository struct {
|
||||
sql sqlExecutor
|
||||
}
|
||||
|
||||
func NewIdempotencyRepository(_ *dbent.Client, sqlDB *sql.DB) service.IdempotencyRepository {
|
||||
return &idempotencyRepository{sql: sqlDB}
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) CreateProcessing(ctx context.Context, record *service.IdempotencyRecord) (bool, error) {
|
||||
if record == nil {
|
||||
return false, nil
|
||||
}
|
||||
query := `
|
||||
INSERT INTO idempotency_records (
|
||||
scope, idempotency_key_hash, request_fingerprint, status, locked_until, expires_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6)
|
||||
ON CONFLICT (scope, idempotency_key_hash) DO NOTHING
|
||||
RETURNING id, created_at, updated_at
|
||||
`
|
||||
var createdAt time.Time
|
||||
var updatedAt time.Time
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{
|
||||
record.Scope,
|
||||
record.IdempotencyKeyHash,
|
||||
record.RequestFingerprint,
|
||||
record.Status,
|
||||
record.LockedUntil,
|
||||
record.ExpiresAt,
|
||||
}, &record.ID, &createdAt, &updatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
record.CreatedAt = createdAt
|
||||
record.UpdatedAt = updatedAt
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) GetByScopeAndKeyHash(ctx context.Context, scope, keyHash string) (*service.IdempotencyRecord, error) {
|
||||
query := `
|
||||
SELECT
|
||||
id, scope, idempotency_key_hash, request_fingerprint, status, response_status,
|
||||
response_body, error_reason, locked_until, expires_at, created_at, updated_at
|
||||
FROM idempotency_records
|
||||
WHERE scope = $1 AND idempotency_key_hash = $2
|
||||
`
|
||||
record := &service.IdempotencyRecord{}
|
||||
var responseStatus sql.NullInt64
|
||||
var responseBody sql.NullString
|
||||
var errorReason sql.NullString
|
||||
var lockedUntil sql.NullTime
|
||||
err := scanSingleRow(ctx, r.sql, query, []any{scope, keyHash},
|
||||
&record.ID,
|
||||
&record.Scope,
|
||||
&record.IdempotencyKeyHash,
|
||||
&record.RequestFingerprint,
|
||||
&record.Status,
|
||||
&responseStatus,
|
||||
&responseBody,
|
||||
&errorReason,
|
||||
&lockedUntil,
|
||||
&record.ExpiresAt,
|
||||
&record.CreatedAt,
|
||||
&record.UpdatedAt,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if responseStatus.Valid {
|
||||
v := int(responseStatus.Int64)
|
||||
record.ResponseStatus = &v
|
||||
}
|
||||
if responseBody.Valid {
|
||||
v := responseBody.String
|
||||
record.ResponseBody = &v
|
||||
}
|
||||
if errorReason.Valid {
|
||||
v := errorReason.String
|
||||
record.ErrorReason = &v
|
||||
}
|
||||
if lockedUntil.Valid {
|
||||
v := lockedUntil.Time
|
||||
record.LockedUntil = &v
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) TryReclaim(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
fromStatus string,
|
||||
now, newLockedUntil, newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
locked_until = $3,
|
||||
error_reason = NULL,
|
||||
updated_at = NOW(),
|
||||
expires_at = $4
|
||||
WHERE id = $1
|
||||
AND status = $5
|
||||
AND (locked_until IS NULL OR locked_until <= $6)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusProcessing,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
fromStatus,
|
||||
now,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) ExtendProcessingLock(
|
||||
ctx context.Context,
|
||||
id int64,
|
||||
requestFingerprint string,
|
||||
newLockedUntil,
|
||||
newExpiresAt time.Time,
|
||||
) (bool, error) {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET locked_until = $2,
|
||||
expires_at = $3,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
AND status = $4
|
||||
AND request_fingerprint = $5
|
||||
`
|
||||
res, err := r.sql.ExecContext(
|
||||
ctx,
|
||||
query,
|
||||
id,
|
||||
newLockedUntil,
|
||||
newExpiresAt,
|
||||
service.IdempotencyStatusProcessing,
|
||||
requestFingerprint,
|
||||
)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return affected > 0, nil
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkSucceeded(ctx context.Context, id int64, responseStatus int, responseBody string, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
response_status = $3,
|
||||
response_body = $4,
|
||||
error_reason = NULL,
|
||||
locked_until = NULL,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusSucceeded,
|
||||
responseStatus,
|
||||
responseBody,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) MarkFailedRetryable(ctx context.Context, id int64, errorReason string, lockedUntil, expiresAt time.Time) error {
|
||||
query := `
|
||||
UPDATE idempotency_records
|
||||
SET status = $2,
|
||||
error_reason = $3,
|
||||
locked_until = $4,
|
||||
expires_at = $5,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
`
|
||||
_, err := r.sql.ExecContext(ctx, query,
|
||||
id,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
errorReason,
|
||||
lockedUntil,
|
||||
expiresAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *idempotencyRepository) DeleteExpired(ctx context.Context, now time.Time, limit int) (int64, error) {
|
||||
if limit <= 0 {
|
||||
limit = 500
|
||||
}
|
||||
query := `
|
||||
WITH victims AS (
|
||||
SELECT id
|
||||
FROM idempotency_records
|
||||
WHERE expires_at <= $1
|
||||
ORDER BY expires_at ASC
|
||||
LIMIT $2
|
||||
)
|
||||
DELETE FROM idempotency_records
|
||||
WHERE id IN (SELECT id FROM victims)
|
||||
`
|
||||
res, err := r.sql.ExecContext(ctx, query, now, limit)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// hashedTestValue returns a unique SHA-256 hex string (64 chars) that fits VARCHAR(64) columns.
|
||||
func hashedTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
sum := sha256.Sum256([]byte(uniqueTestValue(t, prefix)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_CreateProcessing_CompeteSameKey(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-create"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
require.NotZero(t, record.ID)
|
||||
|
||||
duplicate := &service.IdempotencyRecord{
|
||||
Scope: record.Scope,
|
||||
IdempotencyKeyHash: record.IdempotencyKeyHash,
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-other"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(30 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err = repo.CreateProcessing(ctx, duplicate)
|
||||
require.NoError(t, err)
|
||||
require.False(t, owner, "same scope+key hash should be de-duplicated")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_TryReclaim_StatusAndLockWindow(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-reclaim"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-reclaim"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-reclaim"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(-2*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
newLockedUntil := now.Add(20 * time.Second)
|
||||
reclaimed, err := repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
newLockedUntil,
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.True(t, reclaimed, "failed_retryable + expired lock should allow reclaim")
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusProcessing, got.Status)
|
||||
require.NotNil(t, got.LockedUntil)
|
||||
require.True(t, got.LockedUntil.After(now))
|
||||
|
||||
require.NoError(t, repo.MarkFailedRetryable(
|
||||
ctx,
|
||||
record.ID,
|
||||
"RETRYABLE_FAILURE",
|
||||
now.Add(20*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
))
|
||||
|
||||
reclaimed, err = repo.TryReclaim(
|
||||
ctx,
|
||||
record.ID,
|
||||
service.IdempotencyStatusFailedRetryable,
|
||||
now,
|
||||
now.Add(40*time.Second),
|
||||
now.Add(24*time.Hour),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.False(t, reclaimed, "within lock window should not reclaim")
|
||||
}
|
||||
|
||||
func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
repo := &idempotencyRepository{sql: tx}
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now().UTC()
|
||||
record := &service.IdempotencyRecord{
|
||||
Scope: uniqueTestValue(t, "idem-scope-success"),
|
||||
IdempotencyKeyHash: hashedTestValue(t, "idem-hash-success"),
|
||||
RequestFingerprint: hashedTestValue(t, "idem-fp-success"),
|
||||
Status: service.IdempotencyStatusProcessing,
|
||||
LockedUntil: ptrTime(now.Add(10 * time.Second)),
|
||||
ExpiresAt: now.Add(24 * time.Hour),
|
||||
}
|
||||
owner, err := repo.CreateProcessing(ctx, record)
|
||||
require.NoError(t, err)
|
||||
require.True(t, owner)
|
||||
|
||||
require.NoError(t, repo.MarkSucceeded(ctx, record.ID, 200, `{"ok":true}`, now.Add(24*time.Hour)))
|
||||
|
||||
got, err := repo.GetByScopeAndKeyHash(ctx, record.Scope, record.IdempotencyKeyHash)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, service.IdempotencyStatusSucceeded, got.Status)
|
||||
require.NotNil(t, got.ResponseStatus)
|
||||
require.Equal(t, 200, *got.ResponseStatus)
|
||||
require.NotNil(t, got.ResponseBody)
|
||||
require.Equal(t, `{"ok":true}`, *got.ResponseBody)
|
||||
require.Nil(t, got.LockedUntil)
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
fingerprintKeyPrefix = "fingerprint:"
|
||||
fingerprintTTL = 7 * 24 * time.Hour // 7天,配合每24小时懒续期可保持活跃账号永不过期
|
||||
maskedSessionKeyPrefix = "masked_session:"
|
||||
maskedSessionTTL = 15 * time.Minute
|
||||
)
|
||||
|
||||
// fingerprintKey generates the Redis key for account fingerprint cache.
|
||||
func fingerprintKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
// maskedSessionKey generates the Redis key for masked session ID cache.
|
||||
func maskedSessionKey(accountID int64) string {
|
||||
return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID)
|
||||
}
|
||||
|
||||
type identityCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewIdentityCache(rdb *redis.Client) service.IdentityCache {
|
||||
return &identityCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *identityCache) GetFingerprint(ctx context.Context, accountID int64) (*service.Fingerprint, error) {
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var fp service.Fingerprint
|
||||
if err := json.Unmarshal([]byte(val), &fp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &fp, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp *service.Fingerprint) error {
|
||||
key := fingerprintKey(accountID)
|
||||
val, err := json.Marshal(fp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, key, val, fingerprintTTL).Err()
|
||||
}
|
||||
|
||||
func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) {
|
||||
key := maskedSessionKey(accountID)
|
||||
val, err := c.rdb.Get(ctx, key).Result()
|
||||
if err != nil {
|
||||
if err == redis.Nil {
|
||||
return "", nil
|
||||
}
|
||||
return "", err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error {
|
||||
key := maskedSessionKey(accountID)
|
||||
return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err()
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type IdentityCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *identityCache
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||
require.NoError(s.T(), err, "TTL fpKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||
}
|
||||
|
||||
func TestIdentityCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(IdentityCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_account_id",
|
||||
accountID: 123,
|
||||
expected: "fingerprint:123",
|
||||
},
|
||||
{
|
||||
name: "zero_account_id",
|
||||
accountID: 0,
|
||||
expected: "fingerprint:0",
|
||||
},
|
||||
{
|
||||
name: "negative_account_id",
|
||||
accountID: -1,
|
||||
expected: "fingerprint:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
accountID: math.MaxInt64,
|
||||
expected: "fingerprint:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := fingerprintKey(tc.accountID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
|
||||
// It captures the request body (if any) and then rewinds it before invoking the handler.
|
||||
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
|
||||
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
var body []byte
|
||||
if r.Body != nil {
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
if capture != nil {
|
||||
capture(r, body)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, r)
|
||||
return rec.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
canListenOnce sync.Once
|
||||
canListen bool
|
||||
canListenErr error
|
||||
)
|
||||
|
||||
func localListenerAvailable() bool {
|
||||
canListenOnce.Do(func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
canListenErr = err
|
||||
canListen = false
|
||||
return
|
||||
}
|
||||
_ = ln.Close()
|
||||
canListen = true
|
||||
})
|
||||
return canListen
|
||||
}
|
||||
|
||||
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
|
||||
tb.Helper()
|
||||
if !localListenerAvailable() {
|
||||
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
@@ -0,0 +1,408 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
_ "github.com/Wei-Shaw/sub2api/ent/runtime"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
|
||||
"entgo.io/ent/dialect"
|
||||
entsql "entgo.io/ent/dialect/sql"
|
||||
_ "github.com/lib/pq"
|
||||
redisclient "github.com/redis/go-redis/v9"
|
||||
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
|
||||
tcredis "github.com/testcontainers/testcontainers-go/modules/redis"
|
||||
)
|
||||
|
||||
const (
|
||||
redisImageTag = "redis:8.4-alpine"
|
||||
postgresImageTag = "postgres:18.1-alpine3.23"
|
||||
)
|
||||
|
||||
var (
|
||||
integrationDB *sql.DB
|
||||
integrationEntClient *dbent.Client
|
||||
integrationRedis *redisclient.Client
|
||||
|
||||
redisNamespaceSeq uint64
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
ctx := context.Background()
|
||||
|
||||
if err := timezone.Init("UTC"); err != nil {
|
||||
log.Printf("failed to init timezone: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if !dockerIsAvailable(ctx) {
|
||||
// In CI we expect Docker to be available so integration tests should fail loudly.
|
||||
if os.Getenv("CI") != "" {
|
||||
log.Printf("docker is not available (CI=true); failing integration tests")
|
||||
os.Exit(1)
|
||||
}
|
||||
log.Printf("docker is not available; skipping integration tests (start Docker to enable)")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
postgresImage := selectDockerImage(ctx, postgresImageTag)
|
||||
pgContainer, err := tcpostgres.Run(
|
||||
ctx,
|
||||
postgresImage,
|
||||
tcpostgres.WithDatabase("sub2api_test"),
|
||||
tcpostgres.WithUsername("postgres"),
|
||||
tcpostgres.WithPassword("postgres"),
|
||||
tcpostgres.BasicWaitStrategies(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start postgres container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = pgContainer.Terminate(ctx) }()
|
||||
|
||||
redisContainer, err := tcredis.Run(
|
||||
ctx,
|
||||
redisImageTag,
|
||||
)
|
||||
if err != nil {
|
||||
log.Printf("failed to start redis container: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer func() { _ = redisContainer.Terminate(ctx) }()
|
||||
|
||||
dsn, err := pgContainer.ConnectionString(ctx, "sslmode=disable", "TimeZone=UTC")
|
||||
if err != nil {
|
||||
log.Printf("failed to get postgres dsn: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationDB, err = openSQLWithRetry(ctx, dsn, 30*time.Second)
|
||||
if err != nil {
|
||||
log.Printf("failed to open sql db: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := ApplyMigrations(ctx, integrationDB); err != nil {
|
||||
log.Printf("failed to apply db migrations: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// 创建 ent client 用于集成测试
|
||||
drv := entsql.OpenDB(dialect.Postgres, integrationDB)
|
||||
integrationEntClient = dbent.NewClient(dbent.Driver(drv))
|
||||
|
||||
redisHost, err := redisContainer.Host(ctx)
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis host: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
redisPort, err := redisContainer.MappedPort(ctx, "6379/tcp")
|
||||
if err != nil {
|
||||
log.Printf("failed to get redis port: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
integrationRedis = redisclient.NewClient(&redisclient.Options{
|
||||
Addr: fmt.Sprintf("%s:%d", redisHost, redisPort.Int()),
|
||||
DB: 0,
|
||||
})
|
||||
if err := integrationRedis.Ping(ctx).Err(); err != nil {
|
||||
log.Printf("failed to ping redis: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
_ = integrationEntClient.Close()
|
||||
_ = integrationRedis.Close()
|
||||
_ = integrationDB.Close()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func dockerIsAvailable(ctx context.Context) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "info")
|
||||
cmd.Env = os.Environ()
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func selectDockerImage(ctx context.Context, preferred string) string {
|
||||
if dockerImageExists(ctx, preferred) {
|
||||
return preferred
|
||||
}
|
||||
|
||||
return preferred
|
||||
}
|
||||
|
||||
func dockerImageExists(ctx context.Context, image string) bool {
|
||||
cmd := exec.CommandContext(ctx, "docker", "image", "inspect", image)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdout = nil
|
||||
cmd.Stderr = nil
|
||||
return cmd.Run() == nil
|
||||
}
|
||||
|
||||
func openSQLWithRetry(ctx context.Context, dsn string, timeout time.Duration) (*sql.DB, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
var lastErr error
|
||||
|
||||
for time.Now().Before(deadline) {
|
||||
db, err := sql.Open("postgres", dsn)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := pingWithTimeout(ctx, db, 2*time.Second); err != nil {
|
||||
lastErr = err
|
||||
_ = db.Close()
|
||||
time.Sleep(250 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("db not ready after %s: %w", timeout, lastErr)
|
||||
}
|
||||
|
||||
func pingWithTimeout(ctx context.Context, db *sql.DB, timeout time.Duration) error {
|
||||
pingCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
return db.PingContext(pingCtx)
|
||||
}
|
||||
|
||||
func testTx(t *testing.T) *sql.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx, err := integrationDB.BeginTx(context.Background(), nil)
|
||||
require.NoError(t, err, "begin tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntClient 返回全局的 ent client,用于测试需要内部管理事务的代码(如 Create/Update 方法)。
|
||||
// 注意:此 client 的操作会真正写入数据库,测试结束后不会自动回滚。
|
||||
func testEntClient(t *testing.T) *dbent.Client {
|
||||
t.Helper()
|
||||
return integrationEntClient
|
||||
}
|
||||
|
||||
// testEntTx 返回一个 ent 事务,用于需要事务隔离的测试。
|
||||
// 测试结束后会自动回滚,不会影响数据库状态。
|
||||
func testEntTx(t *testing.T) *dbent.Tx {
|
||||
t.Helper()
|
||||
|
||||
tx, err := integrationEntClient.Tx(context.Background())
|
||||
require.NoError(t, err, "begin ent tx")
|
||||
t.Cleanup(func() {
|
||||
_ = tx.Rollback()
|
||||
})
|
||||
return tx
|
||||
}
|
||||
|
||||
// testEntSQLTx 已弃用:不要在新测试中使用此函数。
|
||||
// 基于 *sql.Tx 创建的 ent client 在调用 client.Tx() 时会 panic。
|
||||
// 对于需要测试内部使用事务的代码,请使用 testEntClient。
|
||||
// 对于需要事务隔离的测试,请使用 testEntTx。
|
||||
//
|
||||
// Deprecated: Use testEntClient or testEntTx instead.
|
||||
func testEntSQLTx(t *testing.T) (*dbent.Client, *sql.Tx) {
|
||||
t.Helper()
|
||||
|
||||
// 直接失败,避免旧测试误用导致的事务嵌套 panic。
|
||||
t.Fatalf("testEntSQLTx 已弃用:请使用 testEntClient 或 testEntTx")
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func testRedis(t *testing.T) *redisclient.Client {
|
||||
t.Helper()
|
||||
|
||||
prefix := fmt.Sprintf(
|
||||
"it:%s:%d:%d:",
|
||||
sanitizeRedisNamespace(t.Name()),
|
||||
time.Now().UnixNano(),
|
||||
atomic.AddUint64(&redisNamespaceSeq, 1),
|
||||
)
|
||||
|
||||
opts := *integrationRedis.Options()
|
||||
rdb := redisclient.NewClient(&opts)
|
||||
rdb.AddHook(prefixHook{prefix: prefix})
|
||||
|
||||
t.Cleanup(func() {
|
||||
ctx := context.Background()
|
||||
|
||||
var cursor uint64
|
||||
for {
|
||||
keys, nextCursor, err := integrationRedis.Scan(ctx, cursor, prefix+"*", 500).Result()
|
||||
require.NoError(t, err, "scan redis keys for cleanup")
|
||||
if len(keys) > 0 {
|
||||
require.NoError(t, integrationRedis.Unlink(ctx, keys...).Err(), "unlink redis keys for cleanup")
|
||||
}
|
||||
|
||||
cursor = nextCursor
|
||||
if cursor == 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
return rdb
|
||||
}
|
||||
|
||||
func assertTTLWithin(t *testing.T, ttl time.Duration, min, max time.Duration) {
|
||||
t.Helper()
|
||||
require.GreaterOrEqual(t, ttl, min, "ttl should be >= min")
|
||||
require.LessOrEqual(t, ttl, max, "ttl should be <= max")
|
||||
}
|
||||
|
||||
func sanitizeRedisNamespace(name string) string {
|
||||
name = strings.ReplaceAll(name, "/", "_")
|
||||
name = strings.ReplaceAll(name, " ", "_")
|
||||
return name
|
||||
}
|
||||
|
||||
type prefixHook struct {
|
||||
prefix string
|
||||
}
|
||||
|
||||
func (h prefixHook) DialHook(next redisclient.DialHook) redisclient.DialHook { return next }
|
||||
|
||||
func (h prefixHook) ProcessHook(next redisclient.ProcessHook) redisclient.ProcessHook {
|
||||
return func(ctx context.Context, cmd redisclient.Cmder) error {
|
||||
h.prefixCmd(cmd)
|
||||
return next(ctx, cmd)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) ProcessPipelineHook(next redisclient.ProcessPipelineHook) redisclient.ProcessPipelineHook {
|
||||
return func(ctx context.Context, cmds []redisclient.Cmder) error {
|
||||
for _, cmd := range cmds {
|
||||
h.prefixCmd(cmd)
|
||||
}
|
||||
return next(ctx, cmds)
|
||||
}
|
||||
}
|
||||
|
||||
func (h prefixHook) prefixCmd(cmd redisclient.Cmder) {
|
||||
args := cmd.Args()
|
||||
if len(args) < 2 {
|
||||
return
|
||||
}
|
||||
|
||||
prefixOne := func(i int) {
|
||||
if i < 0 || i >= len(args) {
|
||||
return
|
||||
}
|
||||
|
||||
switch v := args[i].(type) {
|
||||
case string:
|
||||
if v != "" && !strings.HasPrefix(v, h.prefix) {
|
||||
args[i] = h.prefix + v
|
||||
}
|
||||
case []byte:
|
||||
s := string(v)
|
||||
if s != "" && !strings.HasPrefix(s, h.prefix) {
|
||||
args[i] = []byte(h.prefix + s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch strings.ToLower(cmd.Name()) {
|
||||
case "get", "set", "setnx", "setex", "psetex", "incr", "decr", "incrby", "expire", "pexpire", "ttl", "pttl",
|
||||
"hgetall", "hget", "hset", "hdel", "hincrbyfloat", "exists",
|
||||
"zadd", "zcard", "zrange", "zrangebyscore", "zrem", "zremrangebyscore", "zrevrange", "zrevrangebyscore", "zscore":
|
||||
prefixOne(1)
|
||||
case "del", "unlink":
|
||||
for i := 1; i < len(args); i++ {
|
||||
prefixOne(i)
|
||||
}
|
||||
case "eval", "evalsha", "eval_ro", "evalsha_ro":
|
||||
if len(args) < 3 {
|
||||
return
|
||||
}
|
||||
numKeys, err := strconv.Atoi(fmt.Sprint(args[2]))
|
||||
if err != nil || numKeys <= 0 {
|
||||
return
|
||||
}
|
||||
for i := 0; i < numKeys && 3+i < len(args); i++ {
|
||||
prefixOne(3 + i)
|
||||
}
|
||||
case "scan":
|
||||
for i := 2; i+1 < len(args); i++ {
|
||||
if strings.EqualFold(fmt.Sprint(args[i]), "match") {
|
||||
prefixOne(i + 1)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IntegrationRedisSuite provides a base suite for Redis integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and rdb.
|
||||
type IntegrationRedisSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
rdb *redisclient.Client
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and rdb for each test method.
|
||||
func (s *IntegrationRedisSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.rdb = testRedis(s.T())
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationRedisSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
|
||||
// AssertTTLWithin asserts that ttl is within [min, max].
|
||||
func (s *IntegrationRedisSuite) AssertTTLWithin(ttl, min, max time.Duration) {
|
||||
s.T().Helper()
|
||||
assertTTLWithin(s.T(), ttl, min, max)
|
||||
}
|
||||
|
||||
// IntegrationDBSuite provides a base suite for DB integration tests.
|
||||
// Embedding suites should call SetupTest to initialize ctx and client.
|
||||
type IntegrationDBSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
tx *dbent.Tx
|
||||
}
|
||||
|
||||
// SetupTest initializes ctx and client for each test method.
|
||||
func (s *IntegrationDBSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
// 统一使用 ent.Tx,确保每个测试都有独立事务并自动回滚。
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.client = tx.Client()
|
||||
}
|
||||
|
||||
// RequireNoError is a convenience method wrapping require.NoError with s.T().
|
||||
func (s *IntegrationDBSuite) RequireNoError(err error, msgAndArgs ...any) {
|
||||
s.T().Helper()
|
||||
require.NoError(s.T(), err, msgAndArgs...)
|
||||
}
|
||||
@@ -0,0 +1,434 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/migrations"
|
||||
)
|
||||
|
||||
// schemaMigrationsTableDDL 定义迁移记录表的 DDL。
|
||||
// 该表用于跟踪已应用的迁移文件及其校验和。
|
||||
// - filename: 迁移文件名,作为主键唯一标识每个迁移
|
||||
// - checksum: 文件内容的 SHA256 哈希值,用于检测迁移文件是否被篡改
|
||||
// - applied_at: 迁移应用时间戳
|
||||
const schemaMigrationsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
filename TEXT PRIMARY KEY,
|
||||
checksum TEXT NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
);
|
||||
`
|
||||
|
||||
const atlasSchemaRevisionsTableDDL = `
|
||||
CREATE TABLE IF NOT EXISTS atlas_schema_revisions (
|
||||
version TEXT PRIMARY KEY,
|
||||
description TEXT NOT NULL,
|
||||
type INTEGER NOT NULL,
|
||||
applied INTEGER NOT NULL DEFAULT 0,
|
||||
total INTEGER NOT NULL DEFAULT 0,
|
||||
executed_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
||||
execution_time BIGINT NOT NULL DEFAULT 0,
|
||||
error TEXT NULL,
|
||||
error_stmt TEXT NULL,
|
||||
hash TEXT NOT NULL DEFAULT '',
|
||||
partial_hashes TEXT[] NULL,
|
||||
operator_version TEXT NULL
|
||||
);
|
||||
`
|
||||
|
||||
// migrationsAdvisoryLockID 是用于序列化迁移操作的 PostgreSQL Advisory Lock ID。
|
||||
// 在多实例部署场景下,该锁确保同一时间只有一个实例执行迁移。
|
||||
// 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。
|
||||
const migrationsAdvisoryLockID int64 = 694208311321144027
|
||||
const migrationsLockRetryInterval = 500 * time.Millisecond
|
||||
const nonTransactionalMigrationSuffix = "_notx.sql"
|
||||
|
||||
type migrationChecksumCompatibilityRule struct {
|
||||
fileChecksum string
|
||||
acceptedDBChecksum map[string]struct{}
|
||||
}
|
||||
|
||||
// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。
|
||||
// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。
|
||||
var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{
|
||||
"054_drop_legacy_cache_columns.sql": {
|
||||
fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {},
|
||||
},
|
||||
},
|
||||
"061_add_usage_log_request_type.sql": {
|
||||
fileChecksum: "66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
acceptedDBChecksum: map[string]struct{}{
|
||||
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0": {},
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3": {},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。
|
||||
//
|
||||
// 该函数可以在每次应用启动时安全调用:
|
||||
// - 已应用的迁移会被自动跳过(通过校验 filename 判断)
|
||||
// - 如果迁移文件内容被修改(checksum 不匹配),会返回错误
|
||||
// - 使用 PostgreSQL Advisory Lock 确保多实例并发安全
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文,用于超时控制和取消
|
||||
// - db: 数据库连接
|
||||
//
|
||||
// 返回:
|
||||
// - error: 迁移过程中的任何错误
|
||||
func ApplyMigrations(ctx context.Context, db *sql.DB) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
return applyMigrationsFS(ctx, db, migrations.FS)
|
||||
}
|
||||
|
||||
// applyMigrationsFS 是迁移执行的核心实现。
|
||||
// 它从指定的文件系统读取 SQL 迁移文件并按顺序应用。
|
||||
//
|
||||
// 迁移执行流程:
|
||||
// 1. 获取 PostgreSQL Advisory Lock,防止多实例并发迁移
|
||||
// 2. 确保 schema_migrations 表存在
|
||||
// 3. 按文件名排序读取所有 .sql 文件
|
||||
// 4. 对于每个迁移文件:
|
||||
// - 计算文件内容的 SHA256 校验和
|
||||
// - 检查该迁移是否已应用(通过 filename 查询)
|
||||
// - 如果已应用,验证校验和是否匹配
|
||||
// - 如果未应用,在事务中执行迁移并记录
|
||||
// 5. 释放 Advisory Lock
|
||||
//
|
||||
// 参数:
|
||||
// - ctx: 上下文
|
||||
// - db: 数据库连接
|
||||
// - fsys: 包含迁移文件的文件系统(通常是 embed.FS)
|
||||
func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
if db == nil {
|
||||
return errors.New("nil sql db")
|
||||
}
|
||||
|
||||
// 获取分布式锁,确保多实例部署时只有一个实例执行迁移。
|
||||
// 这是 PostgreSQL 特有的 Advisory Lock 机制。
|
||||
if err := pgAdvisoryLock(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
// 无论迁移是否成功,都要释放锁。
|
||||
// 使用 context.Background() 确保即使原 ctx 已取消也能释放锁。
|
||||
_ = pgAdvisoryUnlock(context.Background(), db)
|
||||
}()
|
||||
|
||||
// 创建迁移记录表(如果不存在)。
|
||||
// 该表记录所有已应用的迁移及其校验和。
|
||||
if _, err := db.ExecContext(ctx, schemaMigrationsTableDDL); err != nil {
|
||||
return fmt.Errorf("create schema_migrations: %w", err)
|
||||
}
|
||||
|
||||
// 自动对齐 Atlas 基线(如果检测到 legacy schema_migrations 且缺失 atlas_schema_revisions)。
|
||||
if err := ensureAtlasBaselineAligned(ctx, db, fsys); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取所有 .sql 迁移文件并按文件名排序。
|
||||
// 命名规范:使用零填充数字前缀(如 001_init.sql, 002_add_users.sql)。
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return fmt.Errorf("list migrations: %w", err)
|
||||
}
|
||||
sort.Strings(files) // 确保按文件名顺序执行迁移
|
||||
|
||||
for _, name := range files {
|
||||
// 读取迁移文件内容
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
if content == "" {
|
||||
continue // 跳过空文件
|
||||
}
|
||||
|
||||
// 计算文件内容的 SHA256 校验和,用于检测文件是否被修改。
|
||||
// 这是一种防篡改机制:如果有人修改了已应用的迁移文件,系统会拒绝启动。
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
checksum := hex.EncodeToString(sum[:])
|
||||
|
||||
// 检查该迁移是否已经应用
|
||||
var existing string
|
||||
rowErr := db.QueryRowContext(ctx, "SELECT checksum FROM schema_migrations WHERE filename = $1", name).Scan(&existing)
|
||||
if rowErr == nil {
|
||||
// 迁移已应用,验证校验和是否匹配
|
||||
if existing != checksum {
|
||||
// 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。
|
||||
if isMigrationChecksumCompatible(name, existing, checksum) {
|
||||
continue
|
||||
}
|
||||
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
|
||||
// 正确的做法是创建新的迁移文件来进行变更。
|
||||
return fmt.Errorf(
|
||||
"migration %s checksum mismatch (db=%s file=%s)\n"+
|
||||
"This means the migration file was modified after being applied to the database.\n"+
|
||||
"Solutions:\n"+
|
||||
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
|
||||
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
|
||||
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
|
||||
name, existing, checksum, name, name,
|
||||
)
|
||||
}
|
||||
continue // 迁移已应用且校验和匹配,跳过
|
||||
}
|
||||
if !errors.Is(rowErr, sql.ErrNoRows) {
|
||||
return fmt.Errorf("check migration %s: %w", name, rowErr)
|
||||
}
|
||||
|
||||
nonTx, err := validateMigrationExecutionMode(name, content)
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
if nonTx {
|
||||
// *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。
|
||||
// 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。
|
||||
statements := splitSQLStatements(content)
|
||||
for i, stmt := range statements {
|
||||
trimmed := strings.TrimSpace(stmt)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if stripSQLLineComment(trimmed) == "" {
|
||||
continue
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, trimmed); err != nil {
|
||||
return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err)
|
||||
}
|
||||
}
|
||||
if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
return fmt.Errorf("record migration %s (non-tx): %w", name, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 执行迁移 SQL
|
||||
if _, err := tx.ExecContext(ctx, content); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("apply migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 记录迁移已完成,保存文件名和校验和
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("record migration %s: %w", name, err)
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit(); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("commit migration %s: %w", name, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ensureAtlasBaselineAligned(ctx context.Context, db *sql.DB, fsys fs.FS) error {
|
||||
hasLegacy, err := tableExists(ctx, db, "schema_migrations")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check schema_migrations: %w", err)
|
||||
}
|
||||
if !hasLegacy {
|
||||
return nil
|
||||
}
|
||||
|
||||
hasAtlas, err := tableExists(ctx, db, "atlas_schema_revisions")
|
||||
if err != nil {
|
||||
return fmt.Errorf("check atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if !hasAtlas {
|
||||
if _, err := db.ExecContext(ctx, atlasSchemaRevisionsTableDDL); err != nil {
|
||||
return fmt.Errorf("create atlas_schema_revisions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
var count int
|
||||
if err := db.QueryRowContext(ctx, "SELECT COUNT(*) FROM atlas_schema_revisions").Scan(&count); err != nil {
|
||||
return fmt.Errorf("count atlas_schema_revisions: %w", err)
|
||||
}
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("atlas baseline version: %w", err)
|
||||
}
|
||||
|
||||
if _, err := db.ExecContext(ctx, `
|
||||
INSERT INTO atlas_schema_revisions (version, description, type, applied, total, executed_at, execution_time, hash)
|
||||
VALUES ($1, $2, $3, 0, 0, NOW(), 0, $4)
|
||||
`, version, description, 1, hash); err != nil {
|
||||
return fmt.Errorf("insert atlas baseline: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func tableExists(ctx context.Context, db *sql.DB, tableName string) (bool, error) {
|
||||
var exists bool
|
||||
err := db.QueryRowContext(ctx, `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = 'public' AND table_name = $1
|
||||
)
|
||||
`, tableName).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) {
|
||||
files, err := fs.Glob(fsys, "*.sql")
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
if len(files) == 0 {
|
||||
return "baseline", "baseline", "", nil
|
||||
}
|
||||
sort.Strings(files)
|
||||
name := files[len(files)-1]
|
||||
contentBytes, err := fs.ReadFile(fsys, name)
|
||||
if err != nil {
|
||||
return "", "", "", err
|
||||
}
|
||||
content := strings.TrimSpace(string(contentBytes))
|
||||
sum := sha256.Sum256([]byte(content))
|
||||
hash := hex.EncodeToString(sum[:])
|
||||
version := strings.TrimSuffix(name, ".sql")
|
||||
return version, version, hash, nil
|
||||
}
|
||||
|
||||
func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool {
|
||||
rule, ok := migrationChecksumCompatibilityRules[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if rule.fileChecksum != fileChecksum {
|
||||
return false
|
||||
}
|
||||
_, ok = rule.acceptedDBChecksum[dbChecksum]
|
||||
return ok
|
||||
}
|
||||
|
||||
func validateMigrationExecutionMode(name, content string) (bool, error) {
|
||||
normalizedName := strings.ToLower(strings.TrimSpace(name))
|
||||
upperContent := strings.ToUpper(content)
|
||||
nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix)
|
||||
|
||||
if !nonTx {
|
||||
if strings.Contains(upperContent, "CONCURRENTLY") {
|
||||
return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations")
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") {
|
||||
return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)")
|
||||
}
|
||||
|
||||
statements := splitSQLStatements(content)
|
||||
for _, stmt := range statements {
|
||||
normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt)))
|
||||
if normalizedStmt == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.Contains(normalizedStmt, "CONCURRENTLY") {
|
||||
isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX")
|
||||
isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX")
|
||||
if !isCreateIndex && !isDropIndex {
|
||||
return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements")
|
||||
}
|
||||
if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") {
|
||||
return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency")
|
||||
}
|
||||
if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") {
|
||||
return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency")
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements")
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func splitSQLStatements(content string) []string {
|
||||
parts := strings.Split(content, ";")
|
||||
out := make([]string, 0, len(parts))
|
||||
for _, part := range parts {
|
||||
if strings.TrimSpace(part) == "" {
|
||||
continue
|
||||
}
|
||||
out = append(out, part)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func stripSQLLineComment(s string) string {
|
||||
lines := strings.Split(s, "\n")
|
||||
for i, line := range lines {
|
||||
if idx := strings.Index(line, "--"); idx >= 0 {
|
||||
lines[i] = line[:idx]
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(lines, "\n"))
|
||||
}
|
||||
|
||||
// pgAdvisoryLock 获取 PostgreSQL Advisory Lock。
|
||||
// Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。
|
||||
// 它非常适合用于应用层面的分布式锁场景,如迁移序列化。
|
||||
func pgAdvisoryLock(ctx context.Context, db *sql.DB) error {
|
||||
ticker := time.NewTicker(migrationsLockRetryInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
var locked bool
|
||||
if err := db.QueryRowContext(ctx, "SELECT pg_try_advisory_lock($1)", migrationsAdvisoryLockID).Scan(&locked); err != nil {
|
||||
return fmt.Errorf("acquire migrations lock: %w", err)
|
||||
}
|
||||
if locked {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("acquire migrations lock: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// pgAdvisoryUnlock 释放 PostgreSQL Advisory Lock。
|
||||
// 必须在获取锁后确保释放,否则会阻塞其他实例的迁移操作。
|
||||
func pgAdvisoryUnlock(ctx context.Context, db *sql.DB) error {
|
||||
_, err := db.ExecContext(ctx, "SELECT pg_advisory_unlock($1)", migrationsAdvisoryLockID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("release migrations lock: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsMigrationChecksumCompatible(t *testing.T) {
|
||||
t.Run("054历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"054_drop_legacy_cache_columns.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("054在未知文件checksum下不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"054_drop_legacy_cache_columns.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"0000000000000000000000000000000000000000000000000000000000000000",
|
||||
)
|
||||
require.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("061历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"061_add_usage_log_request_type.sql",
|
||||
"08a248652cbab7cfde147fc6ef8cda464f2477674e20b718312faa252e0481c0",
|
||||
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("061第二个历史checksum可兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"061_add_usage_log_request_type.sql",
|
||||
"222b4a09c797c22e5922b6b172327c824f5463aaa8760e4f621bc5c22e2be0f3",
|
||||
"66207e7aa5dd0429c2e2c0fabdaf79783ff157fa0af2e81adff2ee03790ec65c",
|
||||
)
|
||||
require.True(t, ok)
|
||||
})
|
||||
|
||||
t.Run("非白名单迁移不兼容", func(t *testing.T) {
|
||||
ok := isMigrationChecksumCompatible(
|
||||
"001_init.sql",
|
||||
"182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4",
|
||||
"82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d",
|
||||
)
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,368 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io/fs"
|
||||
"strings"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
"time"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestApplyMigrations_NilDB(t *testing.T) {
|
||||
err := ApplyMigrations(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "nil sql db")
|
||||
}
|
||||
|
||||
func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnError(errors.New("lock failed"))
|
||||
|
||||
err = ApplyMigrations(context.Background(), db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "acquire migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestLatestMigrationBaseline(t *testing.T) {
|
||||
t.Run("empty_fs_returns_baseline", func(t *testing.T) {
|
||||
version, description, hash, err := latestMigrationBaseline(fstest.MapFS{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "baseline", version)
|
||||
require.Equal(t, "baseline", description)
|
||||
require.Equal(t, "", hash)
|
||||
})
|
||||
|
||||
t.Run("uses_latest_sorted_sql_file", func(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
||||
"010_final.sql": &fstest.MapFile{
|
||||
Data: []byte("CREATE TABLE t2(id int);"),
|
||||
},
|
||||
}
|
||||
version, description, hash, err := latestMigrationBaseline(fsys)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "010_final", version)
|
||||
require.Equal(t, "010_final", description)
|
||||
require.Len(t, hash, 64)
|
||||
})
|
||||
|
||||
t.Run("read_file_error", func(t *testing.T) {
|
||||
fsys := fstest.MapFS{
|
||||
"010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
||||
}
|
||||
_, _, _, err := latestMigrationBaseline(fsys)
|
||||
require.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) {
|
||||
require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file"))
|
||||
|
||||
var (
|
||||
name string
|
||||
rule migrationChecksumCompatibilityRule
|
||||
)
|
||||
for n, r := range migrationChecksumCompatibilityRules {
|
||||
name = n
|
||||
rule = r
|
||||
break
|
||||
}
|
||||
require.NotEmpty(t, name)
|
||||
|
||||
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match"))
|
||||
require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum))
|
||||
|
||||
var accepted string
|
||||
for checksum := range rule.acceptedDBChecksum {
|
||||
accepted = checksum
|
||||
break
|
||||
}
|
||||
require.NotEmpty(t, accepted)
|
||||
require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum))
|
||||
}
|
||||
|
||||
func TestEnsureAtlasBaselineAligned(t *testing.T) {
|
||||
t.Run("skip_when_no_legacy_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
||||
WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")},
|
||||
"002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")},
|
||||
}
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_checking_legacy_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnError(errors.New("exists failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "check schema_migrations")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_counting_atlas_rows", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnError(errors.New("count failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "count atlas_schema_revisions")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_creating_atlas_table", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions").
|
||||
WillReturnError(errors.New("create failed"))
|
||||
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "create atlas_schema_revisions")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("error_when_inserting_baseline", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
|
||||
mock.ExpectExec("INSERT INTO atlas_schema_revisions").
|
||||
WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()).
|
||||
WillReturnError(errors.New("insert failed"))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
||||
}
|
||||
err = ensureAtlasBaselineAligned(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "insert atlas baseline")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_init.sql").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum"))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "checksum mismatch")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_err.sql").
|
||||
WillReturnError(errors.New("query failed"))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "check migration 001_err.sql")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
|
||||
alreadySQL := "CREATE TABLE t(id int);"
|
||||
checksum := migrationChecksum(alreadySQL)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_already.sql").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")},
|
||||
"001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir},
|
||||
}
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "read migration 001_bad.sql")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) {
|
||||
t.Run("context_cancelled_while_not_locked", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond)
|
||||
defer cancel()
|
||||
err = pgAdvisoryLock(ctx, db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "acquire migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("unlock_exec_error", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnError(errors.New("unlock failed"))
|
||||
|
||||
err = pgAdvisoryUnlock(context.Background(), db)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "release migrations lock")
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
|
||||
t.Run("acquire_lock_after_retry", func(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false))
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3)
|
||||
defer cancel()
|
||||
start := time.Now()
|
||||
err = pgAdvisoryLock(ctx, db)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
})
|
||||
}
|
||||
|
||||
func migrationChecksum(content string) string {
|
||||
sum := sha256.Sum256([]byte(strings.TrimSpace(content)))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
@@ -0,0 +1,164 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
"testing/fstest"
|
||||
|
||||
sqlmock "github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateMigrationExecutionMode(t *testing.T) {
|
||||
t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移禁止事务控制语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;")
|
||||
require.False(t, nonTx)
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) {
|
||||
nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", `
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a);
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_b;
|
||||
`)
|
||||
require.True(t, nonTx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_idx_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_idx_notx.sql": &fstest.MapFile{
|
||||
Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_multi_idx_notx.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_multi_idx_notx.sql": &fstest.MapFile{
|
||||
Data: []byte(`
|
||||
-- first
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);
|
||||
-- second
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b);
|
||||
`),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) {
|
||||
db, mock, err := sqlmock.New()
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = db.Close() }()
|
||||
|
||||
prepareMigrationsBootstrapExpectations(mock)
|
||||
mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1").
|
||||
WithArgs("001_add_col.sql").
|
||||
WillReturnError(sql.ErrNoRows)
|
||||
mock.ExpectBegin()
|
||||
mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)").
|
||||
WithArgs("001_add_col.sql", sqlmock.AnyArg()).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
mock.ExpectCommit()
|
||||
mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnResult(sqlmock.NewResult(0, 1))
|
||||
|
||||
fsys := fstest.MapFS{
|
||||
"001_add_col.sql": &fstest.MapFile{
|
||||
Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"),
|
||||
},
|
||||
}
|
||||
|
||||
err = applyMigrationsFS(context.Background(), db, fsys)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) {
|
||||
mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)").
|
||||
WithArgs(migrationsAdvisoryLockID).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true))
|
||||
mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations").
|
||||
WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("schema_migrations").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT EXISTS \\(").
|
||||
WithArgs("atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true))
|
||||
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
|
||||
}
|
||||
@@ -0,0 +1,141 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) {
|
||||
tx := testTx(t)
|
||||
|
||||
// Re-apply migrations to verify idempotency (no errors, no duplicate rows).
|
||||
require.NoError(t, ApplyMigrations(context.Background(), integrationDB))
|
||||
|
||||
// schema_migrations should have at least the current migration set.
|
||||
var applied int
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM schema_migrations").Scan(&applied))
|
||||
require.GreaterOrEqual(t, applied, 7, "expected schema_migrations to contain applied migrations")
|
||||
|
||||
// users: columns required by repository queries
|
||||
requireColumn(t, tx, "users", "username", "character varying", 100, false)
|
||||
requireColumn(t, tx, "users", "notes", "text", 0, false)
|
||||
|
||||
// accounts: schedulable and rate-limit fields
|
||||
requireColumn(t, tx, "accounts", "notes", "text", 0, true)
|
||||
requireColumn(t, tx, "accounts", "schedulable", "boolean", 0, false)
|
||||
requireColumn(t, tx, "accounts", "rate_limited_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "rate_limit_reset_at", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "overload_until", "timestamp with time zone", 0, true)
|
||||
requireColumn(t, tx, "accounts", "session_window_status", "character varying", 20, true)
|
||||
|
||||
// api_keys: key length should be 128
|
||||
requireColumn(t, tx, "api_keys", "key", "character varying", 128, false)
|
||||
|
||||
// redeem_codes: subscription fields
|
||||
requireColumn(t, tx, "redeem_codes", "group_id", "bigint", 0, true)
|
||||
requireColumn(t, tx, "redeem_codes", "validity_days", "integer", 0, false)
|
||||
|
||||
// usage_logs: billing_type used by filters/stats
|
||||
requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false)
|
||||
requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false)
|
||||
|
||||
// usage_billing_dedup: billing idempotency narrow table
|
||||
var usageBillingDedupRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup')").Scan(&usageBillingDedupRegclass))
|
||||
require.True(t, usageBillingDedupRegclass.Valid, "expected usage_billing_dedup table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_request_api_key")
|
||||
requireIndex(t, tx, "usage_billing_dedup", "idx_usage_billing_dedup_created_at_brin")
|
||||
|
||||
var usageBillingDedupArchiveRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.usage_billing_dedup_archive')").Scan(&usageBillingDedupArchiveRegclass))
|
||||
require.True(t, usageBillingDedupArchiveRegclass.Valid, "expected usage_billing_dedup_archive table to exist")
|
||||
requireColumn(t, tx, "usage_billing_dedup_archive", "request_fingerprint", "character varying", 64, false)
|
||||
requireIndex(t, tx, "usage_billing_dedup_archive", "usage_billing_dedup_archive_pkey")
|
||||
|
||||
// settings table should exist
|
||||
var settingsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.settings')").Scan(&settingsRegclass))
|
||||
require.True(t, settingsRegclass.Valid, "expected settings table to exist")
|
||||
|
||||
// security_secrets table should exist
|
||||
var securitySecretsRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.security_secrets')").Scan(&securitySecretsRegclass))
|
||||
require.True(t, securitySecretsRegclass.Valid, "expected security_secrets table to exist")
|
||||
|
||||
// user_allowed_groups table should exist
|
||||
var uagRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.user_allowed_groups')").Scan(&uagRegclass))
|
||||
require.True(t, uagRegclass.Valid, "expected user_allowed_groups table to exist")
|
||||
|
||||
// user_subscriptions: deleted_at for soft delete support (migration 012)
|
||||
requireColumn(t, tx, "user_subscriptions", "deleted_at", "timestamp with time zone", 0, true)
|
||||
|
||||
// orphan_allowed_groups_audit table should exist (migration 013)
|
||||
var orphanAuditRegclass sql.NullString
|
||||
require.NoError(t, tx.QueryRowContext(context.Background(), "SELECT to_regclass('public.orphan_allowed_groups_audit')").Scan(&orphanAuditRegclass))
|
||||
require.True(t, orphanAuditRegclass.Valid, "expected orphan_allowed_groups_audit table to exist")
|
||||
|
||||
// account_groups: created_at should be timestamptz
|
||||
requireColumn(t, tx, "account_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
|
||||
// user_allowed_groups: created_at should be timestamptz
|
||||
requireColumn(t, tx, "user_allowed_groups", "created_at", "timestamp with time zone", 0, false)
|
||||
}
|
||||
|
||||
func requireIndex(t *testing.T, tx *sql.Tx, table, index string) {
|
||||
t.Helper()
|
||||
|
||||
var exists bool
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE schemaname = 'public'
|
||||
AND tablename = $1
|
||||
AND indexname = $2
|
||||
)
|
||||
`, table, index).Scan(&exists)
|
||||
require.NoError(t, err, "query pg_indexes for %s.%s", table, index)
|
||||
require.True(t, exists, "expected index %s on %s", index, table)
|
||||
}
|
||||
|
||||
func requireColumn(t *testing.T, tx *sql.Tx, table, column, dataType string, maxLen int, nullable bool) {
|
||||
t.Helper()
|
||||
|
||||
var row struct {
|
||||
DataType string
|
||||
MaxLen sql.NullInt64
|
||||
Nullable string
|
||||
}
|
||||
|
||||
err := tx.QueryRowContext(context.Background(), `
|
||||
SELECT
|
||||
data_type,
|
||||
character_maximum_length,
|
||||
is_nullable
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = $1
|
||||
AND column_name = $2
|
||||
`, table, column).Scan(&row.DataType, &row.MaxLen, &row.Nullable)
|
||||
require.NoError(t, err, "query information_schema.columns for %s.%s", table, column)
|
||||
require.Equal(t, dataType, row.DataType, "data_type mismatch for %s.%s", table, column)
|
||||
|
||||
if maxLen > 0 {
|
||||
require.True(t, row.MaxLen.Valid, "expected maxLen for %s.%s", table, column)
|
||||
require.Equal(t, int64(maxLen), row.MaxLen.Int64, "maxLen mismatch for %s.%s", table, column)
|
||||
}
|
||||
|
||||
if nullable {
|
||||
require.Equal(t, "YES", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
} else {
|
||||
require.Equal(t, "NO", row.Nullable, "nullable mismatch for %s.%s", table, column)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
// NewOpenAIOAuthClient creates a new OpenAI OAuth client
|
||||
func NewOpenAIOAuthClient() service.OpenAIOAuthClient {
|
||||
return &openaiOAuthService{tokenURL: openai.TokenURL}
|
||||
}
|
||||
|
||||
type openaiOAuthService struct {
|
||||
tokenURL string
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client, err := createOpenAIReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
|
||||
}
|
||||
|
||||
if redirectURI == "" {
|
||||
redirectURI = openai.DefaultRedirectURI
|
||||
}
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("code", code)
|
||||
formData.Set("redirect_uri", redirectURI)
|
||||
formData.Set("code_verifier", codeVerifier)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_EXCHANGE_FAILED", "token exchange failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) {
|
||||
return s.RefreshTokenWithClientID(ctx, refreshToken, proxyURL, "")
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) {
|
||||
// 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID
|
||||
clientID = strings.TrimSpace(clientID)
|
||||
if clientID == "" {
|
||||
clientID = openai.ClientID
|
||||
}
|
||||
return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID)
|
||||
}
|
||||
|
||||
func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) {
|
||||
client, err := createOpenAIReqClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_CLIENT_INIT_FAILED", "create HTTP client: %v", err)
|
||||
}
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "refresh_token")
|
||||
formData.Set("refresh_token", refreshToken)
|
||||
formData.Set("client_id", clientID)
|
||||
formData.Set("scope", openai.RefreshScopes)
|
||||
|
||||
var tokenResp openai.TokenResponse
|
||||
|
||||
resp, err := client.R().
|
||||
SetContext(ctx).
|
||||
SetHeader("User-Agent", "codex-cli/0.91.0").
|
||||
SetFormDataFromValues(formData).
|
||||
SetSuccessResult(&tokenResp).
|
||||
Post(s.tokenURL)
|
||||
|
||||
if err != nil {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.IsSuccessState() {
|
||||
return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed: status %d, body: %s", resp.StatusCode, resp.String())
|
||||
}
|
||||
|
||||
return &tokenResp, nil
|
||||
}
|
||||
|
||||
func createOpenAIReqClient(proxyURL string) (*req.Client, error) {
|
||||
return getSharedReqClient(reqClientOptions{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: 120 * time.Second,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,350 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type OpenAIOAuthServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
svc *openaiOAuthService
|
||||
received chan url.Values
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.received = make(chan url.Values, 1)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
s.svc = &openaiOAuthService{tokenURL: s.srv.URL}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
errCh <- "method mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "authorization_code" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code"); got != "code" {
|
||||
errCh <- "code mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("redirect_uri"); got != openai.DefaultRedirectURI {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("code_verifier"); got != "ver" {
|
||||
errCh <- "code_verifier mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt", resp.RefreshToken)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() {
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
errCh <- "ParseForm failed"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("grant_type"); got != "refresh_token" {
|
||||
errCh <- "grant_type mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("refresh_token"); got != "rt" {
|
||||
errCh <- "refresh_token mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("client_id"); got != openai.ClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if got := r.PostForm.Get("scope"); got != openai.RefreshScopes {
|
||||
errCh <- "scope mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at2","refresh_token":"rt2","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
require.Equal(s.T(), "at2", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt2", resp.RefreshToken)
|
||||
}
|
||||
|
||||
// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID,
|
||||
// 且只发送一次请求(不再盲猜多个 client_id)。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.NoError(s.T(), err, "RefreshToken")
|
||||
require.Equal(s.T(), "at", resp.AccessToken)
|
||||
// 只发送了一次请求,使用默认的 OpenAI ClientID
|
||||
require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() {
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID == openai.SoraClientID {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-sora", resp.AccessToken)
|
||||
require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() {
|
||||
const customClientID = "custom-client-id"
|
||||
var seenClientIDs []string
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.ParseForm(); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
clientID := r.PostForm.Get("client_id")
|
||||
seenClientIDs = append(seenClientIDs, clientID)
|
||||
if clientID != customClientID {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at-custom","refresh_token":"rt-custom","token_type":"bearer","expires_in":3600}`)
|
||||
}))
|
||||
|
||||
resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", customClientID)
|
||||
require.NoError(s.T(), err, "RefreshTokenWithClientID")
|
||||
require.Equal(s.T(), "at-custom", resp.AccessToken)
|
||||
require.Equal(s.T(), "rt-custom", resp.RefreshToken)
|
||||
require.Equal(s.T(), []string{customClientID}, seenClientIDs)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
_, _ = io.WriteString(w, "bad")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "status 400")
|
||||
require.ErrorContains(s.T(), err, "bad")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.srv.Close()
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "request failed")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestContextCancel() {
|
||||
started := make(chan struct{})
|
||||
block := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-block
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
close(block)
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() {
|
||||
want := "http://localhost:9999/cb"
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("redirect_uri"); got != want {
|
||||
errCh <- "redirect_uri mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() {
|
||||
wantClientID := openai.SoraClientID
|
||||
errCh := make(chan string, 1)
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
if got := r.PostForm.Get("client_id"); got != wantClientID {
|
||||
errCh <- "client_id mismatch"
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID)
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case msg := <-errCh:
|
||||
require.Fail(s.T(), msg)
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
_ = r.ParseForm()
|
||||
s.received <- r.PostForm
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`)
|
||||
}))
|
||||
s.svc.tokenURL = s.srv.URL + "?x=1"
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.NoError(s.T(), err, "ExchangeCode")
|
||||
select {
|
||||
case <-s.received:
|
||||
default:
|
||||
require.Fail(s.T(), "expected server to receive request")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = io.WriteString(w, "not-valid-json")
|
||||
}))
|
||||
|
||||
_, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "")
|
||||
require.Error(s.T(), err, "expected error for invalid JSON response")
|
||||
}
|
||||
|
||||
func (s *OpenAIOAuthServiceSuite) TestRefreshToken_NonSuccessStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = io.WriteString(w, "unauthorized")
|
||||
}))
|
||||
|
||||
_, err := s.svc.RefreshToken(s.ctx, "rt", "")
|
||||
require.Error(s.T(), err, "expected error for non-2xx status")
|
||||
require.ErrorContains(s.T(), err, "status 401")
|
||||
}
|
||||
|
||||
func TestNewOpenAIOAuthClient_DefaultTokenURL(t *testing.T) {
|
||||
client := NewOpenAIOAuthClient()
|
||||
svc, ok := client.(*openaiOAuthService)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, openai.TokenURL, svc.tokenURL)
|
||||
}
|
||||
|
||||
func TestOpenAIOAuthServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(OpenAIOAuthServiceSuite))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,853 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListAlertRules(ctx context.Context) ([]*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at
|
||||
FROM ops_alert_rules
|
||||
ORDER BY id DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertRule{}
|
||||
for rows.Next() {
|
||||
var rule service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&rule.ID,
|
||||
&rule.Name,
|
||||
&rule.Description,
|
||||
&rule.Enabled,
|
||||
&rule.Severity,
|
||||
&rule.MetricType,
|
||||
&rule.Operator,
|
||||
&rule.Threshold,
|
||||
&rule.WindowMinutes,
|
||||
&rule.SustainedMinutes,
|
||||
&rule.CooldownMinutes,
|
||||
&rule.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&rule.CreatedAt,
|
||||
&rule.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
rule.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
rule.Filters = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &rule)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_rules (
|
||||
name,
|
||||
description,
|
||||
enabled,
|
||||
severity,
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
notify_email,
|
||||
filters,
|
||||
created_at,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,NOW(),NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertRule(ctx context.Context, input *service.OpsAlertRule) (*service.OpsAlertRule, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.ID <= 0 {
|
||||
return nil, fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
filtersArg, err := opsNullJSONMap(input.Filters)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_rules
|
||||
SET
|
||||
name = $2,
|
||||
description = $3,
|
||||
enabled = $4,
|
||||
severity = $5,
|
||||
metric_type = $6,
|
||||
operator = $7,
|
||||
threshold = $8,
|
||||
window_minutes = $9,
|
||||
sustained_minutes = $10,
|
||||
cooldown_minutes = $11,
|
||||
notify_email = $12,
|
||||
filters = $13,
|
||||
updated_at = NOW()
|
||||
WHERE id = $1
|
||||
RETURNING
|
||||
id,
|
||||
name,
|
||||
COALESCE(description, ''),
|
||||
enabled,
|
||||
COALESCE(severity, ''),
|
||||
metric_type,
|
||||
operator,
|
||||
threshold,
|
||||
window_minutes,
|
||||
sustained_minutes,
|
||||
cooldown_minutes,
|
||||
COALESCE(notify_email, true),
|
||||
filters,
|
||||
last_triggered_at,
|
||||
created_at,
|
||||
updated_at`
|
||||
|
||||
var out service.OpsAlertRule
|
||||
var filtersRaw []byte
|
||||
var lastTriggeredAt sql.NullTime
|
||||
|
||||
if err := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.ID,
|
||||
strings.TrimSpace(input.Name),
|
||||
strings.TrimSpace(input.Description),
|
||||
input.Enabled,
|
||||
strings.TrimSpace(input.Severity),
|
||||
strings.TrimSpace(input.MetricType),
|
||||
strings.TrimSpace(input.Operator),
|
||||
input.Threshold,
|
||||
input.WindowMinutes,
|
||||
input.SustainedMinutes,
|
||||
input.CooldownMinutes,
|
||||
input.NotifyEmail,
|
||||
filtersArg,
|
||||
).Scan(
|
||||
&out.ID,
|
||||
&out.Name,
|
||||
&out.Description,
|
||||
&out.Enabled,
|
||||
&out.Severity,
|
||||
&out.MetricType,
|
||||
&out.Operator,
|
||||
&out.Threshold,
|
||||
&out.WindowMinutes,
|
||||
&out.SustainedMinutes,
|
||||
&out.CooldownMinutes,
|
||||
&out.NotifyEmail,
|
||||
&filtersRaw,
|
||||
&lastTriggeredAt,
|
||||
&out.CreatedAt,
|
||||
&out.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastTriggeredAt.Valid {
|
||||
v := lastTriggeredAt.Time
|
||||
out.LastTriggeredAt = &v
|
||||
}
|
||||
if len(filtersRaw) > 0 && string(filtersRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(filtersRaw, &decoded); err == nil {
|
||||
out.Filters = decoded
|
||||
}
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) DeleteAlertRule(ctx context.Context, id int64) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if id <= 0 {
|
||||
return fmt.Errorf("invalid id")
|
||||
}
|
||||
|
||||
res, err := r.db.ExecContext(ctx, "DELETE FROM ops_alert_rules WHERE id = $1", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
affected, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListAlertEvents(ctx context.Context, filter *service.OpsAlertEventFilter) ([]*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
filter = &service.OpsAlertEventFilter{}
|
||||
}
|
||||
|
||||
limit := filter.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 500 {
|
||||
limit = 500
|
||||
}
|
||||
|
||||
where, args := buildOpsAlertEventsWhere(filter)
|
||||
args = append(args, limit)
|
||||
limitArg := "$" + itoa(len(args))
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
` + where + `
|
||||
ORDER BY fired_at DESC, id DESC
|
||||
LIMIT ` + limitArg
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := []*service.OpsAlertEvent{}
|
||||
for rows.Next() {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
if err := rows.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
out = append(out, &ev)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetAlertEventByID(ctx context.Context, eventID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return nil, fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE id = $1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, eventID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetActiveAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1 AND status = $2
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID, service.OpsAlertStatusFiring)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestAlertEvent(ctx context.Context, ruleID int64) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
FROM ops_alert_events
|
||||
WHERE rule_id = $1
|
||||
ORDER BY fired_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, q, ruleID)
|
||||
ev, err := scanOpsAlertEvent(row)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return ev, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertEvent(ctx context.Context, event *service.OpsAlertEvent) (*service.OpsAlertEvent, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if event == nil {
|
||||
return nil, fmt.Errorf("nil event")
|
||||
}
|
||||
|
||||
dimensionsArg, err := opsNullJSONMap(event.Dimensions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_events (
|
||||
rule_id,
|
||||
severity,
|
||||
status,
|
||||
title,
|
||||
description,
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,NOW()
|
||||
)
|
||||
RETURNING
|
||||
id,
|
||||
COALESCE(rule_id, 0),
|
||||
COALESCE(severity, ''),
|
||||
COALESCE(status, ''),
|
||||
COALESCE(title, ''),
|
||||
COALESCE(description, ''),
|
||||
metric_value,
|
||||
threshold_value,
|
||||
dimensions,
|
||||
fired_at,
|
||||
resolved_at,
|
||||
email_sent,
|
||||
created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
opsNullInt64(&event.RuleID),
|
||||
opsNullString(event.Severity),
|
||||
opsNullString(event.Status),
|
||||
opsNullString(event.Title),
|
||||
opsNullString(event.Description),
|
||||
opsNullFloat64(event.MetricValue),
|
||||
opsNullFloat64(event.ThresholdValue),
|
||||
dimensionsArg,
|
||||
event.FiredAt,
|
||||
opsNullTime(event.ResolvedAt),
|
||||
event.EmailSent,
|
||||
)
|
||||
return scanOpsAlertEvent(row)
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventStatus(ctx context.Context, eventID int64, status string, resolvedAt *time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
if strings.TrimSpace(status) == "" {
|
||||
return fmt.Errorf("invalid status")
|
||||
}
|
||||
|
||||
q := `
|
||||
UPDATE ops_alert_events
|
||||
SET status = $2,
|
||||
resolved_at = $3
|
||||
WHERE id = $1`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, eventID, strings.TrimSpace(status), opsNullTime(resolvedAt))
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpdateAlertEventEmailSent(ctx context.Context, eventID int64, emailSent bool) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if eventID <= 0 {
|
||||
return fmt.Errorf("invalid event id")
|
||||
}
|
||||
|
||||
_, err := r.db.ExecContext(ctx, "UPDATE ops_alert_events SET email_sent = $2 WHERE id = $1", eventID, emailSent)
|
||||
return err
|
||||
}
|
||||
|
||||
type opsAlertEventRow interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
func (r *opsRepository) CreateAlertSilence(ctx context.Context, input *service.OpsAlertSilence) (*service.OpsAlertSilence, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("nil input")
|
||||
}
|
||||
if input.RuleID <= 0 {
|
||||
return nil, fmt.Errorf("invalid rule_id")
|
||||
}
|
||||
platform := strings.TrimSpace(input.Platform)
|
||||
if platform == "" {
|
||||
return nil, fmt.Errorf("invalid platform")
|
||||
}
|
||||
if input.Until.IsZero() {
|
||||
return nil, fmt.Errorf("invalid until")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_alert_silences (
|
||||
rule_id,
|
||||
platform,
|
||||
group_id,
|
||||
region,
|
||||
until,
|
||||
reason,
|
||||
created_by,
|
||||
created_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||
)
|
||||
RETURNING id, rule_id, platform, group_id, region, until, COALESCE(reason,''), created_by, created_at`
|
||||
|
||||
row := r.db.QueryRowContext(
|
||||
ctx,
|
||||
q,
|
||||
input.RuleID,
|
||||
platform,
|
||||
opsNullInt64(input.GroupID),
|
||||
opsNullString(input.Region),
|
||||
input.Until,
|
||||
opsNullString(input.Reason),
|
||||
opsNullInt64(input.CreatedBy),
|
||||
)
|
||||
|
||||
var out service.OpsAlertSilence
|
||||
var groupID sql.NullInt64
|
||||
var region sql.NullString
|
||||
var createdBy sql.NullInt64
|
||||
if err := row.Scan(
|
||||
&out.ID,
|
||||
&out.RuleID,
|
||||
&out.Platform,
|
||||
&groupID,
|
||||
®ion,
|
||||
&out.Until,
|
||||
&out.Reason,
|
||||
&createdBy,
|
||||
&out.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if groupID.Valid {
|
||||
v := groupID.Int64
|
||||
out.GroupID = &v
|
||||
}
|
||||
if region.Valid {
|
||||
v := strings.TrimSpace(region.String)
|
||||
if v != "" {
|
||||
out.Region = &v
|
||||
}
|
||||
}
|
||||
if createdBy.Valid {
|
||||
v := createdBy.Int64
|
||||
out.CreatedBy = &v
|
||||
}
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) IsAlertSilenced(ctx context.Context, ruleID int64, platform string, groupID *int64, region *string, now time.Time) (bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if ruleID <= 0 {
|
||||
return false, fmt.Errorf("invalid rule id")
|
||||
}
|
||||
platform = strings.TrimSpace(platform)
|
||||
if platform == "" {
|
||||
return false, nil
|
||||
}
|
||||
if now.IsZero() {
|
||||
now = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT 1
|
||||
FROM ops_alert_silences
|
||||
WHERE rule_id = $1
|
||||
AND platform = $2
|
||||
AND (group_id IS NOT DISTINCT FROM $3)
|
||||
AND (region IS NOT DISTINCT FROM $4)
|
||||
AND until > $5
|
||||
LIMIT 1`
|
||||
|
||||
var dummy int
|
||||
err := r.db.QueryRowContext(ctx, q, ruleID, platform, opsNullInt64(groupID), opsNullString(region), now).Scan(&dummy)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func scanOpsAlertEvent(row opsAlertEventRow) (*service.OpsAlertEvent, error) {
|
||||
var ev service.OpsAlertEvent
|
||||
var metricValue sql.NullFloat64
|
||||
var thresholdValue sql.NullFloat64
|
||||
var dimensionsRaw []byte
|
||||
var resolvedAt sql.NullTime
|
||||
|
||||
if err := row.Scan(
|
||||
&ev.ID,
|
||||
&ev.RuleID,
|
||||
&ev.Severity,
|
||||
&ev.Status,
|
||||
&ev.Title,
|
||||
&ev.Description,
|
||||
&metricValue,
|
||||
&thresholdValue,
|
||||
&dimensionsRaw,
|
||||
&ev.FiredAt,
|
||||
&resolvedAt,
|
||||
&ev.EmailSent,
|
||||
&ev.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if metricValue.Valid {
|
||||
v := metricValue.Float64
|
||||
ev.MetricValue = &v
|
||||
}
|
||||
if thresholdValue.Valid {
|
||||
v := thresholdValue.Float64
|
||||
ev.ThresholdValue = &v
|
||||
}
|
||||
if resolvedAt.Valid {
|
||||
v := resolvedAt.Time
|
||||
ev.ResolvedAt = &v
|
||||
}
|
||||
if len(dimensionsRaw) > 0 && string(dimensionsRaw) != "null" {
|
||||
var decoded map[string]any
|
||||
if err := json.Unmarshal(dimensionsRaw, &decoded); err == nil {
|
||||
ev.Dimensions = decoded
|
||||
}
|
||||
}
|
||||
return &ev, nil
|
||||
}
|
||||
|
||||
func buildOpsAlertEventsWhere(filter *service.OpsAlertEventFilter) (string, []any) {
|
||||
clauses := []string{"1=1"}
|
||||
args := []any{}
|
||||
|
||||
if filter == nil {
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
if status := strings.TrimSpace(filter.Status); status != "" {
|
||||
args = append(args, status)
|
||||
clauses = append(clauses, "status = $"+itoa(len(args)))
|
||||
}
|
||||
if severity := strings.TrimSpace(filter.Severity); severity != "" {
|
||||
args = append(args, severity)
|
||||
clauses = append(clauses, "severity = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EmailSent != nil {
|
||||
args = append(args, *filter.EmailSent)
|
||||
clauses = append(clauses, "email_sent = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.StartTime != nil && !filter.StartTime.IsZero() {
|
||||
args = append(args, *filter.StartTime)
|
||||
clauses = append(clauses, "fired_at >= $"+itoa(len(args)))
|
||||
}
|
||||
if filter.EndTime != nil && !filter.EndTime.IsZero() {
|
||||
args = append(args, *filter.EndTime)
|
||||
clauses = append(clauses, "fired_at < $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
// Cursor pagination (descending by fired_at, then id)
|
||||
if filter.BeforeFiredAt != nil && !filter.BeforeFiredAt.IsZero() && filter.BeforeID != nil && *filter.BeforeID > 0 {
|
||||
args = append(args, *filter.BeforeFiredAt)
|
||||
tsArg := "$" + itoa(len(args))
|
||||
args = append(args, *filter.BeforeID)
|
||||
idArg := "$" + itoa(len(args))
|
||||
clauses = append(clauses, fmt.Sprintf("(fired_at < %s OR (fired_at = %s AND id < %s))", tsArg, tsArg, idArg))
|
||||
}
|
||||
// Dimensions are stored in JSONB. We filter best-effort without requiring GIN indexes.
|
||||
if platform := strings.TrimSpace(filter.Platform); platform != "" {
|
||||
args = append(args, platform)
|
||||
clauses = append(clauses, "(dimensions->>'platform') = $"+itoa(len(args)))
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
args = append(args, fmt.Sprintf("%d", *filter.GroupID))
|
||||
clauses = append(clauses, "(dimensions->>'group_id') = $"+itoa(len(args)))
|
||||
}
|
||||
|
||||
return "WHERE " + strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func opsNullJSONMap(v map[string]any) (any, error) {
|
||||
if v == nil {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(b) == 0 {
|
||||
return sql.NullString{}, nil
|
||||
}
|
||||
return sql.NullString{String: string(b), Valid: true}, nil
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,22 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsQueryTimeoutErr(t *testing.T) {
|
||||
if !isQueryTimeoutErr(context.DeadlineExceeded) {
|
||||
t.Fatalf("context.DeadlineExceeded should be treated as query timeout")
|
||||
}
|
||||
if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) {
|
||||
t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout")
|
||||
}
|
||||
if isQueryTimeoutErr(context.Canceled) {
|
||||
t.Fatalf("context.Canceled should not be treated as query timeout")
|
||||
}
|
||||
if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) {
|
||||
t.Fatalf("wrapped context.Canceled should not be treated as query timeout")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_QueryUsesQualifiedColumns(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
Query: "ACCESS_DENIED",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "e.request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.client_request_id ILIKE $") {
|
||||
t.Fatalf("where should include qualified client_request_id condition: %s", where)
|
||||
}
|
||||
if !strings.Contains(where, "e.error_message ILIKE $") {
|
||||
t.Fatalf("where should include qualified error_message condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsErrorLogsWhere_UserQueryUsesExistsSubquery(t *testing.T) {
|
||||
filter := &service.OpsErrorLogFilter{
|
||||
UserQuery: "admin@",
|
||||
}
|
||||
|
||||
where, args := buildOpsErrorLogsWhere(filter)
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 1 {
|
||||
t.Fatalf("args len = %d, want 1", len(args))
|
||||
}
|
||||
if !strings.Contains(where, "EXISTS (SELECT 1 FROM users u WHERE u.id = e.user_id AND u.email ILIKE $") {
|
||||
t.Fatalf("where should include EXISTS user email condition: %s", where)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetLatencyHistogram(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsLatencyHistogramResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
join, where, args, _ := buildUsageWhere(filter, start, end, 1)
|
||||
rangeExpr := latencyHistogramRangeCaseExpr("ul.duration_ms")
|
||||
orderExpr := latencyHistogramRangeOrderCaseExpr("ul.duration_ms")
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + rangeExpr + ` AS range,
|
||||
COALESCE(COUNT(*), 0) AS count,
|
||||
` + orderExpr + ` AS ord
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
AND ul.duration_ms IS NOT NULL
|
||||
GROUP BY 1, 3
|
||||
ORDER BY 3 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
counts := make(map[string]int64, len(latencyHistogramOrderedRanges))
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var label string
|
||||
var count int64
|
||||
var _ord int
|
||||
if err := rows.Scan(&label, &count, &_ord); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[label] = count
|
||||
total += count
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buckets := make([]*service.OpsLatencyHistogramBucket, 0, len(latencyHistogramOrderedRanges))
|
||||
for _, label := range latencyHistogramOrderedRanges {
|
||||
buckets = append(buckets, &service.OpsLatencyHistogramBucket{
|
||||
Range: label,
|
||||
Count: counts[label],
|
||||
})
|
||||
}
|
||||
|
||||
return &service.OpsLatencyHistogramResponse{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
TotalRequests: total,
|
||||
Buckets: buckets,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type latencyHistogramBucket struct {
|
||||
upperMs int
|
||||
label string
|
||||
}
|
||||
|
||||
var latencyHistogramBuckets = []latencyHistogramBucket{
|
||||
{upperMs: 100, label: "0-100ms"},
|
||||
{upperMs: 200, label: "100-200ms"},
|
||||
{upperMs: 500, label: "200-500ms"},
|
||||
{upperMs: 1000, label: "500-1000ms"},
|
||||
{upperMs: 2000, label: "1000-2000ms"},
|
||||
{upperMs: 0, label: "2000ms+"}, // default bucket
|
||||
}
|
||||
|
||||
var latencyHistogramOrderedRanges = func() []string {
|
||||
out := make([]string, 0, len(latencyHistogramBuckets))
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
out = append(out, b.label)
|
||||
}
|
||||
return out
|
||||
}()
|
||||
|
||||
func latencyHistogramRangeCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN '%s'\n", column, b.upperMs, b.label)
|
||||
}
|
||||
|
||||
// Default bucket.
|
||||
last := latencyHistogramBuckets[len(latencyHistogramBuckets)-1]
|
||||
fmt.Fprintf(&sb, "\tELSE '%s'\n", last.label)
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
func latencyHistogramRangeOrderCaseExpr(column string) string {
|
||||
var sb strings.Builder
|
||||
_, _ = sb.WriteString("CASE\n")
|
||||
|
||||
order := 1
|
||||
for _, b := range latencyHistogramBuckets {
|
||||
if b.upperMs <= 0 {
|
||||
continue
|
||||
}
|
||||
fmt.Fprintf(&sb, "\tWHEN %s < %d THEN %d\n", column, b.upperMs, order)
|
||||
order++
|
||||
}
|
||||
|
||||
fmt.Fprintf(&sb, "\tELSE %d\n", order)
|
||||
_, _ = sb.WriteString("END")
|
||||
return sb.String()
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLatencyHistogramBuckets_AreConsistent(t *testing.T) {
|
||||
require.Equal(t, len(latencyHistogramBuckets), len(latencyHistogramOrderedRanges))
|
||||
for i, b := range latencyHistogramBuckets {
|
||||
require.Equal(t, b.label, latencyHistogramOrderedRanges[i])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,445 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) InsertSystemMetrics(ctx context.Context, input *service.OpsInsertSystemMetricsInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
|
||||
window := input.WindowMinutes
|
||||
if window <= 0 {
|
||||
window = 1
|
||||
}
|
||||
createdAt := input.CreatedAt
|
||||
if createdAt.IsZero() {
|
||||
createdAt = time.Now().UTC()
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_system_metrics (
|
||||
created_at,
|
||||
window_minutes,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
|
||||
token_consumed,
|
||||
account_switch_count,
|
||||
qps,
|
||||
tps,
|
||||
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,
|
||||
$5,$6,$7,$8,
|
||||
$9,$10,$11,
|
||||
$12,$13,$14,$15,
|
||||
$16,$17,$18,$19,$20,$21,
|
||||
$22,$23,$24,$25,$26,$27,
|
||||
$28,$29,$30,$31,
|
||||
$32,$33,
|
||||
$34,$35,
|
||||
$36,$37,$38,
|
||||
$39,$40
|
||||
)`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
createdAt,
|
||||
window,
|
||||
opsNullString(input.Platform),
|
||||
opsNullInt64(input.GroupID),
|
||||
|
||||
input.SuccessCount,
|
||||
input.ErrorCountTotal,
|
||||
input.BusinessLimitedCount,
|
||||
input.ErrorCountSLA,
|
||||
|
||||
input.UpstreamErrorCountExcl429529,
|
||||
input.Upstream429Count,
|
||||
input.Upstream529Count,
|
||||
|
||||
input.TokenConsumed,
|
||||
input.AccountSwitchCount,
|
||||
opsNullFloat64(input.QPS),
|
||||
opsNullFloat64(input.TPS),
|
||||
|
||||
opsNullInt(input.DurationP50Ms),
|
||||
opsNullInt(input.DurationP90Ms),
|
||||
opsNullInt(input.DurationP95Ms),
|
||||
opsNullInt(input.DurationP99Ms),
|
||||
opsNullFloat64(input.DurationAvgMs),
|
||||
opsNullInt(input.DurationMaxMs),
|
||||
|
||||
opsNullInt(input.TTFTP50Ms),
|
||||
opsNullInt(input.TTFTP90Ms),
|
||||
opsNullInt(input.TTFTP95Ms),
|
||||
opsNullInt(input.TTFTP99Ms),
|
||||
opsNullFloat64(input.TTFTAvgMs),
|
||||
opsNullInt(input.TTFTMaxMs),
|
||||
|
||||
opsNullFloat64(input.CPUUsagePercent),
|
||||
opsNullInt(input.MemoryUsedMB),
|
||||
opsNullInt(input.MemoryTotalMB),
|
||||
opsNullFloat64(input.MemoryUsagePercent),
|
||||
|
||||
opsNullBool(input.DBOK),
|
||||
opsNullBool(input.RedisOK),
|
||||
|
||||
opsNullInt(input.RedisConnTotal),
|
||||
opsNullInt(input.RedisConnIdle),
|
||||
|
||||
opsNullInt(input.DBConnActive),
|
||||
opsNullInt(input.DBConnIdle),
|
||||
opsNullInt(input.DBConnWaiting),
|
||||
|
||||
opsNullInt(input.GoroutineCount),
|
||||
opsNullInt(input.ConcurrencyQueueDepth),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestSystemMetrics(ctx context.Context, windowMinutes int) (*service.OpsSystemMetricsSnapshot, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if windowMinutes <= 0 {
|
||||
windowMinutes = 1
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
id,
|
||||
created_at,
|
||||
window_minutes,
|
||||
|
||||
cpu_usage_percent,
|
||||
memory_used_mb,
|
||||
memory_total_mb,
|
||||
memory_usage_percent,
|
||||
|
||||
db_ok,
|
||||
redis_ok,
|
||||
|
||||
redis_conn_total,
|
||||
redis_conn_idle,
|
||||
|
||||
db_conn_active,
|
||||
db_conn_idle,
|
||||
db_conn_waiting,
|
||||
|
||||
goroutine_count,
|
||||
concurrency_queue_depth,
|
||||
account_switch_count
|
||||
FROM ops_system_metrics
|
||||
WHERE window_minutes = $1
|
||||
AND platform IS NULL
|
||||
AND group_id IS NULL
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 1`
|
||||
|
||||
var out service.OpsSystemMetricsSnapshot
|
||||
var cpu sql.NullFloat64
|
||||
var memUsed sql.NullInt64
|
||||
var memTotal sql.NullInt64
|
||||
var memPct sql.NullFloat64
|
||||
var dbOK sql.NullBool
|
||||
var redisOK sql.NullBool
|
||||
var redisTotal sql.NullInt64
|
||||
var redisIdle sql.NullInt64
|
||||
var dbActive sql.NullInt64
|
||||
var dbIdle sql.NullInt64
|
||||
var dbWaiting sql.NullInt64
|
||||
var goroutines sql.NullInt64
|
||||
var queueDepth sql.NullInt64
|
||||
var accountSwitchCount sql.NullInt64
|
||||
|
||||
if err := r.db.QueryRowContext(ctx, q, windowMinutes).Scan(
|
||||
&out.ID,
|
||||
&out.CreatedAt,
|
||||
&out.WindowMinutes,
|
||||
&cpu,
|
||||
&memUsed,
|
||||
&memTotal,
|
||||
&memPct,
|
||||
&dbOK,
|
||||
&redisOK,
|
||||
&redisTotal,
|
||||
&redisIdle,
|
||||
&dbActive,
|
||||
&dbIdle,
|
||||
&dbWaiting,
|
||||
&goroutines,
|
||||
&queueDepth,
|
||||
&accountSwitchCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cpu.Valid {
|
||||
v := cpu.Float64
|
||||
out.CPUUsagePercent = &v
|
||||
}
|
||||
if memUsed.Valid {
|
||||
v := memUsed.Int64
|
||||
out.MemoryUsedMB = &v
|
||||
}
|
||||
if memTotal.Valid {
|
||||
v := memTotal.Int64
|
||||
out.MemoryTotalMB = &v
|
||||
}
|
||||
if memPct.Valid {
|
||||
v := memPct.Float64
|
||||
out.MemoryUsagePercent = &v
|
||||
}
|
||||
if dbOK.Valid {
|
||||
v := dbOK.Bool
|
||||
out.DBOK = &v
|
||||
}
|
||||
if redisOK.Valid {
|
||||
v := redisOK.Bool
|
||||
out.RedisOK = &v
|
||||
}
|
||||
if redisTotal.Valid {
|
||||
v := int(redisTotal.Int64)
|
||||
out.RedisConnTotal = &v
|
||||
}
|
||||
if redisIdle.Valid {
|
||||
v := int(redisIdle.Int64)
|
||||
out.RedisConnIdle = &v
|
||||
}
|
||||
if dbActive.Valid {
|
||||
v := int(dbActive.Int64)
|
||||
out.DBConnActive = &v
|
||||
}
|
||||
if dbIdle.Valid {
|
||||
v := int(dbIdle.Int64)
|
||||
out.DBConnIdle = &v
|
||||
}
|
||||
if dbWaiting.Valid {
|
||||
v := int(dbWaiting.Int64)
|
||||
out.DBConnWaiting = &v
|
||||
}
|
||||
if goroutines.Valid {
|
||||
v := int(goroutines.Int64)
|
||||
out.GoroutineCount = &v
|
||||
}
|
||||
if queueDepth.Valid {
|
||||
v := int(queueDepth.Int64)
|
||||
out.ConcurrencyQueueDepth = &v
|
||||
}
|
||||
if accountSwitchCount.Valid {
|
||||
v := accountSwitchCount.Int64
|
||||
out.AccountSwitchCount = &v
|
||||
}
|
||||
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertJobHeartbeat(ctx context.Context, input *service.OpsUpsertJobHeartbeatInput) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if input == nil {
|
||||
return fmt.Errorf("nil input")
|
||||
}
|
||||
if input.JobName == "" {
|
||||
return fmt.Errorf("job_name required")
|
||||
}
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_job_heartbeats (
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
last_result,
|
||||
updated_at
|
||||
) VALUES (
|
||||
$1,$2,$3,$4,$5,$6,$7,NOW()
|
||||
)
|
||||
ON CONFLICT (job_name) DO UPDATE SET
|
||||
last_run_at = COALESCE(EXCLUDED.last_run_at, ops_job_heartbeats.last_run_at),
|
||||
last_success_at = COALESCE(EXCLUDED.last_success_at, ops_job_heartbeats.last_success_at),
|
||||
last_error_at = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error_at, ops_job_heartbeats.last_error_at)
|
||||
END,
|
||||
last_error = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN NULL
|
||||
ELSE COALESCE(EXCLUDED.last_error, ops_job_heartbeats.last_error)
|
||||
END,
|
||||
last_duration_ms = COALESCE(EXCLUDED.last_duration_ms, ops_job_heartbeats.last_duration_ms),
|
||||
last_result = CASE
|
||||
WHEN EXCLUDED.last_success_at IS NOT NULL THEN COALESCE(EXCLUDED.last_result, ops_job_heartbeats.last_result)
|
||||
ELSE ops_job_heartbeats.last_result
|
||||
END,
|
||||
updated_at = NOW()`
|
||||
|
||||
_, err := r.db.ExecContext(
|
||||
ctx,
|
||||
q,
|
||||
input.JobName,
|
||||
opsNullTime(input.LastRunAt),
|
||||
opsNullTime(input.LastSuccessAt),
|
||||
opsNullTime(input.LastErrorAt),
|
||||
opsNullString(input.LastError),
|
||||
opsNullInt(input.LastDurationMs),
|
||||
opsNullString(input.LastResult),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) ListJobHeartbeats(ctx context.Context) ([]*service.OpsJobHeartbeat, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
job_name,
|
||||
last_run_at,
|
||||
last_success_at,
|
||||
last_error_at,
|
||||
last_error,
|
||||
last_duration_ms,
|
||||
last_result,
|
||||
updated_at
|
||||
FROM ops_job_heartbeats
|
||||
ORDER BY job_name ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]*service.OpsJobHeartbeat, 0, 8)
|
||||
for rows.Next() {
|
||||
var item service.OpsJobHeartbeat
|
||||
var lastRun sql.NullTime
|
||||
var lastSuccess sql.NullTime
|
||||
var lastErrorAt sql.NullTime
|
||||
var lastError sql.NullString
|
||||
var lastDuration sql.NullInt64
|
||||
|
||||
var lastResult sql.NullString
|
||||
|
||||
if err := rows.Scan(
|
||||
&item.JobName,
|
||||
&lastRun,
|
||||
&lastSuccess,
|
||||
&lastErrorAt,
|
||||
&lastError,
|
||||
&lastDuration,
|
||||
&lastResult,
|
||||
&item.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if lastRun.Valid {
|
||||
v := lastRun.Time
|
||||
item.LastRunAt = &v
|
||||
}
|
||||
if lastSuccess.Valid {
|
||||
v := lastSuccess.Time
|
||||
item.LastSuccessAt = &v
|
||||
}
|
||||
if lastErrorAt.Valid {
|
||||
v := lastErrorAt.Time
|
||||
item.LastErrorAt = &v
|
||||
}
|
||||
if lastError.Valid {
|
||||
v := lastError.String
|
||||
item.LastError = &v
|
||||
}
|
||||
if lastDuration.Valid {
|
||||
v := lastDuration.Int64
|
||||
item.LastDurationMs = &v
|
||||
}
|
||||
if lastResult.Valid {
|
||||
v := lastResult.String
|
||||
item.LastResult = &v
|
||||
}
|
||||
|
||||
out = append(out, &item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func opsNullBool(v *bool) any {
|
||||
if v == nil {
|
||||
return sql.NullBool{}
|
||||
}
|
||||
return sql.NullBool{Bool: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullFloat64(v *float64) any {
|
||||
if v == nil {
|
||||
return sql.NullFloat64{}
|
||||
}
|
||||
return sql.NullFloat64{Float64: *v, Valid: true}
|
||||
}
|
||||
|
||||
func opsNullTime(v *time.Time) any {
|
||||
if v == nil || v.IsZero() {
|
||||
return sql.NullTime{}
|
||||
}
|
||||
return sql.NullTime{Time: *v, Valid: true}
|
||||
}
|
||||
@@ -0,0 +1,145 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetOpenAITokenStats(ctx context.Context, filter *service.OpsOpenAITokenStatsFilter) (*service.OpsOpenAITokenStatsResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
// 允许 start_time == end_time(结果为空),与 service 层校验口径保持一致。
|
||||
if filter.StartTime.After(filter.EndTime) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
|
||||
dashboardFilter := &service.OpsDashboardFilter{
|
||||
StartTime: filter.StartTime.UTC(),
|
||||
EndTime: filter.EndTime.UTC(),
|
||||
Platform: strings.TrimSpace(strings.ToLower(filter.Platform)),
|
||||
GroupID: filter.GroupID,
|
||||
}
|
||||
|
||||
join, where, baseArgs, next := buildUsageWhere(dashboardFilter, dashboardFilter.StartTime, dashboardFilter.EndTime, 1)
|
||||
where += " AND ul.model LIKE 'gpt%'"
|
||||
|
||||
baseCTE := `
|
||||
WITH stats AS (
|
||||
SELECT
|
||||
ul.model AS model,
|
||||
COUNT(*)::bigint AS request_count,
|
||||
ROUND(
|
||||
AVG(
|
||||
CASE
|
||||
WHEN ul.duration_ms > 0 AND ul.output_tokens > 0
|
||||
THEN ul.output_tokens * 1000.0 / ul.duration_ms
|
||||
END
|
||||
)::numeric,
|
||||
2
|
||||
)::float8 AS avg_tokens_per_sec,
|
||||
ROUND(AVG(ul.first_token_ms)::numeric, 2)::float8 AS avg_first_token_ms,
|
||||
COALESCE(SUM(ul.output_tokens), 0)::bigint AS total_output_tokens,
|
||||
COALESCE(ROUND(AVG(ul.duration_ms)::numeric, 0), 0)::bigint AS avg_duration_ms,
|
||||
COUNT(CASE WHEN ul.first_token_ms IS NOT NULL THEN 1 END)::bigint AS requests_with_first_token
|
||||
FROM usage_logs ul
|
||||
` + join + `
|
||||
` + where + `
|
||||
GROUP BY ul.model
|
||||
)
|
||||
`
|
||||
|
||||
countSQL := baseCTE + `SELECT COUNT(*) FROM stats`
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countSQL, baseArgs...).Scan(&total); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
querySQL := baseCTE + `
|
||||
SELECT
|
||||
model,
|
||||
request_count,
|
||||
avg_tokens_per_sec,
|
||||
avg_first_token_ms,
|
||||
total_output_tokens,
|
||||
avg_duration_ms,
|
||||
requests_with_first_token
|
||||
FROM stats
|
||||
ORDER BY request_count DESC, model ASC`
|
||||
|
||||
args := make([]any, 0, len(baseArgs)+2)
|
||||
args = append(args, baseArgs...)
|
||||
|
||||
if filter.IsTopNMode() {
|
||||
querySQL += fmt.Sprintf("\nLIMIT $%d", next)
|
||||
args = append(args, filter.TopN)
|
||||
} else {
|
||||
offset := (filter.Page - 1) * filter.PageSize
|
||||
querySQL += fmt.Sprintf("\nLIMIT $%d OFFSET $%d", next, next+1)
|
||||
args = append(args, filter.PageSize, offset)
|
||||
}
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, querySQL, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsOpenAITokenStatsItem, 0, 32)
|
||||
for rows.Next() {
|
||||
item := &service.OpsOpenAITokenStatsItem{}
|
||||
var avgTPS sql.NullFloat64
|
||||
var avgFirstToken sql.NullFloat64
|
||||
if err := rows.Scan(
|
||||
&item.Model,
|
||||
&item.RequestCount,
|
||||
&avgTPS,
|
||||
&avgFirstToken,
|
||||
&item.TotalOutputTokens,
|
||||
&item.AvgDurationMs,
|
||||
&item.RequestsWithFirstToken,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if avgTPS.Valid {
|
||||
v := avgTPS.Float64
|
||||
item.AvgTokensPerSec = &v
|
||||
}
|
||||
if avgFirstToken.Valid {
|
||||
v := avgFirstToken.Float64
|
||||
item.AvgFirstTokenMs = &v
|
||||
}
|
||||
items = append(items, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp := &service.OpsOpenAITokenStatsResponse{
|
||||
TimeRange: strings.TrimSpace(filter.TimeRange),
|
||||
StartTime: dashboardFilter.StartTime,
|
||||
EndTime: dashboardFilter.EndTime,
|
||||
Platform: dashboardFilter.Platform,
|
||||
GroupID: dashboardFilter.GroupID,
|
||||
Items: items,
|
||||
Total: total,
|
||||
}
|
||||
if filter.IsTopNMode() {
|
||||
topN := filter.TopN
|
||||
resp.TopN = &topN
|
||||
} else {
|
||||
resp.Page = filter.Page
|
||||
resp.PageSize = filter.PageSize
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_PaginationMode(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(24 * time.Hour)
|
||||
groupID := int64(9)
|
||||
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "1d",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: " OpenAI ",
|
||||
GroupID: &groupID,
|
||||
Page: 2,
|
||||
PageSize: 10,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end, groupID, "openai").
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(3)))
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}).
|
||||
AddRow("gpt-4o-mini", int64(20), 21.56, 120.34, int64(3000), int64(850), int64(18)).
|
||||
AddRow("gpt-4.1", int64(20), 10.2, 240.0, int64(2500), int64(900), int64(20))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$5 OFFSET \$6`).
|
||||
WithArgs(start, end, groupID, "openai", 10, 10).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(3), resp.Total)
|
||||
require.Equal(t, 2, resp.Page)
|
||||
require.Equal(t, 10, resp.PageSize)
|
||||
require.Nil(t, resp.TopN)
|
||||
require.Equal(t, "openai", resp.Platform)
|
||||
require.NotNil(t, resp.GroupID)
|
||||
require.Equal(t, groupID, *resp.GroupID)
|
||||
require.Len(t, resp.Items, 2)
|
||||
require.Equal(t, "gpt-4o-mini", resp.Items[0].Model)
|
||||
require.NotNil(t, resp.Items[0].AvgTokensPerSec)
|
||||
require.InDelta(t, 21.56, *resp.Items[0].AvgTokensPerSec, 0.0001)
|
||||
require.NotNil(t, resp.Items[0].AvgFirstTokenMs)
|
||||
require.InDelta(t, 120.34, *resp.Items[0].AvgFirstTokenMs, 0.0001)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_TopNMode(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 1, 10, 0, 0, 0, time.UTC)
|
||||
end := start.Add(time.Hour)
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "1h",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
TopN: 5,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1)))
|
||||
|
||||
rows := sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}).
|
||||
AddRow("gpt-4o", int64(5), nil, nil, int64(0), int64(0), int64(0))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3`).
|
||||
WithArgs(start, end, 5).
|
||||
WillReturnRows(rows)
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.NotNil(t, resp.TopN)
|
||||
require.Equal(t, 5, *resp.TopN)
|
||||
require.Equal(t, 0, resp.Page)
|
||||
require.Equal(t, 0, resp.PageSize)
|
||||
require.Len(t, resp.Items, 1)
|
||||
require.Nil(t, resp.Items[0].AvgTokensPerSec)
|
||||
require.Nil(t, resp.Items[0].AvgFirstTokenMs)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
|
||||
func TestOpsRepositoryGetOpenAITokenStats_EmptyResult(t *testing.T) {
|
||||
db, mock := newSQLMock(t)
|
||||
repo := &opsRepository{db: db}
|
||||
|
||||
start := time.Date(2026, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
end := start.Add(30 * time.Minute)
|
||||
filter := &service.OpsOpenAITokenStatsFilter{
|
||||
TimeRange: "30m",
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Page: 1,
|
||||
PageSize: 20,
|
||||
}
|
||||
|
||||
mock.ExpectQuery(`SELECT COUNT\(\*\) FROM stats`).
|
||||
WithArgs(start, end).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0)))
|
||||
|
||||
mock.ExpectQuery(`ORDER BY request_count DESC, model ASC\s+LIMIT \$3 OFFSET \$4`).
|
||||
WithArgs(start, end, 20, 0).
|
||||
WillReturnRows(sqlmock.NewRows([]string{
|
||||
"model",
|
||||
"request_count",
|
||||
"avg_tokens_per_sec",
|
||||
"avg_first_token_ms",
|
||||
"total_output_tokens",
|
||||
"avg_duration_ms",
|
||||
"requests_with_first_token",
|
||||
}))
|
||||
|
||||
resp, err := repo.GetOpenAITokenStats(context.Background(), filter)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
require.Equal(t, int64(0), resp.Total)
|
||||
require.Len(t, resp.Items, 0)
|
||||
require.Equal(t, 1, resp.Page)
|
||||
require.Equal(t, 20, resp.PageSize)
|
||||
|
||||
require.NoError(t, mock.ExpectationsWereMet())
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (r *opsRepository) UpsertHourlyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
// NOTE:
|
||||
// - We aggregate usage_logs + ops_error_logs into ops_metrics_hourly.
|
||||
// - We emit three dimension granularities via GROUPING SETS:
|
||||
// 1) overall: (bucket_start)
|
||||
// 2) platform: (bucket_start, platform)
|
||||
// 3) group: (bucket_start, platform, group_id)
|
||||
//
|
||||
// IMPORTANT: Postgres UNIQUE treats NULLs as distinct, so the table uses a COALESCE-based
|
||||
// unique index; our ON CONFLICT target must match that expression set.
|
||||
q := `
|
||||
WITH usage_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', ul.created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
g.platform AS platform,
|
||||
ul.group_id AS group_id,
|
||||
ul.duration_ms AS duration_ms,
|
||||
ul.first_token_ms AS first_token_ms,
|
||||
(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens) AS tokens
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
),
|
||||
usage_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(tokens), 0) AS token_consumed,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99_ms,
|
||||
AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg_ms,
|
||||
MAX(duration_ms) AS duration_max_ms,
|
||||
|
||||
percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50_ms,
|
||||
percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90_ms,
|
||||
percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95_ms,
|
||||
percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99_ms,
|
||||
AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg_ms,
|
||||
MAX(first_token_ms) AS ttft_max_ms
|
||||
FROM usage_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
),
|
||||
error_base AS (
|
||||
SELECT
|
||||
date_trunc('hour', created_at AT TIME ZONE 'UTC') AT TIME ZONE 'UTC' AS bucket_start,
|
||||
-- platform is NULL for some early-phase errors (e.g. before routing); map to a sentinel
|
||||
-- value so platform-level GROUPING SETS don't collide with the overall (platform=NULL) row.
|
||||
COALESCE(platform, 'unknown') AS platform,
|
||||
group_id AS group_id,
|
||||
is_business_limited AS is_business_limited,
|
||||
error_owner AS error_owner,
|
||||
status_code AS client_status_code,
|
||||
COALESCE(upstream_status_code, status_code, 0) AS effective_status_code
|
||||
FROM ops_error_logs
|
||||
-- Exclude count_tokens requests from error metrics as they are informational probes
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND is_count_tokens = FALSE
|
||||
),
|
||||
error_agg AS (
|
||||
SELECT
|
||||
bucket_start,
|
||||
CASE WHEN GROUPING(platform) = 1 THEN NULL ELSE platform END AS platform,
|
||||
CASE WHEN GROUPING(group_id) = 1 THEN NULL ELSE group_id END AS group_id,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400) AS error_count_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND is_business_limited) AS business_limited_count,
|
||||
COUNT(*) FILTER (WHERE COALESCE(client_status_code, 0) >= 400 AND NOT is_business_limited) AS error_count_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) NOT IN (429, 529)) AS upstream_error_count_excl_429_529,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 429) AS upstream_429_count,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(effective_status_code, 0) = 529) AS upstream_529_count
|
||||
FROM error_base
|
||||
GROUP BY GROUPING SETS (
|
||||
(bucket_start),
|
||||
(bucket_start, platform),
|
||||
(bucket_start, platform, group_id)
|
||||
)
|
||||
HAVING GROUPING(group_id) = 1 OR group_id IS NOT NULL
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket_start, e.bucket_start) AS bucket_start,
|
||||
COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count_total, 0) AS error_count_total,
|
||||
COALESCE(e.business_limited_count, 0) AS business_limited_count,
|
||||
COALESCE(e.error_count_sla, 0) AS error_count_sla,
|
||||
COALESCE(e.upstream_error_count_excl_429_529, 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(e.upstream_429_count, 0) AS upstream_429_count,
|
||||
COALESCE(e.upstream_529_count, 0) AS upstream_529_count,
|
||||
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed,
|
||||
|
||||
u.duration_p50_ms,
|
||||
u.duration_p90_ms,
|
||||
u.duration_p95_ms,
|
||||
u.duration_p99_ms,
|
||||
u.duration_avg_ms,
|
||||
u.duration_max_ms,
|
||||
|
||||
u.ttft_p50_ms,
|
||||
u.ttft_p90_ms,
|
||||
u.ttft_p95_ms,
|
||||
u.ttft_p99_ms,
|
||||
u.ttft_avg_ms,
|
||||
u.ttft_max_ms
|
||||
FROM usage_agg u
|
||||
FULL OUTER JOIN error_agg e
|
||||
ON u.bucket_start = e.bucket_start
|
||||
AND COALESCE(u.platform, '') = COALESCE(e.platform, '')
|
||||
AND COALESCE(u.group_id, 0) = COALESCE(e.group_id, 0)
|
||||
)
|
||||
INSERT INTO ops_metrics_hourly (
|
||||
bucket_start,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
bucket_start,
|
||||
NULLIF(platform, '') AS platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms::int,
|
||||
duration_p90_ms::int,
|
||||
duration_p95_ms::int,
|
||||
duration_p99_ms::int,
|
||||
duration_avg_ms,
|
||||
duration_max_ms::int,
|
||||
ttft_p50_ms::int,
|
||||
ttft_p90_ms::int,
|
||||
ttft_p95_ms::int,
|
||||
ttft_p99_ms::int,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms::int,
|
||||
NOW()
|
||||
FROM combined
|
||||
WHERE bucket_start IS NOT NULL
|
||||
AND (platform IS NULL OR platform <> '')
|
||||
ON CONFLICT (bucket_start, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) UpsertDailyMetrics(ctx context.Context, startTime, endTime time.Time) error {
|
||||
if r == nil || r.db == nil {
|
||||
return fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if startTime.IsZero() || endTime.IsZero() || !endTime.After(startTime) {
|
||||
return nil
|
||||
}
|
||||
|
||||
start := startTime.UTC()
|
||||
end := endTime.UTC()
|
||||
|
||||
q := `
|
||||
INSERT INTO ops_metrics_daily (
|
||||
bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
success_count,
|
||||
error_count_total,
|
||||
business_limited_count,
|
||||
error_count_sla,
|
||||
upstream_error_count_excl_429_529,
|
||||
upstream_429_count,
|
||||
upstream_529_count,
|
||||
token_consumed,
|
||||
duration_p50_ms,
|
||||
duration_p90_ms,
|
||||
duration_p95_ms,
|
||||
duration_p99_ms,
|
||||
duration_avg_ms,
|
||||
duration_max_ms,
|
||||
ttft_p50_ms,
|
||||
ttft_p90_ms,
|
||||
ttft_p95_ms,
|
||||
ttft_p99_ms,
|
||||
ttft_avg_ms,
|
||||
ttft_max_ms,
|
||||
computed_at
|
||||
)
|
||||
SELECT
|
||||
(bucket_start AT TIME ZONE 'UTC')::date AS bucket_date,
|
||||
platform,
|
||||
group_id,
|
||||
|
||||
COALESCE(SUM(success_count), 0) AS success_count,
|
||||
COALESCE(SUM(error_count_total), 0) AS error_count_total,
|
||||
COALESCE(SUM(business_limited_count), 0) AS business_limited_count,
|
||||
COALESCE(SUM(error_count_sla), 0) AS error_count_sla,
|
||||
COALESCE(SUM(upstream_error_count_excl_429_529), 0) AS upstream_error_count_excl_429_529,
|
||||
COALESCE(SUM(upstream_429_count), 0) AS upstream_429_count,
|
||||
COALESCE(SUM(upstream_529_count), 0) AS upstream_529_count,
|
||||
COALESCE(SUM(token_consumed), 0) AS token_consumed,
|
||||
|
||||
-- Approximation: weighted average for p50/p90, max for p95/p99 (conservative tail).
|
||||
ROUND(SUM(duration_p50_ms::double precision * success_count) FILTER (WHERE duration_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p50_ms IS NOT NULL), 0))::int AS duration_p50_ms,
|
||||
ROUND(SUM(duration_p90_ms::double precision * success_count) FILTER (WHERE duration_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_p90_ms IS NOT NULL), 0))::int AS duration_p90_ms,
|
||||
MAX(duration_p95_ms) AS duration_p95_ms,
|
||||
MAX(duration_p99_ms) AS duration_p99_ms,
|
||||
SUM(duration_avg_ms * success_count) FILTER (WHERE duration_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE duration_avg_ms IS NOT NULL), 0) AS duration_avg_ms,
|
||||
MAX(duration_max_ms) AS duration_max_ms,
|
||||
|
||||
ROUND(SUM(ttft_p50_ms::double precision * success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p50_ms IS NOT NULL), 0))::int AS ttft_p50_ms,
|
||||
ROUND(SUM(ttft_p90_ms::double precision * success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_p90_ms IS NOT NULL), 0))::int AS ttft_p90_ms,
|
||||
MAX(ttft_p95_ms) AS ttft_p95_ms,
|
||||
MAX(ttft_p99_ms) AS ttft_p99_ms,
|
||||
SUM(ttft_avg_ms * success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL)
|
||||
/ NULLIF(SUM(success_count) FILTER (WHERE ttft_avg_ms IS NOT NULL), 0) AS ttft_avg_ms,
|
||||
MAX(ttft_max_ms) AS ttft_max_ms,
|
||||
|
||||
NOW()
|
||||
FROM ops_metrics_hourly
|
||||
WHERE bucket_start >= $1 AND bucket_start < $2
|
||||
GROUP BY 1, 2, 3
|
||||
ON CONFLICT (bucket_date, COALESCE(platform, ''), COALESCE(group_id, 0)) DO UPDATE SET
|
||||
success_count = EXCLUDED.success_count,
|
||||
error_count_total = EXCLUDED.error_count_total,
|
||||
business_limited_count = EXCLUDED.business_limited_count,
|
||||
error_count_sla = EXCLUDED.error_count_sla,
|
||||
upstream_error_count_excl_429_529 = EXCLUDED.upstream_error_count_excl_429_529,
|
||||
upstream_429_count = EXCLUDED.upstream_429_count,
|
||||
upstream_529_count = EXCLUDED.upstream_529_count,
|
||||
token_consumed = EXCLUDED.token_consumed,
|
||||
|
||||
duration_p50_ms = EXCLUDED.duration_p50_ms,
|
||||
duration_p90_ms = EXCLUDED.duration_p90_ms,
|
||||
duration_p95_ms = EXCLUDED.duration_p95_ms,
|
||||
duration_p99_ms = EXCLUDED.duration_p99_ms,
|
||||
duration_avg_ms = EXCLUDED.duration_avg_ms,
|
||||
duration_max_ms = EXCLUDED.duration_max_ms,
|
||||
|
||||
ttft_p50_ms = EXCLUDED.ttft_p50_ms,
|
||||
ttft_p90_ms = EXCLUDED.ttft_p90_ms,
|
||||
ttft_p95_ms = EXCLUDED.ttft_p95_ms,
|
||||
ttft_p99_ms = EXCLUDED.ttft_p99_ms,
|
||||
ttft_avg_ms = EXCLUDED.ttft_avg_ms,
|
||||
ttft_max_ms = EXCLUDED.ttft_max_ms,
|
||||
|
||||
computed_at = NOW()
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, q, start, end)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestHourlyBucketStart(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_start) FROM ops_metrics_hourly`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
return value.Time.UTC(), true, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetLatestDailyBucketDate(ctx context.Context) (time.Time, bool, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return time.Time{}, false, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
var value sql.NullTime
|
||||
if err := r.db.QueryRowContext(ctx, `SELECT MAX(bucket_date) FROM ops_metrics_daily`).Scan(&value); err != nil {
|
||||
return time.Time{}, false, err
|
||||
}
|
||||
if !value.Valid {
|
||||
return time.Time{}, false, nil
|
||||
}
|
||||
t := value.Time.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), true, nil
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetRealtimeTrafficSummary(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsRealtimeTrafficSummary, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
|
||||
window := end.Sub(start)
|
||||
if window <= 0 {
|
||||
return nil, fmt.Errorf("invalid time window")
|
||||
}
|
||||
if window > time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', ul.created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_sum
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT
|
||||
date_trunc('minute', created_at) AS bucket,
|
||||
COALESCE(COUNT(*), 0) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
COALESCE(u.bucket, e.bucket) AS bucket,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(u.token_sum, 0) AS token_sum,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.success_count, 0) + COALESCE(e.error_count, 0) AS request_total
|
||||
FROM usage_buckets u
|
||||
FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket
|
||||
)
|
||||
SELECT
|
||||
COALESCE(SUM(success_count), 0) AS success_total,
|
||||
COALESCE(SUM(error_count), 0) AS error_total,
|
||||
COALESCE(SUM(token_sum), 0) AS token_total,
|
||||
COALESCE(MAX(request_total), 0) AS peak_requests_per_min,
|
||||
COALESCE(MAX(token_sum), 0) AS peak_tokens_per_min
|
||||
FROM combined`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
var successCount int64
|
||||
var errorTotal int64
|
||||
var tokenConsumed int64
|
||||
var peakRequestsPerMin int64
|
||||
var peakTokensPerMin int64
|
||||
if err := r.db.QueryRowContext(ctx, q, args...).Scan(
|
||||
&successCount,
|
||||
&errorTotal,
|
||||
&tokenConsumed,
|
||||
&peakRequestsPerMin,
|
||||
&peakTokensPerMin,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
windowSeconds := window.Seconds()
|
||||
if windowSeconds <= 0 {
|
||||
windowSeconds = 1
|
||||
}
|
||||
|
||||
requestCountTotal := successCount + errorTotal
|
||||
qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds)
|
||||
tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds)
|
||||
|
||||
// Keep "current" consistent with the dashboard overview semantics: last 1 minute.
|
||||
// This remains "within the selected window" since end=start+window.
|
||||
qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qpsPeak := roundTo1DP(float64(peakRequestsPerMin) / 60.0)
|
||||
tpsPeak := roundTo1DP(float64(peakTokensPerMin) / 60.0)
|
||||
|
||||
return &service.OpsRealtimeTrafficSummary{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
Platform: strings.TrimSpace(filter.Platform),
|
||||
GroupID: filter.GroupID,
|
||||
QPS: service.OpsRateSummary{
|
||||
Current: qpsCurrent,
|
||||
Peak: qpsPeak,
|
||||
Avg: qpsAvg,
|
||||
},
|
||||
TPS: service.OpsRateSummary{
|
||||
Current: tpsCurrent,
|
||||
Peak: tpsPeak,
|
||||
Avg: tpsAvg,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,286 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) ListRequestDetails(ctx context.Context, filter *service.OpsRequestDetailFilter) ([]*service.OpsRequestDetail, int64, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, 0, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
|
||||
page, pageSize, startTime, endTime := filter.Normalize()
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
conditions := make([]string, 0, 16)
|
||||
args := make([]any, 0, 24)
|
||||
|
||||
// Placeholders $1/$2 reserved for time window inside the CTE.
|
||||
args = append(args, startTime.UTC(), endTime.UTC())
|
||||
|
||||
addCondition := func(condition string, values ...any) {
|
||||
conditions = append(conditions, condition)
|
||||
args = append(args, values...)
|
||||
}
|
||||
|
||||
if filter != nil {
|
||||
if kind := strings.TrimSpace(strings.ToLower(filter.Kind)); kind != "" && kind != "all" {
|
||||
if kind != string(service.OpsRequestKindSuccess) && kind != string(service.OpsRequestKindError) {
|
||||
return nil, 0, fmt.Errorf("invalid kind")
|
||||
}
|
||||
addCondition(fmt.Sprintf("kind = $%d", len(args)+1), kind)
|
||||
}
|
||||
|
||||
if platform := strings.TrimSpace(strings.ToLower(filter.Platform)); platform != "" {
|
||||
addCondition(fmt.Sprintf("platform = $%d", len(args)+1), platform)
|
||||
}
|
||||
if filter.GroupID != nil && *filter.GroupID > 0 {
|
||||
addCondition(fmt.Sprintf("group_id = $%d", len(args)+1), *filter.GroupID)
|
||||
}
|
||||
|
||||
if filter.UserID != nil && *filter.UserID > 0 {
|
||||
addCondition(fmt.Sprintf("user_id = $%d", len(args)+1), *filter.UserID)
|
||||
}
|
||||
if filter.APIKeyID != nil && *filter.APIKeyID > 0 {
|
||||
addCondition(fmt.Sprintf("api_key_id = $%d", len(args)+1), *filter.APIKeyID)
|
||||
}
|
||||
if filter.AccountID != nil && *filter.AccountID > 0 {
|
||||
addCondition(fmt.Sprintf("account_id = $%d", len(args)+1), *filter.AccountID)
|
||||
}
|
||||
|
||||
if model := strings.TrimSpace(filter.Model); model != "" {
|
||||
addCondition(fmt.Sprintf("model = $%d", len(args)+1), model)
|
||||
}
|
||||
if requestID := strings.TrimSpace(filter.RequestID); requestID != "" {
|
||||
addCondition(fmt.Sprintf("request_id = $%d", len(args)+1), requestID)
|
||||
}
|
||||
if q := strings.TrimSpace(filter.Query); q != "" {
|
||||
like := "%" + strings.ToLower(q) + "%"
|
||||
startIdx := len(args) + 1
|
||||
addCondition(
|
||||
fmt.Sprintf("(LOWER(COALESCE(request_id,'')) LIKE $%d OR LOWER(COALESCE(model,'')) LIKE $%d OR LOWER(COALESCE(message,'')) LIKE $%d)",
|
||||
startIdx, startIdx+1, startIdx+2,
|
||||
),
|
||||
like, like, like,
|
||||
)
|
||||
}
|
||||
|
||||
if filter.MinDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms >= $%d", len(args)+1), *filter.MinDurationMs)
|
||||
}
|
||||
if filter.MaxDurationMs != nil {
|
||||
addCondition(fmt.Sprintf("duration_ms <= $%d", len(args)+1), *filter.MaxDurationMs)
|
||||
}
|
||||
}
|
||||
|
||||
where := ""
|
||||
if len(conditions) > 0 {
|
||||
where = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
cte := `
|
||||
WITH combined AS (
|
||||
SELECT
|
||||
'success'::TEXT AS kind,
|
||||
ul.created_at AS created_at,
|
||||
ul.request_id AS request_id,
|
||||
COALESCE(NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
ul.model AS model,
|
||||
ul.duration_ms AS duration_ms,
|
||||
NULL::INT AS status_code,
|
||||
NULL::BIGINT AS error_id,
|
||||
NULL::TEXT AS phase,
|
||||
NULL::TEXT AS severity,
|
||||
NULL::TEXT AS message,
|
||||
ul.user_id AS user_id,
|
||||
ul.api_key_id AS api_key_id,
|
||||
ul.account_id AS account_id,
|
||||
ul.group_id AS group_id,
|
||||
ul.stream AS stream
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
|
||||
UNION ALL
|
||||
|
||||
SELECT
|
||||
'error'::TEXT AS kind,
|
||||
o.created_at AS created_at,
|
||||
COALESCE(NULLIF(o.request_id,''), NULLIF(o.client_request_id,''), '') AS request_id,
|
||||
COALESCE(NULLIF(o.platform, ''), NULLIF(g.platform, ''), NULLIF(a.platform, ''), '') AS platform,
|
||||
o.model AS model,
|
||||
o.duration_ms AS duration_ms,
|
||||
o.status_code AS status_code,
|
||||
o.id AS error_id,
|
||||
o.error_phase AS phase,
|
||||
o.severity AS severity,
|
||||
o.error_message AS message,
|
||||
o.user_id AS user_id,
|
||||
o.api_key_id AS api_key_id,
|
||||
o.account_id AS account_id,
|
||||
o.group_id AS group_id,
|
||||
o.stream AS stream
|
||||
FROM ops_error_logs o
|
||||
LEFT JOIN groups g ON g.id = o.group_id
|
||||
LEFT JOIN accounts a ON a.id = o.account_id
|
||||
WHERE o.created_at >= $1 AND o.created_at < $2
|
||||
AND COALESCE(o.status_code, 0) >= 400
|
||||
)
|
||||
`
|
||||
|
||||
countQuery := fmt.Sprintf(`%s SELECT COUNT(1) FROM combined %s`, cte, where)
|
||||
var total int64
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
total = 0
|
||||
} else {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
sort := "ORDER BY created_at DESC"
|
||||
if filter != nil {
|
||||
switch strings.TrimSpace(strings.ToLower(filter.Sort)) {
|
||||
case "", "created_at_desc":
|
||||
// default
|
||||
case "duration_desc":
|
||||
sort = "ORDER BY duration_ms DESC NULLS LAST, created_at DESC"
|
||||
default:
|
||||
return nil, 0, fmt.Errorf("invalid sort")
|
||||
}
|
||||
}
|
||||
|
||||
listQuery := fmt.Sprintf(`
|
||||
%s
|
||||
SELECT
|
||||
kind,
|
||||
created_at,
|
||||
request_id,
|
||||
platform,
|
||||
model,
|
||||
duration_ms,
|
||||
status_code,
|
||||
error_id,
|
||||
phase,
|
||||
severity,
|
||||
message,
|
||||
user_id,
|
||||
api_key_id,
|
||||
account_id,
|
||||
group_id,
|
||||
stream
|
||||
FROM combined
|
||||
%s
|
||||
%s
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, cte, where, sort, len(args)+1, len(args)+2)
|
||||
|
||||
listArgs := append(append([]any{}, args...), pageSize, offset)
|
||||
rows, err := r.db.QueryContext(ctx, listQuery, listArgs...)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
toIntPtr := func(v sql.NullInt64) *int {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := int(v.Int64)
|
||||
return &i
|
||||
}
|
||||
toInt64Ptr := func(v sql.NullInt64) *int64 {
|
||||
if !v.Valid {
|
||||
return nil
|
||||
}
|
||||
i := v.Int64
|
||||
return &i
|
||||
}
|
||||
|
||||
out := make([]*service.OpsRequestDetail, 0, pageSize)
|
||||
for rows.Next() {
|
||||
var (
|
||||
kind string
|
||||
createdAt time.Time
|
||||
requestID sql.NullString
|
||||
platform sql.NullString
|
||||
model sql.NullString
|
||||
|
||||
durationMs sql.NullInt64
|
||||
statusCode sql.NullInt64
|
||||
errorID sql.NullInt64
|
||||
|
||||
phase sql.NullString
|
||||
severity sql.NullString
|
||||
message sql.NullString
|
||||
|
||||
userID sql.NullInt64
|
||||
apiKeyID sql.NullInt64
|
||||
accountID sql.NullInt64
|
||||
groupID sql.NullInt64
|
||||
|
||||
stream bool
|
||||
)
|
||||
|
||||
if err := rows.Scan(
|
||||
&kind,
|
||||
&createdAt,
|
||||
&requestID,
|
||||
&platform,
|
||||
&model,
|
||||
&durationMs,
|
||||
&statusCode,
|
||||
&errorID,
|
||||
&phase,
|
||||
&severity,
|
||||
&message,
|
||||
&userID,
|
||||
&apiKeyID,
|
||||
&accountID,
|
||||
&groupID,
|
||||
&stream,
|
||||
); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
item := &service.OpsRequestDetail{
|
||||
Kind: service.OpsRequestKind(kind),
|
||||
CreatedAt: createdAt,
|
||||
RequestID: strings.TrimSpace(requestID.String),
|
||||
Platform: strings.TrimSpace(platform.String),
|
||||
Model: strings.TrimSpace(model.String),
|
||||
|
||||
DurationMs: toIntPtr(durationMs),
|
||||
StatusCode: toIntPtr(statusCode),
|
||||
ErrorID: toInt64Ptr(errorID),
|
||||
Phase: phase.String,
|
||||
Severity: severity.String,
|
||||
Message: message.String,
|
||||
|
||||
UserID: toInt64Ptr(userID),
|
||||
APIKeyID: toInt64Ptr(apiKeyID),
|
||||
AccountID: toInt64Ptr(accountID),
|
||||
GroupID: toInt64Ptr(groupID),
|
||||
|
||||
Stream: stream,
|
||||
}
|
||||
|
||||
if item.Platform == "" {
|
||||
item.Platform = "unknown"
|
||||
}
|
||||
|
||||
out = append(out, item)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return out, total, nil
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func TestBuildOpsSystemLogsWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
start := time.Date(2026, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
end := time.Date(2026, 2, 2, 0, 0, 0, 0, time.UTC)
|
||||
userID := int64(12)
|
||||
accountID := int64(34)
|
||||
|
||||
filter := &service.OpsSystemLogFilter{
|
||||
StartTime: &start,
|
||||
EndTime: &end,
|
||||
Level: "warn",
|
||||
Component: "http.access",
|
||||
RequestID: "req-1",
|
||||
ClientRequestID: "creq-1",
|
||||
UserID: &userID,
|
||||
AccountID: &accountID,
|
||||
Platform: "openai",
|
||||
Model: "gpt-5",
|
||||
Query: "timeout",
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 11 {
|
||||
t.Fatalf("args len = %d, want 11", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_RequireConstraint(t *testing.T) {
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(&service.OpsSystemLogCleanupFilter{})
|
||||
if hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=false")
|
||||
}
|
||||
if where == "" {
|
||||
t.Fatalf("where should not be empty")
|
||||
}
|
||||
if len(args) != 0 {
|
||||
t.Fatalf("args len = %d, want 0", len(args))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpsSystemLogsCleanupWhere_WithClientRequestIDAndUserID(t *testing.T) {
|
||||
userID := int64(9)
|
||||
filter := &service.OpsSystemLogCleanupFilter{
|
||||
ClientRequestID: "creq-9",
|
||||
UserID: &userID,
|
||||
}
|
||||
|
||||
where, args, hasConstraint := buildOpsSystemLogsCleanupWhere(filter)
|
||||
if !hasConstraint {
|
||||
t.Fatalf("expected hasConstraint=true")
|
||||
}
|
||||
if len(args) != 2 {
|
||||
t.Fatalf("args len = %d, want 2", len(args))
|
||||
}
|
||||
if !contains(where, "COALESCE(l.client_request_id,'') = $") {
|
||||
t.Fatalf("where should include client_request_id condition: %s", where)
|
||||
}
|
||||
if !contains(where, "l.user_id = $") {
|
||||
t.Fatalf("where should include user_id condition: %s", where)
|
||||
}
|
||||
}
|
||||
|
||||
func contains(s string, sub string) bool {
|
||||
return strings.Contains(s, sub)
|
||||
}
|
||||
@@ -0,0 +1,606 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetThroughputTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsThroughputTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
// Keep a small, predictable set of supported buckets for now.
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
|
||||
usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1)
|
||||
errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next)
|
||||
|
||||
usageBucketExpr := opsBucketExprForUsage(bucketSeconds)
|
||||
errorBucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
WITH usage_buckets AS (
|
||||
SELECT ` + usageBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
` + usageJoin + `
|
||||
` + usageWhere + `
|
||||
GROUP BY 1
|
||||
),
|
||||
error_buckets AS (
|
||||
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
` + errorWhere + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
),
|
||||
switch_buckets AS (
|
||||
SELECT ` + errorBucketExpr + ` AS bucket,
|
||||
COALESCE(SUM(CASE
|
||||
WHEN split_part(ev->>'kind', ':', 1) IN ('failover', 'retry_exhausted_failover', 'failover_on_400') THEN 1
|
||||
ELSE 0
|
||||
END), 0) AS switch_count
|
||||
FROM ops_error_logs
|
||||
CROSS JOIN LATERAL jsonb_array_elements(
|
||||
COALESCE(NULLIF(upstream_errors, 'null'::jsonb), '[]'::jsonb)
|
||||
) AS ev
|
||||
` + errorWhere + `
|
||||
AND upstream_errors IS NOT NULL
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT
|
||||
bucket,
|
||||
SUM(success_count) AS success_count,
|
||||
SUM(error_count) AS error_count,
|
||||
SUM(token_consumed) AS token_consumed,
|
||||
SUM(switch_count) AS switch_count
|
||||
FROM (
|
||||
SELECT bucket, success_count, 0 AS error_count, token_consumed, 0 AS switch_count
|
||||
FROM usage_buckets
|
||||
UNION ALL
|
||||
SELECT bucket, 0, error_count, 0, 0
|
||||
FROM error_buckets
|
||||
UNION ALL
|
||||
SELECT bucket, 0, 0, 0, switch_count
|
||||
FROM switch_buckets
|
||||
) t
|
||||
GROUP BY bucket
|
||||
)
|
||||
SELECT
|
||||
bucket,
|
||||
(success_count + error_count) AS request_count,
|
||||
token_consumed,
|
||||
switch_count
|
||||
FROM combined
|
||||
ORDER BY bucket ASC`
|
||||
|
||||
args := append(usageArgs, errorArgs...)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsThroughputTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
var switches sql.NullInt64
|
||||
if err := rows.Scan(&bucket, &requests, &tokens, &switches); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
switchCount := int64(0)
|
||||
if switches.Valid {
|
||||
switchCount = switches.Int64
|
||||
}
|
||||
|
||||
denom := float64(bucketSeconds)
|
||||
if denom <= 0 {
|
||||
denom = 60
|
||||
}
|
||||
qps := roundTo1DP(float64(requests) / denom)
|
||||
tps := roundTo1DP(float64(tokenConsumed) / denom)
|
||||
|
||||
points = append(points, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
SwitchCount: switchCount,
|
||||
QPS: qps,
|
||||
TPS: tps,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fill missing buckets with zeros so charts render continuous timelines.
|
||||
points = fillOpsThroughputBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
var byPlatform []*service.OpsThroughputPlatformBreakdownItem
|
||||
var topGroups []*service.OpsThroughputGroupBreakdownItem
|
||||
|
||||
platform := ""
|
||||
if filter != nil {
|
||||
platform = strings.TrimSpace(strings.ToLower(filter.Platform))
|
||||
}
|
||||
groupID := (*int64)(nil)
|
||||
if filter != nil {
|
||||
groupID = filter.GroupID
|
||||
}
|
||||
|
||||
// Drilldown helpers:
|
||||
// - No platform/group: totals by platform
|
||||
// - Platform selected but no group: top groups in that platform
|
||||
if platform == "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputBreakdownByPlatform(ctx, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
byPlatform = items
|
||||
} else if platform != "" && (groupID == nil || *groupID <= 0) {
|
||||
items, err := r.getThroughputTopGroupsByPlatform(ctx, start, end, platform, 10)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
topGroups = items
|
||||
}
|
||||
|
||||
return &service.OpsThroughputTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
|
||||
ByPlatform: byPlatform,
|
||||
TopGroups: topGroups,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputBreakdownByPlatform(ctx context.Context, start, end time.Time) ([]*service.OpsThroughputPlatformBreakdownItem, error) {
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT COALESCE(NULLIF(g.platform,''), a.platform) AS platform,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
LEFT JOIN groups g ON g.id = ul.group_id
|
||||
LEFT JOIN accounts a ON a.id = ul.account_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
GROUP BY 1
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT platform,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.platform, e.platform) AS platform,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.platform = e.platform
|
||||
)
|
||||
SELECT platform, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE platform IS NOT NULL AND platform <> ''
|
||||
ORDER BY request_count DESC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputPlatformBreakdownItem, 0, 8)
|
||||
for rows.Next() {
|
||||
var platform string
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&platform, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
items = append(items, &service.OpsThroughputPlatformBreakdownItem{
|
||||
Platform: platform,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (r *opsRepository) getThroughputTopGroupsByPlatform(ctx context.Context, start, end time.Time, platform string, limit int) ([]*service.OpsThroughputGroupBreakdownItem, error) {
|
||||
if strings.TrimSpace(platform) == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
q := `
|
||||
WITH usage_totals AS (
|
||||
SELECT ul.group_id AS group_id,
|
||||
g.name AS group_name,
|
||||
COUNT(*) AS success_count,
|
||||
COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_consumed
|
||||
FROM usage_logs ul
|
||||
JOIN groups g ON g.id = ul.group_id
|
||||
WHERE ul.created_at >= $1 AND ul.created_at < $2
|
||||
AND g.platform = $3
|
||||
GROUP BY 1, 2
|
||||
),
|
||||
error_totals AS (
|
||||
SELECT group_id,
|
||||
COUNT(*) AS error_count
|
||||
FROM ops_error_logs
|
||||
WHERE created_at >= $1 AND created_at < $2
|
||||
AND platform = $3
|
||||
AND group_id IS NOT NULL
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
AND is_count_tokens = FALSE -- 排除 count_tokens 请求的错误
|
||||
GROUP BY 1
|
||||
),
|
||||
combined AS (
|
||||
SELECT COALESCE(u.group_id, e.group_id) AS group_id,
|
||||
COALESCE(u.group_name, g2.name, '') AS group_name,
|
||||
COALESCE(u.success_count, 0) AS success_count,
|
||||
COALESCE(e.error_count, 0) AS error_count,
|
||||
COALESCE(u.token_consumed, 0) AS token_consumed
|
||||
FROM usage_totals u
|
||||
FULL OUTER JOIN error_totals e ON u.group_id = e.group_id
|
||||
LEFT JOIN groups g2 ON g2.id = COALESCE(u.group_id, e.group_id)
|
||||
)
|
||||
SELECT group_id, group_name, (success_count + error_count) AS request_count, token_consumed
|
||||
FROM combined
|
||||
WHERE group_id IS NOT NULL
|
||||
ORDER BY request_count DESC
|
||||
LIMIT $4`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, start, end, platform, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsThroughputGroupBreakdownItem, 0, limit)
|
||||
for rows.Next() {
|
||||
var groupID int64
|
||||
var groupName sql.NullString
|
||||
var requests int64
|
||||
var tokens sql.NullInt64
|
||||
if err := rows.Scan(&groupID, &groupName, &requests, &tokens); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tokenConsumed := int64(0)
|
||||
if tokens.Valid {
|
||||
tokenConsumed = tokens.Int64
|
||||
}
|
||||
name := ""
|
||||
if groupName.Valid {
|
||||
name = groupName.String
|
||||
}
|
||||
items = append(items, &service.OpsThroughputGroupBreakdownItem{
|
||||
GroupID: groupID,
|
||||
GroupName: name,
|
||||
RequestCount: requests,
|
||||
TokenConsumed: tokenConsumed,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func opsBucketExprForUsage(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', ul.created_at)"
|
||||
case 300:
|
||||
// 5-minute buckets in UTC.
|
||||
return "to_timestamp(floor(extract(epoch from ul.created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', ul.created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketExprForError(bucketSeconds int) string {
|
||||
switch bucketSeconds {
|
||||
case 3600:
|
||||
return "date_trunc('hour', created_at)"
|
||||
case 300:
|
||||
return "to_timestamp(floor(extract(epoch from created_at) / 300) * 300)"
|
||||
default:
|
||||
return "date_trunc('minute', created_at)"
|
||||
}
|
||||
}
|
||||
|
||||
func opsBucketLabel(bucketSeconds int) string {
|
||||
if bucketSeconds <= 0 {
|
||||
return "1m"
|
||||
}
|
||||
if bucketSeconds%3600 == 0 {
|
||||
h := bucketSeconds / 3600
|
||||
if h <= 0 {
|
||||
h = 1
|
||||
}
|
||||
return fmt.Sprintf("%dh", h)
|
||||
}
|
||||
m := bucketSeconds / 60
|
||||
if m <= 0 {
|
||||
m = 1
|
||||
}
|
||||
return fmt.Sprintf("%dm", m)
|
||||
}
|
||||
|
||||
func opsFloorToBucketStart(t time.Time, bucketSeconds int) time.Time {
|
||||
t = t.UTC()
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
secs := t.Unix()
|
||||
floored := secs - (secs % int64(bucketSeconds))
|
||||
return time.Unix(floored, 0).UTC()
|
||||
}
|
||||
|
||||
func fillOpsThroughputBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsThroughputTrendPoint) []*service.OpsThroughputTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsThroughputTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsThroughputTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsThroughputTrendPoint{
|
||||
BucketStart: cursor,
|
||||
RequestCount: 0,
|
||||
TokenConsumed: 0,
|
||||
SwitchCount: 0,
|
||||
QPS: 0,
|
||||
TPS: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorTrend(ctx context.Context, filter *service.OpsDashboardFilter, bucketSeconds int) (*service.OpsErrorTrendResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if bucketSeconds != 60 && bucketSeconds != 300 && bucketSeconds != 3600 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
bucketExpr := opsBucketExprForError(bucketSeconds)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
` + bucketExpr + ` AS bucket,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400) AS error_total,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND is_business_limited) AS business_limited,
|
||||
COUNT(*) FILTER (WHERE COALESCE(status_code, 0) >= 400 AND NOT is_business_limited) AS error_sla,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)) AS upstream_excl,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 429) AS upstream_429,
|
||||
COUNT(*) FILTER (WHERE error_owner = 'provider' AND NOT is_business_limited AND COALESCE(upstream_status_code, status_code, 0) = 529) AS upstream_529
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
GROUP BY 1
|
||||
ORDER BY 1 ASC`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
points := make([]*service.OpsErrorTrendPoint, 0, 256)
|
||||
for rows.Next() {
|
||||
var bucket time.Time
|
||||
var total, businessLimited, sla, upstreamExcl, upstream429, upstream529 int64
|
||||
if err := rows.Scan(&bucket, &total, &businessLimited, &sla, &upstreamExcl, &upstream429, &upstream529); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
points = append(points, &service.OpsErrorTrendPoint{
|
||||
BucketStart: bucket.UTC(),
|
||||
|
||||
ErrorCountTotal: total,
|
||||
BusinessLimitedCount: businessLimited,
|
||||
ErrorCountSLA: sla,
|
||||
|
||||
UpstreamErrorCountExcl429529: upstreamExcl,
|
||||
Upstream429Count: upstream429,
|
||||
Upstream529Count: upstream529,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
points = fillOpsErrorTrendBuckets(start, end, bucketSeconds, points)
|
||||
|
||||
return &service.OpsErrorTrendResponse{
|
||||
Bucket: opsBucketLabel(bucketSeconds),
|
||||
Points: points,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func fillOpsErrorTrendBuckets(start, end time.Time, bucketSeconds int, points []*service.OpsErrorTrendPoint) []*service.OpsErrorTrendPoint {
|
||||
if bucketSeconds <= 0 {
|
||||
bucketSeconds = 60
|
||||
}
|
||||
if !start.Before(end) {
|
||||
return points
|
||||
}
|
||||
|
||||
endMinus := end.Add(-time.Nanosecond)
|
||||
if endMinus.Before(start) {
|
||||
return points
|
||||
}
|
||||
|
||||
first := opsFloorToBucketStart(start, bucketSeconds)
|
||||
last := opsFloorToBucketStart(endMinus, bucketSeconds)
|
||||
step := time.Duration(bucketSeconds) * time.Second
|
||||
|
||||
existing := make(map[int64]*service.OpsErrorTrendPoint, len(points))
|
||||
for _, p := range points {
|
||||
if p == nil {
|
||||
continue
|
||||
}
|
||||
existing[p.BucketStart.UTC().Unix()] = p
|
||||
}
|
||||
|
||||
out := make([]*service.OpsErrorTrendPoint, 0, int(last.Sub(first)/step)+1)
|
||||
for cursor := first; !cursor.After(last); cursor = cursor.Add(step) {
|
||||
if p, ok := existing[cursor.Unix()]; ok && p != nil {
|
||||
out = append(out, p)
|
||||
continue
|
||||
}
|
||||
out = append(out, &service.OpsErrorTrendPoint{
|
||||
BucketStart: cursor,
|
||||
|
||||
ErrorCountTotal: 0,
|
||||
BusinessLimitedCount: 0,
|
||||
ErrorCountSLA: 0,
|
||||
|
||||
UpstreamErrorCountExcl429529: 0,
|
||||
Upstream429Count: 0,
|
||||
Upstream529Count: 0,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (r *opsRepository) GetErrorDistribution(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsErrorDistributionResponse, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
where, args, _ := buildErrorWhere(filter, start, end, 1)
|
||||
|
||||
q := `
|
||||
SELECT
|
||||
COALESCE(upstream_status_code, status_code, 0) AS status_code,
|
||||
COUNT(*) AS total,
|
||||
COUNT(*) FILTER (WHERE NOT is_business_limited) AS sla,
|
||||
COUNT(*) FILTER (WHERE is_business_limited) AS business_limited
|
||||
FROM ops_error_logs
|
||||
` + where + `
|
||||
AND COALESCE(status_code, 0) >= 400
|
||||
GROUP BY 1
|
||||
ORDER BY total DESC
|
||||
LIMIT 20`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, q, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
items := make([]*service.OpsErrorDistributionItem, 0, 16)
|
||||
var total int64
|
||||
for rows.Next() {
|
||||
var statusCode int
|
||||
var cntTotal, cntSLA, cntBiz int64
|
||||
if err := rows.Scan(&statusCode, &cntTotal, &cntSLA, &cntBiz); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
total += cntTotal
|
||||
items = append(items, &service.OpsErrorDistributionItem{
|
||||
StatusCode: statusCode,
|
||||
Total: cntTotal,
|
||||
SLA: cntSLA,
|
||||
BusinessLimited: cntBiz,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsErrorDistributionResponse{
|
||||
Total: total,
|
||||
Items: items,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func (r *opsRepository) GetWindowStats(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsWindowStats, error) {
|
||||
if r == nil || r.db == nil {
|
||||
return nil, fmt.Errorf("nil ops repository")
|
||||
}
|
||||
if filter == nil {
|
||||
return nil, fmt.Errorf("nil filter")
|
||||
}
|
||||
if filter.StartTime.IsZero() || filter.EndTime.IsZero() {
|
||||
return nil, fmt.Errorf("start_time/end_time required")
|
||||
}
|
||||
|
||||
start := filter.StartTime.UTC()
|
||||
end := filter.EndTime.UTC()
|
||||
if start.After(end) {
|
||||
return nil, fmt.Errorf("start_time must be <= end_time")
|
||||
}
|
||||
// Bound excessively large windows to prevent accidental heavy queries.
|
||||
if end.Sub(start) > 24*time.Hour {
|
||||
return nil, fmt.Errorf("window too large")
|
||||
}
|
||||
|
||||
successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
errorTotal, _, _, _, _, _, err := r.queryErrorCounts(ctx, filter, start, end)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &service.OpsWindowStats{
|
||||
StartTime: start,
|
||||
EndTime: end,
|
||||
|
||||
SuccessCount: successCount,
|
||||
ErrorCountTotal: errorTotal,
|
||||
TokenConsumed: tokenConsumed,
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||
|
||||
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||
now := time.Now().UTC()
|
||||
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||
{
|
||||
RequestID: "batch-ops-1",
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
Severity: "error",
|
||||
StatusCode: 429,
|
||||
ErrorMessage: "rate limited",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
RequestID: "batch-ops-2",
|
||||
ErrorPhase: "internal",
|
||||
ErrorType: "api_error",
|
||||
Severity: "error",
|
||||
StatusCode: 500,
|
||||
ErrorMessage: "internal error",
|
||||
CreatedAt: now.Add(time.Millisecond),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, inserted)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(12345)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(67890)
|
||||
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import "github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
|
||||
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
return &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type pricingRemoteClient struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// pricingRemoteClientError 代理初始化失败时的错误占位客户端
|
||||
// 所有请求直接返回初始化错误,禁止回退到直连
|
||||
type pricingRemoteClientError struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClientError) FetchPricingJSON(_ context.Context, _ string) ([]byte, error) {
|
||||
return nil, c.err
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClientError) FetchHashText(_ context.Context, _ string) (string, error) {
|
||||
return "", c.err
|
||||
}
|
||||
|
||||
// NewPricingRemoteClient 创建定价数据远程客户端
|
||||
// proxyURL 为空时直连,支持 http/https/socks5/socks5h 协议
|
||||
// 代理配置失败时行为由 allowDirectOnProxyError 控制:
|
||||
// - false(默认):返回错误占位客户端,禁止回退到直连
|
||||
// - true:回退到直连(仅限管理员显式开启)
|
||||
func NewPricingRemoteClient(proxyURL string, allowDirectOnProxyError bool) service.PricingRemoteClient {
|
||||
// 安全说明:httpclient.GetClient 的错误链(url.Parse / proxyutil)不含明文代理凭据,
|
||||
// 但仍通过 slog 仅在服务端日志记录,不会暴露给 HTTP 响应。
|
||||
sharedClient, err := httpclient.GetClient(httpclient.Options{
|
||||
Timeout: 30 * time.Second,
|
||||
ProxyURL: proxyURL,
|
||||
})
|
||||
if err != nil {
|
||||
if strings.TrimSpace(proxyURL) != "" && !allowDirectOnProxyError {
|
||||
slog.Warn("proxy client init failed, all requests will fail", "service", "pricing", "error", err)
|
||||
return &pricingRemoteClientError{err: fmt.Errorf("proxy client init failed and direct fallback is disabled; set security.proxy_fallback.allow_direct_on_error=true to allow fallback: %w", err)}
|
||||
}
|
||||
sharedClient = &http.Client{Timeout: 30 * time.Second}
|
||||
}
|
||||
return &pricingRemoteClient{
|
||||
httpClient: sharedClient,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchPricingJSON(ctx context.Context, url string) ([]byte, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
func (c *pricingRemoteClient) FetchHashText(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 哈希文件格式:hash filename 或者纯 hash
|
||||
hash := strings.TrimSpace(string(body))
|
||||
parts := strings.Fields(hash)
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
return hash, nil
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type PricingServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
srv *httptest.Server
|
||||
client *pricingRemoteClient
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
client, ok := NewPricingRemoteClient("", false).(*pricingRemoteClient)
|
||||
require.True(s.T(), ok, "type assertion failed")
|
||||
s.client = client
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TearDownTest() {
|
||||
if s.srv != nil {
|
||||
s.srv.Close()
|
||||
s.srv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) setupServer(handler http.HandlerFunc) {
|
||||
s.srv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_Success() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path == "/ok" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"ok":true}`))
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
body, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/ok")
|
||||
require.NoError(s.T(), err, "FetchPricingJSON")
|
||||
require.Equal(s.T(), `{"ok":true}`, string(body), "body mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, s.srv.URL+"/err")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_ParsesFields() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/hashfile":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("abc123 model_prices.json\n"))
|
||||
case "/hashonly":
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("def456\n"))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashfile")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "abc123", hash, "hash mismatch")
|
||||
|
||||
hash2, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/hashonly")
|
||||
require.NoError(s.T(), err, "FetchHashText")
|
||||
require.Equal(s.T(), "def456", hash2, "hash mismatch")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_NonOKStatus() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}))
|
||||
|
||||
_, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/nope")
|
||||
require.Error(s.T(), err, "expected error for non-200 status")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_InvalidURL() {
|
||||
_, err := s.client.FetchPricingJSON(s.ctx, "://invalid-url")
|
||||
require.Error(s.T(), err, "expected error for invalid URL")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_EmptyBody() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
// empty body
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/empty")
|
||||
require.NoError(s.T(), err, "FetchHashText empty body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchHashText_WhitespaceOnly() {
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(" \n"))
|
||||
}))
|
||||
|
||||
hash, err := s.client.FetchHashText(s.ctx, s.srv.URL+"/ws")
|
||||
require.NoError(s.T(), err, "FetchHashText whitespace body should not error")
|
||||
require.Equal(s.T(), "", hash, "expected empty hash after trimming")
|
||||
}
|
||||
|
||||
func (s *PricingServiceSuite) TestFetchPricingJSON_ContextCancel() {
|
||||
started := make(chan struct{})
|
||||
s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
close(started)
|
||||
<-r.Context().Done()
|
||||
}))
|
||||
|
||||
ctx, cancel := context.WithCancel(s.ctx)
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := s.client.FetchPricingJSON(ctx, s.srv.URL+"/block")
|
||||
done <- err
|
||||
}()
|
||||
|
||||
<-started
|
||||
cancel()
|
||||
|
||||
err := <-done
|
||||
require.Error(s.T(), err)
|
||||
}
|
||||
|
||||
func TestNewPricingRemoteClient_InvalidProxy_NoFallback(t *testing.T) {
|
||||
client := NewPricingRemoteClient("://bad", false)
|
||||
_, ok := client.(*pricingRemoteClientError)
|
||||
require.True(t, ok, "should return error client when proxy is invalid and fallback disabled")
|
||||
|
||||
_, err := client.FetchPricingJSON(context.Background(), "http://example.com")
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "proxy client init failed")
|
||||
}
|
||||
|
||||
func TestNewPricingRemoteClient_InvalidProxy_WithFallback(t *testing.T) {
|
||||
client := NewPricingRemoteClient("://bad", true)
|
||||
_, ok := client.(*pricingRemoteClient)
|
||||
require.True(t, ok, "should fallback to direct client when allowed")
|
||||
}
|
||||
|
||||
func TestPricingServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(PricingServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,273 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/promocodeusage"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type promoCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewPromoCodeRepository(client *dbent.Client) service.PromoCodeRepository {
|
||||
return &promoCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Create(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
code.UpdatedAt = created.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByID(ctx context.Context, id int64) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCode(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
m, err := r.client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetByCodeForUpdate(ctx context.Context, code string) (*service.PromoCode, error) {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
m, err := client.PromoCode.Query().
|
||||
Where(promocode.CodeEqualFold(code)).
|
||||
ForUpdate().
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrPromoCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Update(ctx context.Context, code *service.PromoCode) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
builder := client.PromoCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetBonusAmount(code.BonusAmount).
|
||||
SetMaxUses(code.MaxUses).
|
||||
SetUsedCount(code.UsedCount).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes)
|
||||
|
||||
if code.ExpiresAt != nil {
|
||||
builder.SetExpiresAt(*code.ExpiresAt)
|
||||
} else {
|
||||
builder.ClearExpiresAt()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrPromoCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
code.UpdatedAt = updated.UpdatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.Delete().Where(promocode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "")
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.PromoCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCode.Query()
|
||||
|
||||
if status != "" {
|
||||
q = q.Where(promocode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(promocode.CodeContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := promoCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) CreateUsage(ctx context.Context, usage *service.PromoCodeUsage) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
created, err := client.PromoCodeUsage.Create().
|
||||
SetPromoCodeID(usage.PromoCodeID).
|
||||
SetUserID(usage.UserID).
|
||||
SetBonusAmount(usage.BonusAmount).
|
||||
SetUsedAt(usage.UsedAt).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
usage.ID = created.ID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) GetUsageByPromoCodeAndUser(ctx context.Context, promoCodeID, userID int64) (*service.PromoCodeUsage, error) {
|
||||
m, err := r.client.PromoCodeUsage.Query().
|
||||
Where(
|
||||
promocodeusage.PromoCodeIDEQ(promoCodeID),
|
||||
promocodeusage.UserIDEQ(userID),
|
||||
).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return promoCodeUsageEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) ListUsagesByPromoCode(ctx context.Context, promoCodeID int64, params pagination.PaginationParams) ([]service.PromoCodeUsage, *pagination.PaginationResult, error) {
|
||||
q := r.client.PromoCodeUsage.Query().
|
||||
Where(promocodeusage.PromoCodeIDEQ(promoCodeID))
|
||||
|
||||
total, err := q.Clone().Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
usages, err := q.
|
||||
WithUser().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(promocodeusage.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outUsages := promoCodeUsageEntitiesToService(usages)
|
||||
|
||||
return outUsages, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *promoCodeRepository) IncrementUsedCount(ctx context.Context, id int64) error {
|
||||
client := clientFromContext(ctx, r.client)
|
||||
_, err := client.PromoCode.UpdateOneID(id).
|
||||
AddUsedCount(1).
|
||||
Save(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
// Entity to Service conversions
|
||||
|
||||
func promoCodeEntityToService(m *dbent.PromoCode) *service.PromoCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
return &service.PromoCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
BonusAmount: m.BonusAmount,
|
||||
MaxUses: m.MaxUses,
|
||||
UsedCount: m.UsedCount,
|
||||
Status: m.Status,
|
||||
ExpiresAt: m.ExpiresAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func promoCodeEntitiesToService(models []*dbent.PromoCode) []service.PromoCode {
|
||||
out := make([]service.PromoCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntityToService(m *dbent.PromoCodeUsage) *service.PromoCodeUsage {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.PromoCodeUsage{
|
||||
ID: m.ID,
|
||||
PromoCodeID: m.PromoCodeID,
|
||||
UserID: m.UserID,
|
||||
BonusAmount: m.BonusAmount,
|
||||
UsedAt: m.UsedAt,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func promoCodeUsageEntitiesToService(models []*dbent.PromoCodeUsage) []service.PromoCodeUsage {
|
||||
out := make([]service.PromoCodeUsage, 0, len(models))
|
||||
for i := range models {
|
||||
if s := promoCodeUsageEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,74 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const proxyLatencyKeyPrefix = "proxy:latency:"
|
||||
|
||||
func proxyLatencyKey(proxyID int64) string {
|
||||
return fmt.Sprintf("%s%d", proxyLatencyKeyPrefix, proxyID)
|
||||
}
|
||||
|
||||
type proxyLatencyCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewProxyLatencyCache(rdb *redis.Client) service.ProxyLatencyCache {
|
||||
return &proxyLatencyCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) GetProxyLatencies(ctx context.Context, proxyIDs []int64) (map[int64]*service.ProxyLatencyInfo, error) {
|
||||
results := make(map[int64]*service.ProxyLatencyInfo)
|
||||
if len(proxyIDs) == 0 {
|
||||
return results, nil
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(proxyIDs))
|
||||
for _, id := range proxyIDs {
|
||||
keys = append(keys, proxyLatencyKey(id))
|
||||
}
|
||||
|
||||
values, err := c.rdb.MGet(ctx, keys...).Result()
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
|
||||
for i, raw := range values {
|
||||
if raw == nil {
|
||||
continue
|
||||
}
|
||||
var payload []byte
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
payload = []byte(v)
|
||||
case []byte:
|
||||
payload = v
|
||||
default:
|
||||
continue
|
||||
}
|
||||
var info service.ProxyLatencyInfo
|
||||
if err := json.Unmarshal(payload, &info); err != nil {
|
||||
continue
|
||||
}
|
||||
results[proxyIDs[i]] = &info
|
||||
}
|
||||
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (c *proxyLatencyCache) SetProxyLatency(ctx context.Context, proxyID int64, info *service.ProxyLatencyInfo) error {
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.rdb.Set(ctx, proxyLatencyKey(proxyID), payload, 0).Err()
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
func NewProxyExitInfoProber(cfg *config.Config) service.ProxyExitInfoProber {
|
||||
insecure := false
|
||||
allowPrivate := false
|
||||
validateResolvedIP := true
|
||||
maxResponseBytes := defaultProxyProbeResponseMaxBytes
|
||||
if cfg != nil {
|
||||
insecure = cfg.Security.ProxyProbe.InsecureSkipVerify
|
||||
allowPrivate = cfg.Security.URLAllowlist.AllowPrivateHosts
|
||||
validateResolvedIP = cfg.Security.URLAllowlist.Enabled
|
||||
if cfg.Gateway.ProxyProbeResponseReadMaxBytes > 0 {
|
||||
maxResponseBytes = cfg.Gateway.ProxyProbeResponseReadMaxBytes
|
||||
}
|
||||
}
|
||||
if insecure {
|
||||
log.Printf("[ProxyProbe] Warning: insecure_skip_verify is not allowed and will cause probe failure.")
|
||||
}
|
||||
return &proxyProbeService{
|
||||
insecureSkipVerify: insecure,
|
||||
allowPrivateHosts: allowPrivate,
|
||||
validateResolvedIP: validateResolvedIP,
|
||||
maxResponseBytes: maxResponseBytes,
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
defaultProxyProbeTimeout = 30 * time.Second
|
||||
defaultProxyProbeResponseMaxBytes = int64(1024 * 1024)
|
||||
)
|
||||
|
||||
// probeURLs 按优先级排列的探测 URL 列表
|
||||
// 某些 AI API 专用代理只允许访问特定域名,因此需要多个备选
|
||||
var probeURLs = []struct {
|
||||
url string
|
||||
parser string // "ip-api" or "httpbin"
|
||||
}{
|
||||
{"http://ip-api.com/json/?lang=zh-CN", "ip-api"},
|
||||
{"http://httpbin.org/ip", "httpbin"},
|
||||
}
|
||||
|
||||
type proxyProbeService struct {
|
||||
insecureSkipVerify bool
|
||||
allowPrivateHosts bool
|
||||
validateResolvedIP bool
|
||||
maxResponseBytes int64
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) ProbeProxy(ctx context.Context, proxyURL string) (*service.ProxyExitInfo, int64, error) {
|
||||
client, err := httpclient.GetClient(httpclient.Options{
|
||||
ProxyURL: proxyURL,
|
||||
Timeout: defaultProxyProbeTimeout,
|
||||
InsecureSkipVerify: s.insecureSkipVerify,
|
||||
ValidateResolvedIP: s.validateResolvedIP,
|
||||
AllowPrivateHosts: s.allowPrivateHosts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create proxy client: %w", err)
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, probe := range probeURLs {
|
||||
exitInfo, latencyMs, err := s.probeWithURL(ctx, client, probe.url, probe.parser)
|
||||
if err == nil {
|
||||
return exitInfo, latencyMs, nil
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, 0, fmt.Errorf("all probe URLs failed, last error: %w", lastErr)
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) probeWithURL(ctx context.Context, client *http.Client, url string, parser string) (*service.ProxyExitInfo, int64, error) {
|
||||
startTime := time.Now()
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("proxy connection failed: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
latencyMs := time.Since(startTime).Milliseconds()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, latencyMs, fmt.Errorf("request failed with status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
maxResponseBytes := s.maxResponseBytes
|
||||
if maxResponseBytes <= 0 {
|
||||
maxResponseBytes = defaultProxyProbeResponseMaxBytes
|
||||
}
|
||||
body, err := io.ReadAll(io.LimitReader(resp.Body, maxResponseBytes+1))
|
||||
if err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
if int64(len(body)) > maxResponseBytes {
|
||||
return nil, latencyMs, fmt.Errorf("proxy probe response exceeds limit: %d", maxResponseBytes)
|
||||
}
|
||||
|
||||
switch parser {
|
||||
case "ip-api":
|
||||
return s.parseIPAPI(body, latencyMs)
|
||||
case "httpbin":
|
||||
return s.parseHTTPBin(body, latencyMs)
|
||||
default:
|
||||
return nil, latencyMs, fmt.Errorf("unknown parser: %s", parser)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseIPAPI(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
var ipInfo struct {
|
||||
Status string `json:"status"`
|
||||
Message string `json:"message"`
|
||||
Query string `json:"query"`
|
||||
City string `json:"city"`
|
||||
Region string `json:"region"`
|
||||
RegionName string `json:"regionName"`
|
||||
Country string `json:"country"`
|
||||
CountryCode string `json:"countryCode"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &ipInfo); err != nil {
|
||||
preview := string(body)
|
||||
if len(preview) > 200 {
|
||||
preview = preview[:200] + "..."
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse response: %w (body: %s)", err, preview)
|
||||
}
|
||||
if strings.ToLower(ipInfo.Status) != "success" {
|
||||
if ipInfo.Message == "" {
|
||||
ipInfo.Message = "ip-api request failed"
|
||||
}
|
||||
return nil, latencyMs, fmt.Errorf("ip-api request failed: %s", ipInfo.Message)
|
||||
}
|
||||
|
||||
region := ipInfo.RegionName
|
||||
if region == "" {
|
||||
region = ipInfo.Region
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: ipInfo.Query,
|
||||
City: ipInfo.City,
|
||||
Region: region,
|
||||
Country: ipInfo.Country,
|
||||
CountryCode: ipInfo.CountryCode,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
|
||||
func (s *proxyProbeService) parseHTTPBin(body []byte, latencyMs int64) (*service.ProxyExitInfo, int64, error) {
|
||||
// httpbin.org/ip 返回格式: {"origin": "1.2.3.4"}
|
||||
var result struct {
|
||||
Origin string `json:"origin"`
|
||||
}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, latencyMs, fmt.Errorf("failed to parse httpbin response: %w", err)
|
||||
}
|
||||
if result.Origin == "" {
|
||||
return nil, latencyMs, fmt.Errorf("httpbin: no IP found in response")
|
||||
}
|
||||
return &service.ProxyExitInfo{
|
||||
IP: result.Origin,
|
||||
}, latencyMs, nil
|
||||
}
|
||||
@@ -0,0 +1,171 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyProbeServiceSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
proxySrv *httptest.Server
|
||||
prober *proxyProbeService
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.prober = &proxyProbeService{
|
||||
allowPrivateHosts: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TearDownTest() {
|
||||
if s.proxySrv != nil {
|
||||
s.proxySrv.Close()
|
||||
s.proxySrv = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) setupProxyServer(handler http.HandlerFunc) {
|
||||
s.proxySrv = newLocalTestServer(s.T(), handler)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidProxyURL() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "://bad")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_UnsupportedProxyScheme() {
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, "ftp://127.0.0.1:1")
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "failed to create proxy client")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_IPAPI() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// 检查是否是 ip-api 请求
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"status":"success","query":"1.2.3.4","city":"c","regionName":"r","country":"cc","countryCode":"CC"}`)
|
||||
return
|
||||
}
|
||||
// 其他请求返回错误
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "c", info.City)
|
||||
require.Equal(s.T(), "r", info.Region)
|
||||
require.Equal(s.T(), "cc", info.Country)
|
||||
require.Equal(s.T(), "CC", info.CountryCode)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_Success_HTTPBinFallback() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// ip-api 失败
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
// httpbin 成功
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, `{"origin": "5.6.7.8"}`)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
info, latencyMs, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.NoError(s.T(), err, "ProbeProxy should fallback to httpbin")
|
||||
require.GreaterOrEqual(s.T(), latencyMs, int64(0), "unexpected latency")
|
||||
require.Equal(s.T(), "5.6.7.8", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_AllFailed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_InvalidJSON() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.RequestURI, "ip-api.com") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
// httpbin 也返回无效响应
|
||||
if strings.Contains(r.RequestURI, "httpbin.org") {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_, _ = io.WriteString(w, "not-json")
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusServiceUnavailable)
|
||||
}))
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "all probe URLs failed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestProbeProxy_ProxyServerClosed() {
|
||||
s.setupProxyServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
s.proxySrv.Close()
|
||||
|
||||
_, _, err := s.prober.ProbeProxy(s.ctx, s.proxySrv.URL)
|
||||
require.Error(s.T(), err, "expected error when proxy server is closed")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Success() {
|
||||
body := []byte(`{"status":"success","query":"1.2.3.4","city":"Beijing","regionName":"Beijing","country":"China","countryCode":"CN"}`)
|
||||
info, latencyMs, err := s.prober.parseIPAPI(body, 100)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(100), latencyMs)
|
||||
require.Equal(s.T(), "1.2.3.4", info.IP)
|
||||
require.Equal(s.T(), "Beijing", info.City)
|
||||
require.Equal(s.T(), "Beijing", info.Region)
|
||||
require.Equal(s.T(), "China", info.Country)
|
||||
require.Equal(s.T(), "CN", info.CountryCode)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseIPAPI_Failure() {
|
||||
body := []byte(`{"status":"fail","message":"rate limited"}`)
|
||||
_, _, err := s.prober.parseIPAPI(body, 100)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "rate limited")
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_Success() {
|
||||
body := []byte(`{"origin": "9.8.7.6"}`)
|
||||
info, latencyMs, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), int64(50), latencyMs)
|
||||
require.Equal(s.T(), "9.8.7.6", info.IP)
|
||||
}
|
||||
|
||||
func (s *ProxyProbeServiceSuite) TestParseHTTPBin_NoIP() {
|
||||
body := []byte(`{"origin": ""}`)
|
||||
_, _, err := s.prober.parseHTTPBin(body, 50)
|
||||
require.Error(s.T(), err)
|
||||
require.ErrorContains(s.T(), err, "no IP found")
|
||||
}
|
||||
|
||||
func TestProxyProbeServiceSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyProbeServiceSuite))
|
||||
}
|
||||
@@ -0,0 +1,378 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/proxy"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type sqlQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
type proxyRepository struct {
|
||||
client *dbent.Client
|
||||
sql sqlQuerier
|
||||
}
|
||||
|
||||
func NewProxyRepository(client *dbent.Client, sqlDB *sql.DB) service.ProxyRepository {
|
||||
return newProxyRepositoryWithSQL(client, sqlDB)
|
||||
}
|
||||
|
||||
func newProxyRepositoryWithSQL(client *dbent.Client, sqlq sqlQuerier) *proxyRepository {
|
||||
return &proxyRepository{client: client, sql: sqlq}
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Create(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.Create().
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyEntityToService(proxyIn, created)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) GetByID(ctx context.Context, id int64) (*service.Proxy, error) {
|
||||
m, err := r.client.Proxy.Get(ctx, id)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrProxyNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return proxyEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListByIDs(ctx context.Context, ids []int64) ([]service.Proxy, error) {
|
||||
if len(ids) == 0 {
|
||||
return []service.Proxy{}, nil
|
||||
}
|
||||
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.IDIn(ids...)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
out = append(out, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Update(ctx context.Context, proxyIn *service.Proxy) error {
|
||||
builder := r.client.Proxy.UpdateOneID(proxyIn.ID).
|
||||
SetName(proxyIn.Name).
|
||||
SetProtocol(proxyIn.Protocol).
|
||||
SetHost(proxyIn.Host).
|
||||
SetPort(proxyIn.Port).
|
||||
SetStatus(proxyIn.Status)
|
||||
if proxyIn.Username != "" {
|
||||
builder.SetUsername(proxyIn.Username)
|
||||
} else {
|
||||
builder.ClearUsername()
|
||||
}
|
||||
if proxyIn.Password != "" {
|
||||
builder.SetPassword(proxyIn.Password)
|
||||
} else {
|
||||
builder.ClearPassword()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
applyProxyEntityToService(proxyIn, updated)
|
||||
return nil
|
||||
}
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrProxyNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.Proxy.Delete().Where(proxy.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *proxyRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
// ListWithFilters lists proxies with optional filtering by protocol, status, and search query
|
||||
func (r *proxyRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.Proxy, *pagination.PaginationResult, error) {
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
|
||||
return outProxies, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// ListWithFiltersAndAccountCount lists proxies with filters and includes account count per proxy
|
||||
func (r *proxyRepository) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]service.ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
q := r.client.Proxy.Query()
|
||||
if protocol != "" {
|
||||
q = q.Where(proxy.ProtocolEQ(protocol))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(proxy.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(proxy.NameContainsFold(search))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxies, err := q.
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(proxy.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListActive(ctx context.Context) ([]service.Proxy, error) {
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outProxies := make([]service.Proxy, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
outProxies = append(outProxies, *proxyEntityToService(proxies[i]))
|
||||
}
|
||||
return outProxies, nil
|
||||
}
|
||||
|
||||
// ExistsByHostPortAuth checks if a proxy with the same host, port, username, and password exists
|
||||
func (r *proxyRepository) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
q := r.client.Proxy.Query().
|
||||
Where(proxy.HostEQ(host), proxy.PortEQ(port))
|
||||
|
||||
if username == "" {
|
||||
q = q.Where(proxy.Or(proxy.UsernameIsNil(), proxy.UsernameEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.UsernameEQ(username))
|
||||
}
|
||||
if password == "" {
|
||||
q = q.Where(proxy.Or(proxy.PasswordIsNil(), proxy.PasswordEQ("")))
|
||||
} else {
|
||||
q = q.Where(proxy.PasswordEQ(password))
|
||||
}
|
||||
|
||||
count, err := q.Count(ctx)
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// CountAccountsByProxyID returns the number of accounts using a specific proxy
|
||||
func (r *proxyRepository) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
var count int64
|
||||
if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM accounts WHERE proxy_id = $1 AND deleted_at IS NULL", []any{proxyID}, &count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (r *proxyRepository) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) {
|
||||
rows, err := r.sql.QueryContext(ctx, `
|
||||
SELECT id, name, platform, type, notes
|
||||
FROM accounts
|
||||
WHERE proxy_id = $1 AND deleted_at IS NULL
|
||||
ORDER BY id DESC
|
||||
`, proxyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
out := make([]service.ProxyAccountSummary, 0)
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
name string
|
||||
platform string
|
||||
accType string
|
||||
notes sql.NullString
|
||||
)
|
||||
if err := rows.Scan(&id, &name, &platform, &accType, ¬es); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var notesPtr *string
|
||||
if notes.Valid {
|
||||
notesPtr = ¬es.String
|
||||
}
|
||||
out = append(out, service.ProxyAccountSummary{
|
||||
ID: id,
|
||||
Name: name,
|
||||
Platform: platform,
|
||||
Type: accType,
|
||||
Notes: notesPtr,
|
||||
})
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetAccountCountsForProxies returns a map of proxy ID to account count for all proxies
|
||||
func (r *proxyRepository) GetAccountCountsForProxies(ctx context.Context) (counts map[int64]int64, err error) {
|
||||
rows, err := r.sql.QueryContext(ctx, "SELECT proxy_id, COUNT(*) AS count FROM accounts WHERE proxy_id IS NOT NULL AND deleted_at IS NULL GROUP BY proxy_id")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil && err == nil {
|
||||
err = closeErr
|
||||
counts = nil
|
||||
}
|
||||
}()
|
||||
|
||||
counts = make(map[int64]int64)
|
||||
for rows.Next() {
|
||||
var proxyID, count int64
|
||||
if err = rows.Scan(&proxyID, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts[proxyID] = count
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// ListActiveWithAccountCount returns all active proxies with account count, sorted by creation time descending
|
||||
func (r *proxyRepository) ListActiveWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) {
|
||||
proxies, err := r.client.Proxy.Query().
|
||||
Where(proxy.StatusEQ(service.StatusActive)).
|
||||
Order(dbent.Desc(proxy.FieldCreatedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get account counts
|
||||
counts, err := r.GetAccountCountsForProxies(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Build result with account counts
|
||||
result := make([]service.ProxyWithAccountCount, 0, len(proxies))
|
||||
for i := range proxies {
|
||||
proxyOut := proxyEntityToService(proxies[i])
|
||||
if proxyOut == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, service.ProxyWithAccountCount{
|
||||
Proxy: *proxyOut,
|
||||
AccountCount: counts[proxyOut.ID],
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func proxyEntityToService(m *dbent.Proxy) *service.Proxy {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.Proxy{
|
||||
ID: m.ID,
|
||||
Name: m.Name,
|
||||
Protocol: m.Protocol,
|
||||
Host: m.Host,
|
||||
Port: m.Port,
|
||||
Status: m.Status,
|
||||
CreatedAt: m.CreatedAt,
|
||||
UpdatedAt: m.UpdatedAt,
|
||||
}
|
||||
if m.Username != nil {
|
||||
out.Username = *m.Username
|
||||
}
|
||||
if m.Password != nil {
|
||||
out.Password = *m.Password
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func applyProxyEntityToService(dst *service.Proxy, src *dbent.Proxy) {
|
||||
if dst == nil || src == nil {
|
||||
return
|
||||
}
|
||||
dst.ID = src.ID
|
||||
dst.CreatedAt = src.CreatedAt
|
||||
dst.UpdatedAt = src.UpdatedAt
|
||||
}
|
||||
@@ -0,0 +1,329 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type ProxyRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
tx *dbent.Tx
|
||||
repo *proxyRepository
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.tx = tx
|
||||
s.repo = newProxyRepositoryWithSQL(tx.Client(), tx)
|
||||
}
|
||||
|
||||
func TestProxyRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(ProxyRepoSuite))
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCreate() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "test-create",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(proxy.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("test-create", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestUpdate() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "original",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
proxy.Name = "updated"
|
||||
err := s.repo.Update(s.ctx, proxy)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", got.Name)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestDelete() {
|
||||
proxy := &service.Proxy{
|
||||
Name: "to-delete",
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, proxy))
|
||||
|
||||
err := s.repo.Delete(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, proxy.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestList() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(proxies, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Protocol() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "socks5", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "socks5", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("socks5", proxies[0].Protocol)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Status() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusDisabled, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal(service.StatusDisabled, proxies[0].Status)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestListWithFilters_Search() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "production-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "dev-proxy", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
proxies, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "prod")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Contains(proxies[0].Name, "production")
|
||||
}
|
||||
|
||||
// --- ListActive ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActive() {
|
||||
s.mustCreateProxy(&service.Proxy{Name: "active1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustCreateProxy(&service.Proxy{Name: "inactive1", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusDisabled})
|
||||
|
||||
proxies, err := s.repo.ListActive(s.ctx)
|
||||
s.Require().NoError(err, "ListActive")
|
||||
s.Require().Len(proxies, 1)
|
||||
s.Require().Equal("active1", proxies[0].Name)
|
||||
}
|
||||
|
||||
// --- ExistsByHostPortAuth ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth() {
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p1",
|
||||
Protocol: "http",
|
||||
Host: "1.2.3.4",
|
||||
Port: 8080,
|
||||
Username: "user",
|
||||
Password: "pass",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "user", "pass")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "wrong", "creds")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_NoAuth() {
|
||||
s.mustCreateProxy(&service.Proxy{
|
||||
Name: "p-noauth",
|
||||
Protocol: "http",
|
||||
Host: "5.6.7.8",
|
||||
Port: 8081,
|
||||
Username: "",
|
||||
Password: "",
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "5.6.7.8", 8081, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().True(exists)
|
||||
}
|
||||
|
||||
// --- CountAccountsByProxyID ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID() {
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-count", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
s.mustInsertAccount("a1", &proxy.ID)
|
||||
s.mustInsertAccount("a2", &proxy.ID)
|
||||
s.mustInsertAccount("a3", nil) // no proxy
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestCountAccountsByProxyID_Zero() {
|
||||
proxy := s.mustCreateProxy(&service.Proxy{Name: "p-zero", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
|
||||
count, err := s.repo.CountAccountsByProxyID(s.ctx, proxy.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- GetAccountCountsForProxies ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies() {
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "127.0.0.1", Port: 8080, Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "127.0.0.1", Port: 8081, Status: service.StatusActive})
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) TestGetAccountCountsForProxies_Empty() {
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Empty(counts)
|
||||
}
|
||||
|
||||
// --- ListActiveWithAccountCount ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestListActiveWithAccountCount() {
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
p1 := s.mustCreateProxyWithTimes("p1", service.StatusActive, base.Add(-1*time.Hour))
|
||||
p2 := s.mustCreateProxyWithTimes("p2", service.StatusActive, base)
|
||||
s.mustCreateProxyWithTimes("p3-inactive", service.StatusDisabled, base.Add(1*time.Hour))
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 active proxies")
|
||||
|
||||
// Sorted by created_at DESC, so p2 first
|
||||
s.Require().Equal(p2.ID, withCounts[0].ID)
|
||||
s.Require().Equal(int64(1), withCounts[0].AccountCount)
|
||||
s.Require().Equal(p1.ID, withCounts[1].ID)
|
||||
s.Require().Equal(int64(2), withCounts[1].AccountCount)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *ProxyRepoSuite) TestExistsByHostPortAuth_And_AccountCountAggregates() {
|
||||
p1 := s.mustCreateProxy(&service.Proxy{Name: "p1", Protocol: "http", Host: "1.2.3.4", Port: 8080, Username: "u", Password: "p", Status: service.StatusActive})
|
||||
p2 := s.mustCreateProxy(&service.Proxy{Name: "p2", Protocol: "http", Host: "5.6.7.8", Port: 8081, Username: "", Password: "", Status: service.StatusActive})
|
||||
|
||||
exists, err := s.repo.ExistsByHostPortAuth(s.ctx, "1.2.3.4", 8080, "u", "p")
|
||||
s.Require().NoError(err, "ExistsByHostPortAuth")
|
||||
s.Require().True(exists, "expected proxy to exist")
|
||||
|
||||
s.mustInsertAccount("a1", &p1.ID)
|
||||
s.mustInsertAccount("a2", &p1.ID)
|
||||
s.mustInsertAccount("a3", &p2.ID)
|
||||
|
||||
count1, err := s.repo.CountAccountsByProxyID(s.ctx, p1.ID)
|
||||
s.Require().NoError(err, "CountAccountsByProxyID")
|
||||
s.Require().Equal(int64(2), count1, "expected 2 accounts for p1")
|
||||
|
||||
counts, err := s.repo.GetAccountCountsForProxies(s.ctx)
|
||||
s.Require().NoError(err, "GetAccountCountsForProxies")
|
||||
s.Require().Equal(int64(2), counts[p1.ID])
|
||||
s.Require().Equal(int64(1), counts[p2.ID])
|
||||
|
||||
withCounts, err := s.repo.ListActiveWithAccountCount(s.ctx)
|
||||
s.Require().NoError(err, "ListActiveWithAccountCount")
|
||||
s.Require().Len(withCounts, 2, "expected 2 proxies")
|
||||
for _, pc := range withCounts {
|
||||
switch pc.ID {
|
||||
case p1.ID:
|
||||
s.Require().Equal(int64(2), pc.AccountCount, "p1 count mismatch")
|
||||
case p2.ID:
|
||||
s.Require().Equal(int64(1), pc.AccountCount, "p2 count mismatch")
|
||||
default:
|
||||
s.Require().Fail("unexpected proxy id", pc.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxy(p *service.Proxy) *service.Proxy {
|
||||
s.T().Helper()
|
||||
s.Require().NoError(s.repo.Create(s.ctx, p), "create proxy")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustCreateProxyWithTimes(name, status string, createdAt time.Time) *service.Proxy {
|
||||
s.T().Helper()
|
||||
|
||||
// Use the repository create for standard fields, then update timestamps via raw SQL to keep deterministic ordering.
|
||||
p := s.mustCreateProxy(&service.Proxy{
|
||||
Name: name,
|
||||
Protocol: "http",
|
||||
Host: "127.0.0.1",
|
||||
Port: 8080,
|
||||
Status: status,
|
||||
})
|
||||
_, err := s.tx.ExecContext(s.ctx, "UPDATE proxies SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, p.ID)
|
||||
s.Require().NoError(err, "update proxy timestamps")
|
||||
return p
|
||||
}
|
||||
|
||||
func (s *ProxyRepoSuite) mustInsertAccount(name string, proxyID *int64) {
|
||||
s.T().Helper()
|
||||
var pid any
|
||||
if proxyID != nil {
|
||||
pid = *proxyID
|
||||
}
|
||||
_, err := s.tx.ExecContext(
|
||||
s.ctx,
|
||||
"INSERT INTO accounts (name, platform, type, proxy_id) VALUES ($1, $2, $3, $4)",
|
||||
name,
|
||||
service.PlatformAnthropic,
|
||||
service.AccountTypeOAuth,
|
||||
pid,
|
||||
)
|
||||
s.Require().NoError(err, "insert account")
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
const (
|
||||
redeemRateLimitKeyPrefix = "redeem:ratelimit:"
|
||||
redeemLockKeyPrefix = "redeem:lock:"
|
||||
redeemRateLimitDuration = 24 * time.Hour
|
||||
)
|
||||
|
||||
// redeemRateLimitKey generates the Redis key for redeem attempt rate limiting.
|
||||
func redeemRateLimitKey(userID int64) string {
|
||||
return fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
}
|
||||
|
||||
// redeemLockKey generates the Redis key for redeem code locking.
|
||||
func redeemLockKey(code string) string {
|
||||
return redeemLockKeyPrefix + code
|
||||
}
|
||||
|
||||
type redeemCache struct {
|
||||
rdb *redis.Client
|
||||
}
|
||||
|
||||
func NewRedeemCache(rdb *redis.Client) service.RedeemCache {
|
||||
return &redeemCache{rdb: rdb}
|
||||
}
|
||||
|
||||
func (c *redeemCache) GetRedeemAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
key := redeemRateLimitKey(userID)
|
||||
count, err := c.rdb.Get(ctx, key).Int()
|
||||
if err == redis.Nil {
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (c *redeemCache) IncrementRedeemAttemptCount(ctx context.Context, userID int64) error {
|
||||
key := redeemRateLimitKey(userID)
|
||||
pipe := c.rdb.Pipeline()
|
||||
pipe.Incr(ctx, key)
|
||||
pipe.Expire(ctx, key, redeemRateLimitDuration)
|
||||
_, err := pipe.Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *redeemCache) AcquireRedeemLock(ctx context.Context, code string, ttl time.Duration) (bool, error) {
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.SetNX(ctx, key, 1, ttl).Result()
|
||||
}
|
||||
|
||||
func (c *redeemCache) ReleaseRedeemLock(ctx context.Context, code string) error {
|
||||
key := redeemLockKey(code)
|
||||
return c.rdb.Del(ctx, key).Err()
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RedeemCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *redeemCache
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewRedeemCache(s.rdb).(*redeemCache)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestGetRedeemAttemptCount_Missing() {
|
||||
missingUserID := int64(99999)
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, missingUserID)
|
||||
require.NoError(s.T(), err, "expected nil error for missing rate-limit key")
|
||||
require.Equal(s.T(), 0, count, "expected zero count for missing key")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestIncrementAndGetRedeemAttemptCount() {
|
||||
userID := int64(1)
|
||||
key := fmt.Sprintf("%s%d", redeemRateLimitKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID), "IncrementRedeemAttemptCount")
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetRedeemAttemptCount")
|
||||
require.Equal(s.T(), 1, count, "count mismatch")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, key).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, redeemRateLimitDuration)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestMultipleIncrements() {
|
||||
userID := int64(2)
|
||||
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
require.NoError(s.T(), s.cache.IncrementRedeemAttemptCount(s.ctx, userID))
|
||||
|
||||
count, err := s.cache.GetRedeemAttemptCount(s.ctx, userID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, count, "count after 3 increments")
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireAndReleaseRedeemLock() {
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Second acquire should fail
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock 2")
|
||||
require.False(s.T(), ok, "expected lock to be held")
|
||||
|
||||
// Release
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "CODE"), "ReleaseRedeemLock")
|
||||
|
||||
// Now acquire should succeed
|
||||
ok, err = s.cache.AcquireRedeemLock(s.ctx, "CODE", 10*time.Second)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock after release")
|
||||
require.True(s.T(), ok)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestAcquireRedeemLock_TTL() {
|
||||
lockKey := redeemLockKeyPrefix + "CODE2"
|
||||
lockTTL := 15 * time.Second
|
||||
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "CODE2", lockTTL)
|
||||
require.NoError(s.T(), err, "AcquireRedeemLock CODE2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, lockKey).Result()
|
||||
require.NoError(s.T(), err, "TTL lock key")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, lockTTL)
|
||||
}
|
||||
|
||||
func (s *RedeemCacheSuite) TestReleaseRedeemLock_Idempotent() {
|
||||
// Release a lock that doesn't exist should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "NONEXISTENT"))
|
||||
|
||||
// Acquire, release, release again
|
||||
ok, err := s.cache.AcquireRedeemLock(s.ctx, "IDEMPOTENT", 10*time.Second)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"))
|
||||
require.NoError(s.T(), s.cache.ReleaseRedeemLock(s.ctx, "IDEMPOTENT"), "second release should be idempotent")
|
||||
}
|
||||
|
||||
func TestRedeemCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCacheSuite))
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestRedeemRateLimitKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "redeem:ratelimit:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "redeem:ratelimit:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "redeem:ratelimit:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "redeem:ratelimit:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemRateLimitKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRedeemLockKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
code string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_code",
|
||||
code: "ABC123",
|
||||
expected: "redeem:lock:ABC123",
|
||||
},
|
||||
{
|
||||
name: "empty_code",
|
||||
code: "",
|
||||
expected: "redeem:lock:",
|
||||
},
|
||||
{
|
||||
name: "code_with_special_chars",
|
||||
code: "CODE-2024:test",
|
||||
expected: "redeem:lock:CODE-2024:test",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := redeemLockKey(tc.code)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,296 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/ent/redeemcode"
|
||||
"github.com/Wei-Shaw/sub2api/ent/user"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
)
|
||||
|
||||
type redeemCodeRepository struct {
|
||||
client *dbent.Client
|
||||
}
|
||||
|
||||
func NewRedeemCodeRepository(client *dbent.Client) service.RedeemCodeRepository {
|
||||
return &redeemCodeRepository{client: client}
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Create(ctx context.Context, code *service.RedeemCode) error {
|
||||
created, err := r.client.RedeemCode.Create().
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays).
|
||||
SetNillableUsedBy(code.UsedBy).
|
||||
SetNillableUsedAt(code.UsedAt).
|
||||
SetNillableGroupID(code.GroupID).
|
||||
Save(ctx)
|
||||
if err == nil {
|
||||
code.ID = created.ID
|
||||
code.CreatedAt = created.CreatedAt
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) CreateBatch(ctx context.Context, codes []service.RedeemCode) error {
|
||||
if len(codes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
builders := make([]*dbent.RedeemCodeCreate, 0, len(codes))
|
||||
for i := range codes {
|
||||
c := &codes[i]
|
||||
b := r.client.RedeemCode.Create().
|
||||
SetCode(c.Code).
|
||||
SetType(c.Type).
|
||||
SetValue(c.Value).
|
||||
SetStatus(c.Status).
|
||||
SetNotes(c.Notes).
|
||||
SetValidityDays(c.ValidityDays).
|
||||
SetNillableUsedBy(c.UsedBy).
|
||||
SetNillableUsedAt(c.UsedAt).
|
||||
SetNillableGroupID(c.GroupID)
|
||||
builders = append(builders, b)
|
||||
}
|
||||
|
||||
return r.client.RedeemCode.CreateBulk(builders...).Exec(ctx)
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByID(ctx context.Context, id int64) (*service.RedeemCode, error) {
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.IDEQ(id)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) GetByCode(ctx context.Context, code string) (*service.RedeemCode, error) {
|
||||
m, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.CodeEQ(code)).
|
||||
Only(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return nil, service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return redeemCodeEntityToService(m), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Delete(ctx context.Context, id int64) error {
|
||||
_, err := r.client.RedeemCode.Delete().Where(redeemcode.IDEQ(id)).Exec(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) List(ctx context.Context, params pagination.PaginationParams) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
return r.ListWithFilters(ctx, params, "", "", "")
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.RedeemCode.Query()
|
||||
|
||||
if codeType != "" {
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
if status != "" {
|
||||
q = q.Where(redeemcode.StatusEQ(status))
|
||||
}
|
||||
if search != "" {
|
||||
q = q.Where(
|
||||
redeemcode.Or(
|
||||
redeemcode.CodeContainsFold(search),
|
||||
redeemcode.HasUserWith(user.EmailContainsFold(search)),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
WithUser().
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldID)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
outCodes := redeemCodeEntitiesToService(codes)
|
||||
|
||||
return outCodes, paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Update(ctx context.Context, code *service.RedeemCode) error {
|
||||
up := r.client.RedeemCode.UpdateOneID(code.ID).
|
||||
SetCode(code.Code).
|
||||
SetType(code.Type).
|
||||
SetValue(code.Value).
|
||||
SetStatus(code.Status).
|
||||
SetNotes(code.Notes).
|
||||
SetValidityDays(code.ValidityDays)
|
||||
|
||||
if code.UsedBy != nil {
|
||||
up.SetUsedBy(*code.UsedBy)
|
||||
} else {
|
||||
up.ClearUsedBy()
|
||||
}
|
||||
if code.UsedAt != nil {
|
||||
up.SetUsedAt(*code.UsedAt)
|
||||
} else {
|
||||
up.ClearUsedAt()
|
||||
}
|
||||
if code.GroupID != nil {
|
||||
up.SetGroupID(*code.GroupID)
|
||||
} else {
|
||||
up.ClearGroupID()
|
||||
}
|
||||
|
||||
updated, err := up.Save(ctx)
|
||||
if err != nil {
|
||||
if dbent.IsNotFound(err) {
|
||||
return service.ErrRedeemCodeNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
code.CreatedAt = updated.CreatedAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) Use(ctx context.Context, id, userID int64) error {
|
||||
now := time.Now()
|
||||
client := clientFromContext(ctx, r.client)
|
||||
affected, err := client.RedeemCode.Update().
|
||||
Where(redeemcode.IDEQ(id), redeemcode.StatusEQ(service.StatusUnused)).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetUsedBy(userID).
|
||||
SetUsedAt(now).
|
||||
Save(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected == 0 {
|
||||
return service.ErrRedeemCodeUsed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *redeemCodeRepository) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) {
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
codes, err := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID)).
|
||||
WithGroup().
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
Limit(limit).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return redeemCodeEntitiesToService(codes), nil
|
||||
}
|
||||
|
||||
// ListByUserPaginated returns paginated balance/concurrency history for a user.
|
||||
// Supports optional type filter (e.g. "balance", "admin_balance", "concurrency", "admin_concurrency", "subscription").
|
||||
func (r *redeemCodeRepository) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]service.RedeemCode, *pagination.PaginationResult, error) {
|
||||
q := r.client.RedeemCode.Query().
|
||||
Where(redeemcode.UsedByEQ(userID))
|
||||
|
||||
// Optional type filter
|
||||
if codeType != "" {
|
||||
q = q.Where(redeemcode.TypeEQ(codeType))
|
||||
}
|
||||
|
||||
total, err := q.Count(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
codes, err := q.
|
||||
WithGroup().
|
||||
Offset(params.Offset()).
|
||||
Limit(params.Limit()).
|
||||
Order(dbent.Desc(redeemcode.FieldUsedAt)).
|
||||
All(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return redeemCodeEntitiesToService(codes), paginationResultFromTotal(int64(total), params), nil
|
||||
}
|
||||
|
||||
// SumPositiveBalanceByUser returns total recharged amount (sum of value > 0 where type is balance/admin_balance).
|
||||
func (r *redeemCodeRepository) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
var result []struct {
|
||||
Sum float64 `json:"sum"`
|
||||
}
|
||||
err := r.client.RedeemCode.Query().
|
||||
Where(
|
||||
redeemcode.UsedByEQ(userID),
|
||||
redeemcode.ValueGT(0),
|
||||
redeemcode.TypeIn("balance", "admin_balance"),
|
||||
).
|
||||
Aggregate(dbent.As(dbent.Sum(redeemcode.FieldValue), "sum")).
|
||||
Scan(ctx, &result)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
return result[0].Sum, nil
|
||||
}
|
||||
|
||||
func redeemCodeEntityToService(m *dbent.RedeemCode) *service.RedeemCode {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
out := &service.RedeemCode{
|
||||
ID: m.ID,
|
||||
Code: m.Code,
|
||||
Type: m.Type,
|
||||
Value: m.Value,
|
||||
Status: m.Status,
|
||||
UsedBy: m.UsedBy,
|
||||
UsedAt: m.UsedAt,
|
||||
Notes: derefString(m.Notes),
|
||||
CreatedAt: m.CreatedAt,
|
||||
GroupID: m.GroupID,
|
||||
ValidityDays: m.ValidityDays,
|
||||
}
|
||||
if m.Edges.User != nil {
|
||||
out.User = userEntityToService(m.Edges.User)
|
||||
}
|
||||
if m.Edges.Group != nil {
|
||||
out.Group = groupEntityToService(m.Edges.Group)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func redeemCodeEntitiesToService(models []*dbent.RedeemCode) []service.RedeemCode {
|
||||
out := make([]service.RedeemCode, 0, len(models))
|
||||
for i := range models {
|
||||
if s := redeemCodeEntityToService(models[i]); s != nil {
|
||||
out = append(out, *s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -0,0 +1,390 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RedeemCodeRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *redeemCodeRepository
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewRedeemCodeRepository(s.client).(*redeemCodeRepository)
|
||||
}
|
||||
|
||||
func TestRedeemCodeRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(RedeemCodeRepoSuite))
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createUser(email string) *dbent.User {
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) createGroup(name string) *dbent.Group {
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return g
|
||||
}
|
||||
|
||||
// --- Create / CreateBatch / GetByID / GetByCode ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreate() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "TEST-CREATE",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 100,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, code)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(code.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("TEST-CREATE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch() {
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "BATCH-1", Type: service.RedeemTypeBalance, Value: 10, Status: service.StatusUnused},
|
||||
{Code: "BATCH-2", Type: service.RedeemTypeBalance, Value: 20, Status: service.StatusUnused},
|
||||
}
|
||||
|
||||
err := s.repo.CreateBatch(s.ctx, codes)
|
||||
s.Require().NoError(err, "CreateBatch")
|
||||
|
||||
got1, err := s.repo.GetByCode(s.ctx, "BATCH-1")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(10), got1.Value)
|
||||
|
||||
got2, err := s.repo.GetByCode(s.ctx, "BATCH-2")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(20), got2.Value)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode() {
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("GET-BY-CODE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "seed redeem code")
|
||||
|
||||
got, err := s.repo.GetByCode(s.ctx, "GET-BY-CODE")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().Equal("GET-BY-CODE", got.Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestGetByCode_NotFound() {
|
||||
_, err := s.repo.GetByCode(s.ctx, "NON-EXISTENT")
|
||||
s.Require().Error(err, "expected error for non-existent code")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- Delete ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestDelete() {
|
||||
created, err := s.client.RedeemCode.Create().
|
||||
SetCode("TO-DELETE").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
err = s.repo.Delete(s.ctx, created.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeNotFound)
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestList() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-1", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "LIST-2", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(codes, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Type() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-BAL", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "TYPE-SUB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(service.RedeemTypeSubscription, codes[0].Type)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Status() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-UNUSED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "STAT-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", service.StatusUsed, "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Equal(service.StatusUsed, codes[0].Status)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_Search() {
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "ALPHA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
s.Require().NoError(s.repo.Create(s.ctx, &service.RedeemCode{Code: "BETA-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}))
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "alpha")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().Contains(codes[0].Code, "ALPHA")
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListWithFilters_GroupPreload() {
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-preload"))
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GROUP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUnused).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
// --- Update ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUpdate() {
|
||||
code := &service.RedeemCode{
|
||||
Code: "UPDATE-ME",
|
||||
Type: service.RedeemTypeBalance,
|
||||
Value: 10,
|
||||
Status: service.StatusUnused,
|
||||
}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
code.Value = 50
|
||||
err := s.repo.Update(s.ctx, code)
|
||||
s.Require().NoError(err, "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(float64(50), got.Value)
|
||||
}
|
||||
|
||||
// --- Use ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "use") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "USE-ME", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, code.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.StatusUsed, got.Status)
|
||||
s.Require().NotNil(got.UsedBy)
|
||||
s.Require().Equal(user.ID, *got.UsedBy)
|
||||
s.Require().NotNil(got.UsedAt)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_Idempotency() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "idem") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "IDEM-CODE", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUnused}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().NoError(err, "Use first time")
|
||||
|
||||
// Second use should fail
|
||||
err = s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestUse_AlreadyUsed() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "already") + "@example.com")
|
||||
code := &service.RedeemCode{Code: "ALREADY-USED", Type: service.RedeemTypeBalance, Value: 0, Status: service.StatusUsed}
|
||||
s.Require().NoError(s.repo.Create(s.ctx, code))
|
||||
|
||||
err := s.repo.Use(s.ctx, code.ID, user.ID)
|
||||
s.Require().Error(err, "expected error for already used code")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
}
|
||||
|
||||
// --- ListByUser ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "listby") + "@example.com")
|
||||
base := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
|
||||
usedAt1 := base
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("USER-1").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt1).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
usedAt2 := base.Add(1 * time.Hour)
|
||||
_, err = s.client.RedeemCode.Create().
|
||||
SetCode("USER-2").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(usedAt2).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(codes, 2)
|
||||
// Ordered by used_at DESC, so USER-2 first
|
||||
s.Require().Equal("USER-2", codes[0].Code)
|
||||
s.Require().Equal("USER-1", codes[1].Code)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_WithGroupPreload() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "grp") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-listby"))
|
||||
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("WITH-GRP").
|
||||
SetType(service.RedeemTypeSubscription).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
SetGroupID(group.ID).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
s.Require().NotNil(codes[0].Group)
|
||||
s.Require().Equal(group.ID, codes[0].Group.ID)
|
||||
}
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestListByUser_DefaultLimit() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "deflimit") + "@example.com")
|
||||
_, err := s.client.RedeemCode.Create().
|
||||
SetCode("DEF-LIM").
|
||||
SetType(service.RedeemTypeBalance).
|
||||
SetStatus(service.StatusUsed).
|
||||
SetValue(0).
|
||||
SetNotes("").
|
||||
SetValidityDays(30).
|
||||
SetUsedBy(user.ID).
|
||||
SetUsedAt(time.Now()).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// limit <= 0 should default to 10
|
||||
codes, err := s.repo.ListByUser(s.ctx, user.ID, 0)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(codes, 1)
|
||||
}
|
||||
|
||||
// --- Combined original test ---
|
||||
|
||||
func (s *RedeemCodeRepoSuite) TestCreateBatch_Filters_Use_Idempotency_ListByUser() {
|
||||
user := s.createUser(uniqueTestValue(s.T(), "rc") + "@example.com")
|
||||
group := s.createGroup(uniqueTestValue(s.T(), "g-rc"))
|
||||
groupID := group.ID
|
||||
|
||||
codes := []service.RedeemCode{
|
||||
{Code: "CODEA", Type: service.RedeemTypeBalance, Value: 1, Status: service.StatusUnused, Notes: ""},
|
||||
{Code: "CODEB", Type: service.RedeemTypeSubscription, Value: 0, Status: service.StatusUnused, Notes: "", GroupID: &groupID, ValidityDays: 7},
|
||||
}
|
||||
s.Require().NoError(s.repo.CreateBatch(s.ctx, codes), "CreateBatch")
|
||||
|
||||
list, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.RedeemTypeSubscription, service.StatusUnused, "code")
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
s.Require().Len(list, 1)
|
||||
s.Require().NotNil(list[0].Group, "expected Group preload")
|
||||
s.Require().Equal(group.ID, list[0].Group.ID)
|
||||
|
||||
codeB, err := s.repo.GetByCode(s.ctx, "CODEB")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeB.ID, user.ID), "Use")
|
||||
err = s.repo.Use(s.ctx, codeB.ID, user.ID)
|
||||
s.Require().Error(err, "Use expected error on second call")
|
||||
s.Require().ErrorIs(err, service.ErrRedeemCodeUsed)
|
||||
|
||||
codeA, err := s.repo.GetByCode(s.ctx, "CODEA")
|
||||
s.Require().NoError(err, "GetByCode")
|
||||
|
||||
// Use fixed time instead of time.Sleep for deterministic ordering.
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeB.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NoError(s.repo.Use(s.ctx, codeA.ID, user.ID), "Use codeA")
|
||||
_, err = s.client.RedeemCode.UpdateOneID(codeA.ID).
|
||||
SetUsedAt(time.Date(2025, 1, 1, 13, 0, 0, 0, time.UTC)).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err)
|
||||
|
||||
used, err := s.repo.ListByUser(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "ListByUser")
|
||||
s.Require().Len(used, 2, "expected 2 used codes")
|
||||
s.Require().Equal("CODEA", used[0].Code, "expected newest used code first")
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user