feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
145
internal/repository/allowed_groups_contract_integration_test.go
Normal file
145
internal/repository/allowed_groups_contract_integration_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func uniqueTestValue(t *testing.T, prefix string) string {
|
||||
t.Helper()
|
||||
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
|
||||
return fmt.Sprintf("%s-%s", prefix, safeName)
|
||||
}
|
||||
|
||||
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "target-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "other-group")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
repo := newUserRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u1 := &service.User{
|
||||
Email: uniqueTestValue(t, "u1") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u1))
|
||||
|
||||
u2 := &service.User{
|
||||
Email: uniqueTestValue(t, "u2") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u2))
|
||||
|
||||
u3 := &service.User{
|
||||
Email: uniqueTestValue(t, "u3") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, repo.Create(ctx, u3))
|
||||
|
||||
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), affected)
|
||||
|
||||
u1After, err := repo.GetByID(ctx, u1.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
|
||||
|
||||
u2After, err := repo.GetByID(ctx, u2.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
|
||||
}
|
||||
|
||||
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
tx := testEntTx(t)
|
||||
entClient := tx.Client()
|
||||
|
||||
targetGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-target")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
otherGroup, err := entClient.Group.Create().
|
||||
SetName(uniqueTestValue(t, "delete-cascade-other")).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
userRepo := newUserRepositoryWithSQL(entClient, tx)
|
||||
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
|
||||
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
|
||||
|
||||
u := &service.User{
|
||||
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Concurrency: 5,
|
||||
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
|
||||
}
|
||||
require.NoError(t, userRepo.Create(ctx, u))
|
||||
|
||||
key := &service.APIKey{
|
||||
UserID: u.ID,
|
||||
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
|
||||
Name: "test key",
|
||||
GroupID: &targetGroup.ID,
|
||||
Status: service.StatusActive,
|
||||
}
|
||||
require.NoError(t, apiKeyRepo.Create(ctx, key))
|
||||
|
||||
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Deleted group should be hidden by default queries (soft-delete semantics).
|
||||
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
|
||||
require.ErrorIs(t, err, service.ErrGroupNotFound)
|
||||
|
||||
activeGroups, err := groupRepo.ListActive(ctx)
|
||||
require.NoError(t, err)
|
||||
for _, g := range activeGroups {
|
||||
require.NotEqual(t, targetGroup.ID, g.ID)
|
||||
}
|
||||
|
||||
// User.allowed_groups should no longer include the deleted group.
|
||||
uAfter, err := userRepo.GetByID(ctx, u.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
|
||||
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
|
||||
|
||||
// API keys bound to the deleted group should have group_id cleared.
|
||||
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, keyAfter.GroupID)
|
||||
}
|
||||
367
internal/repository/billing_cache_integration_test.go
Normal file
367
internal/repository/billing_cache_integration_test.go
Normal file
@@ -0,0 +1,367 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type BillingCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestUserBalance() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
_, err := cache.GetUserBalance(ctx, 1)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(1)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
|
||||
|
||||
_, err := rdb.Get(ctx, balanceKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(2)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 10.5, got, "balance mismatch")
|
||||
|
||||
ttl, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_reduces_balance",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(3)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
|
||||
|
||||
got, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance after deduct")
|
||||
require.Equal(s.T(), 8.25, got, "deduct mismatch")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(100)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
|
||||
|
||||
exists, err := rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
|
||||
|
||||
exists, err = rdb.Exists(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetUserBalance(ctx, userID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "deduct_refreshes_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(103)
|
||||
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
|
||||
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
|
||||
|
||||
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL before deduct")
|
||||
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
|
||||
|
||||
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
|
||||
|
||||
balance, err := cache.GetUserBalance(ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserBalance")
|
||||
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
|
||||
|
||||
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after deduct")
|
||||
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheSuite) TestSubscriptionCache() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
|
||||
}{
|
||||
{
|
||||
name: "missing_key_returns_redis_nil",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(10)
|
||||
groupID := int64(20)
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_on_nonexistent_is_noop",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(11)
|
||||
groupID := int64(21)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set_and_get_with_ttl",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(12)
|
||||
groupID := int64(22)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 7,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache")
|
||||
require.Equal(s.T(), "active", gotSub.Status)
|
||||
require.Equal(s.T(), int64(7), gotSub.Version)
|
||||
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
|
||||
|
||||
ttl, err := rdb.TTL(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "TTL subKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "update_usage_increments_all_fields",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(13)
|
||||
groupID := int64(23)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
|
||||
|
||||
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.NoError(s.T(), err, "GetSubscriptionCache after update")
|
||||
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
|
||||
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
|
||||
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalidate_removes_key",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(101)
|
||||
groupID := int64(10)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
DailyUsage: 1.0,
|
||||
WeeklyUsage: 2.0,
|
||||
MonthlyUsage: 3.0,
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
|
||||
|
||||
exists, err := rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists")
|
||||
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
|
||||
|
||||
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
|
||||
|
||||
exists, err = rdb.Exists(ctx, subKey).Result()
|
||||
require.NoError(s.T(), err, "Exists after invalidate")
|
||||
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
|
||||
|
||||
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "missing_status_returns_parsing_error",
|
||||
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
|
||||
userID := int64(102)
|
||||
groupID := int64(11)
|
||||
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
|
||||
|
||||
fields := map[string]any{
|
||||
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
|
||||
"daily_usage": 1.0,
|
||||
"weekly_usage": 2.0,
|
||||
"monthly_usage": 3.0,
|
||||
"version": 1,
|
||||
}
|
||||
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
|
||||
|
||||
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
require.Error(s.T(), err, "expected error for missing status field")
|
||||
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
|
||||
require.Equal(s.T(), "invalid cache: missing status", err.Error())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
tt.fn(ctx, rdb, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func(ctx context.Context, cache service.BillingCache)
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "key_not_exists_returns_nil",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
// key 不存在时,Lua 脚本返回 0(redis.Nil),应返回 nil 而非错误
|
||||
err := cache.DeductUserBalance(ctx, 99999, 1.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "existing_key_deducts_successfully",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
|
||||
err := cache.DeductUserBalance(ctx, 200, 10.0)
|
||||
require.NoError(s.T(), err, "DeductUserBalance should succeed")
|
||||
|
||||
bal, err := cache.GetUserBalance(ctx, 200)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "cancelled_context_propagates_error",
|
||||
fn: func(ctx context.Context, cache service.BillingCache) {
|
||||
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel() // 立即取消
|
||||
|
||||
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
tt.fn(ctx, cache)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
|
||||
// Redis 真实错误应传播,key 不存在(redis.Nil)应返回 nil。
|
||||
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
|
||||
s.Run("key_not_exists_returns_nil", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
|
||||
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
|
||||
})
|
||||
|
||||
s.Run("cancelled_context_propagates_error", func() {
|
||||
rdb := testRedis(s.T())
|
||||
cache := NewBillingCache(rdb)
|
||||
ctx := context.Background()
|
||||
|
||||
data := &service.SubscriptionCacheData{
|
||||
Status: "active",
|
||||
ExpiresAt: time.Now().Add(1 * time.Hour),
|
||||
Version: 1,
|
||||
}
|
||||
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
|
||||
|
||||
cancelCtx, cancel := context.WithCancel(ctx)
|
||||
cancel()
|
||||
|
||||
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
|
||||
require.Error(s.T(), err, "cancelled context should propagate error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBillingCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(BillingCacheSuite))
|
||||
}
|
||||
111
internal/repository/billing_cache_test.go
Normal file
111
internal/repository/billing_cache_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBillingBalanceKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_user_id",
|
||||
userID: 123,
|
||||
expected: "billing:balance:123",
|
||||
},
|
||||
{
|
||||
name: "zero_user_id",
|
||||
userID: 0,
|
||||
expected: "billing:balance:0",
|
||||
},
|
||||
{
|
||||
name: "negative_user_id",
|
||||
userID: -1,
|
||||
expected: "billing:balance:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
userID: math.MaxInt64,
|
||||
expected: "billing:balance:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingBalanceKey(tc.userID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBillingSubKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
groupID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_ids",
|
||||
userID: 123,
|
||||
groupID: 456,
|
||||
expected: "billing:sub:123:456",
|
||||
},
|
||||
{
|
||||
name: "zero_ids",
|
||||
userID: 0,
|
||||
groupID: 0,
|
||||
expected: "billing:sub:0:0",
|
||||
},
|
||||
{
|
||||
name: "negative_ids",
|
||||
userID: -1,
|
||||
groupID: -2,
|
||||
expected: "billing:sub:-1:-2",
|
||||
},
|
||||
{
|
||||
name: "max_int64_ids",
|
||||
userID: math.MaxInt64,
|
||||
groupID: math.MaxInt64,
|
||||
expected: "billing:sub:9223372036854775807:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := billingSubKey(tc.userID, tc.groupID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL(t *testing.T) {
|
||||
const (
|
||||
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
|
||||
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
|
||||
)
|
||||
|
||||
for i := 0; i < 200; i++ {
|
||||
ttl := jitteredTTL()
|
||||
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
|
||||
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJitteredTTL_HasVariation(t *testing.T) {
|
||||
// 多次调用应该产生不同的值(验证抖动存在)
|
||||
seen := make(map[time.Duration]struct{}, 50)
|
||||
for i := 0; i < 50; i++ {
|
||||
seen[jitteredTTL()] = struct{}{}
|
||||
}
|
||||
// 50 次调用中应该至少有 2 个不同的值
|
||||
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
|
||||
}
|
||||
487
internal/repository/concurrency_cache_integration_test.go
Normal file
487
internal/repository/concurrency_cache_integration_test.go
Normal file
@@ -0,0 +1,487 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// 测试用 TTL 配置(15 分钟,与默认值一致)
|
||||
const testSlotTTLMinutes = 15
|
||||
|
||||
// 测试用 TTL Duration,用于 TTL 断言
|
||||
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
|
||||
|
||||
type ConcurrencyCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.ConcurrencyCache
|
||||
}
|
||||
|
||||
func TestConcurrencyCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(ConcurrencyCacheSuite))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
|
||||
accountID := int64(10)
|
||||
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot 3")
|
||||
require.False(s.T(), ok, "expected third acquire to fail")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency")
|
||||
require.Equal(s.T(), 2, cur, "concurrency mismatch")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
|
||||
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err, "GetAccountConcurrency after release")
|
||||
require.Equal(s.T(), 1, cur, "expected 1 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
|
||||
accountID := int64(11)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireAccountSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
|
||||
accountID := int64(12)
|
||||
reqID := "dup-req"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Acquiring with same reqID should be idempotent
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
|
||||
accountID := int64(13)
|
||||
reqID := "release-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
|
||||
// Releasing again should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
|
||||
// Releasing non-existent should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
|
||||
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
|
||||
accountID := int64(14)
|
||||
reqID := "max-zero-test"
|
||||
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
|
||||
require.NoError(s.T(), err)
|
||||
require.False(s.T(), ok, "expected acquire to fail with max=0")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
|
||||
userID := int64(42)
|
||||
reqID1, reqID2 := "req1", "req2"
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot 2")
|
||||
require.False(s.T(), ok, "expected second acquire to fail at max=1")
|
||||
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency")
|
||||
require.Equal(s.T(), 1, cur, "expected concurrency=1")
|
||||
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
|
||||
// Releasing a non-existent slot should not error
|
||||
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
|
||||
|
||||
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
|
||||
require.NoError(s.T(), err, "GetUserConcurrency after release")
|
||||
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
|
||||
userID := int64(200)
|
||||
reqID := "req_ttl_test"
|
||||
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
|
||||
require.NoError(s.T(), err, "AcquireUserSlot")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
|
||||
require.NoError(s.T(), err, "TTL")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
|
||||
userID := int64(20)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount 3")
|
||||
require.False(s.T(), ok, "expected wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
|
||||
userID := int64(300)
|
||||
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
|
||||
// Test decrement on non-existent key - should not error and should not create negative value
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
|
||||
|
||||
// Verify no key was created or it's not negative
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
|
||||
|
||||
// Set count to 1, then decrement twice
|
||||
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
|
||||
require.NoError(s.T(), err, "IncrementWaitCount")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Decrement once (1 -> 0)
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
|
||||
|
||||
// Decrement again on 0 - should not go negative
|
||||
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
|
||||
|
||||
// Verify count is 0, not negative
|
||||
val, err = s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey after double decrement")
|
||||
}
|
||||
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
|
||||
accountID := int64(30)
|
||||
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
|
||||
require.True(s.T(), ok)
|
||||
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
|
||||
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
|
||||
require.False(s.T(), ok, "expected account wait increment over max to fail")
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
|
||||
require.NoError(s.T(), err, "TTL account waitKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
|
||||
|
||||
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
|
||||
|
||||
val, err := s.rdb.Get(s.ctx, waitKey).Int()
|
||||
if !errors.Is(err, redis.Nil) {
|
||||
require.NoError(s.T(), err, "Get waitKey")
|
||||
}
|
||||
require.Equal(s.T(), 1, val, "expected account wait count 1")
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := time.Now().Unix()
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-1"},
|
||||
redis.Z{Score: float64(now), Member: "keep-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
|
||||
redis.Z{Score: float64(now), Member: "oldproc-2"},
|
||||
redis.Z{Score: float64(now), Member: "keep-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"keep-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.True(s.T(), errors.Is(err, redis.Nil))
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
|
||||
// When no slots exist, GetAccountConcurrency should return 0
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
|
||||
// When no slots exist, GetUserConcurrency should return 0
|
||||
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 0, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
|
||||
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
|
||||
// Setup: Create accounts with different load states
|
||||
account1 := int64(100)
|
||||
account2 := int64(101)
|
||||
account3 := int64(102)
|
||||
|
||||
// Account 1: 2/3 slots used, 1 waiting
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 2: 1/2 slots used, 0 waiting
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Account 3: 0/1 slots used, 0 waiting (idle)
|
||||
|
||||
// Query batch load
|
||||
accounts := []service.AccountWithConcurrency{
|
||||
{ID: account1, MaxConcurrency: 3},
|
||||
{ID: account2, MaxConcurrency: 2},
|
||||
{ID: account3, MaxConcurrency: 1},
|
||||
}
|
||||
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), loadMap, 3)
|
||||
|
||||
// Verify account1: (2 + 1) / 3 = 100%
|
||||
load1 := loadMap[account1]
|
||||
require.NotNil(s.T(), load1)
|
||||
require.Equal(s.T(), account1, load1.AccountID)
|
||||
require.Equal(s.T(), 2, load1.CurrentConcurrency)
|
||||
require.Equal(s.T(), 1, load1.WaitingCount)
|
||||
require.Equal(s.T(), 100, load1.LoadRate)
|
||||
|
||||
// Verify account2: (1 + 0) / 2 = 50%
|
||||
load2 := loadMap[account2]
|
||||
require.NotNil(s.T(), load2)
|
||||
require.Equal(s.T(), account2, load2.AccountID)
|
||||
require.Equal(s.T(), 1, load2.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load2.WaitingCount)
|
||||
require.Equal(s.T(), 50, load2.LoadRate)
|
||||
|
||||
// Verify account3: (0 + 0) / 1 = 0%
|
||||
load3 := loadMap[account3]
|
||||
require.NotNil(s.T(), load3)
|
||||
require.Equal(s.T(), account3, load3.AccountID)
|
||||
require.Equal(s.T(), 0, load3.CurrentConcurrency)
|
||||
require.Equal(s.T(), 0, load3.WaitingCount)
|
||||
require.Equal(s.T(), 0, load3.LoadRate)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
|
||||
// Test with empty account list
|
||||
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
|
||||
require.NoError(s.T(), err)
|
||||
require.Empty(s.T(), loadMap)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
|
||||
accountID := int64(200)
|
||||
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
|
||||
// Acquire 3 slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Verify 3 slots exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 3, cur)
|
||||
|
||||
// Manually set old timestamps for req1 and req2 (simulate expired slots)
|
||||
now := time.Now().Unix()
|
||||
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Run cleanup
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify only 1 slot remains (req3)
|
||||
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 1, cur)
|
||||
|
||||
// Verify req3 still exists
|
||||
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Len(s.T(), members, 1)
|
||||
require.Equal(s.T(), "req3", members[0])
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
|
||||
accountID := int64(201)
|
||||
|
||||
// Acquire 2 fresh slots
|
||||
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
|
||||
require.NoError(s.T(), err)
|
||||
require.True(s.T(), ok)
|
||||
|
||||
// Run cleanup (should not remove anything)
|
||||
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
|
||||
// Verify both slots still exist
|
||||
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), 2, cur)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
|
||||
accountID := int64(901)
|
||||
userID := int64(902)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
|
||||
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
|
||||
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
|
||||
|
||||
now := float64(time.Now().Unix())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-1"},
|
||||
redis.Z{Score: now, Member: "activeproc-1"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
|
||||
redis.Z{Score: now, Member: "oldproc-2"},
|
||||
redis.Z{Score: now, Member: "activeproc-2"},
|
||||
).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
|
||||
|
||||
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
|
||||
|
||||
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
|
||||
require.ErrorIs(s.T(), err, redis.Nil)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
|
||||
accountID := int64(903)
|
||||
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
|
||||
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
|
||||
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
|
||||
|
||||
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
|
||||
|
||||
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
require.EqualValues(s.T(), 0, exists)
|
||||
}
|
||||
150
internal/repository/custom_field.go
Normal file
150
internal/repository/custom_field.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// CustomFieldRepository 自定义字段数据访问层
|
||||
type CustomFieldRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewCustomFieldRepository 创建自定义字段数据访问层
|
||||
func NewCustomFieldRepository(db *gorm.DB) *CustomFieldRepository {
|
||||
return &CustomFieldRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建自定义字段
|
||||
func (r *CustomFieldRepository) Create(ctx context.Context, field *domain.CustomField) error {
|
||||
return r.db.WithContext(ctx).Create(field).Error
|
||||
}
|
||||
|
||||
// Update 更新自定义字段
|
||||
func (r *CustomFieldRepository) Update(ctx context.Context, field *domain.CustomField) error {
|
||||
return r.db.WithContext(ctx).Save(field).Error
|
||||
}
|
||||
|
||||
// Delete 删除自定义字段
|
||||
func (r *CustomFieldRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.CustomField{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取自定义字段
|
||||
func (r *CustomFieldRepository) GetByID(ctx context.Context, id int64) (*domain.CustomField, error) {
|
||||
var field domain.CustomField
|
||||
err := r.db.WithContext(ctx).First(&field, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &field, nil
|
||||
}
|
||||
|
||||
// GetByFieldKey 根据FieldKey获取自定义字段
|
||||
func (r *CustomFieldRepository) GetByFieldKey(ctx context.Context, fieldKey string) (*domain.CustomField, error) {
|
||||
var field domain.CustomField
|
||||
err := r.db.WithContext(ctx).Where("field_key = ?", fieldKey).First(&field).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &field, nil
|
||||
}
|
||||
|
||||
// List 获取所有启用的自定义字段
|
||||
func (r *CustomFieldRepository) List(ctx context.Context) ([]*domain.CustomField, error) {
|
||||
var fields []*domain.CustomField
|
||||
err := r.db.WithContext(ctx).Where("status = ?", 1).Order("sort ASC").Find(&fields).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// ListAll 获取所有自定义字段
|
||||
func (r *CustomFieldRepository) ListAll(ctx context.Context) ([]*domain.CustomField, error) {
|
||||
var fields []*domain.CustomField
|
||||
err := r.db.WithContext(ctx).Order("sort ASC").Find(&fields).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return fields, nil
|
||||
}
|
||||
|
||||
// UserCustomFieldValueRepository 用户自定义字段值数据访问层
|
||||
type UserCustomFieldValueRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserCustomFieldValueRepository 创建用户自定义字段值数据访问层
|
||||
func NewUserCustomFieldValueRepository(db *gorm.DB) *UserCustomFieldValueRepository {
|
||||
return &UserCustomFieldValueRepository{db: db}
|
||||
}
|
||||
|
||||
// Set 为用户设置自定义字段值(upsert)
|
||||
func (r *UserCustomFieldValueRepository) Set(ctx context.Context, userID int64, fieldID int64, fieldKey, value string) error {
|
||||
return r.db.WithContext(ctx).Exec(`
|
||||
INSERT INTO user_custom_field_values (user_id, field_id, field_key, value, created_at, updated_at)
|
||||
VALUES (?, ?, ?, ?, NOW(), NOW())
|
||||
ON CONFLICT(user_id, field_id) DO UPDATE SET value = ?, updated_at = NOW()
|
||||
`, userID, fieldID, fieldKey, value, value).Error
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的所有自定义字段值
|
||||
func (r *UserCustomFieldValueRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserCustomFieldValue, error) {
|
||||
var values []*domain.UserCustomFieldValue
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&values).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return values, nil
|
||||
}
|
||||
|
||||
// GetByUserIDAndFieldKey 获取用户指定字段的值
|
||||
func (r *UserCustomFieldValueRepository) GetByUserIDAndFieldKey(ctx context.Context, userID int64, fieldKey string) (*domain.UserCustomFieldValue, error) {
|
||||
var value domain.UserCustomFieldValue
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND field_key = ?", userID, fieldKey).First(&value).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &value, nil
|
||||
}
|
||||
|
||||
// Delete 删除用户的自定义字段值
|
||||
func (r *UserCustomFieldValueRepository) Delete(ctx context.Context, userID int64, fieldID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ? AND field_id = ?", userID, fieldID).Delete(&domain.UserCustomFieldValue{}).Error
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户的所有自定义字段值
|
||||
func (r *UserCustomFieldValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserCustomFieldValue{}).Error
|
||||
}
|
||||
|
||||
// BatchSet 批量设置用户的自定义字段值
|
||||
func (r *UserCustomFieldValueRepository) BatchSet(ctx context.Context, userID int64, values map[string]string) error {
|
||||
if len(values) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
for fieldKey, value := range values {
|
||||
if err := tx.Exec(`
|
||||
INSERT INTO user_custom_field_values (user_id, field_id, field_key, value, created_at, updated_at)
|
||||
VALUES (
|
||||
?,
|
||||
(SELECT id FROM custom_fields WHERE field_key = ? LIMIT 1),
|
||||
?,
|
||||
?,
|
||||
NOW(),
|
||||
NOW()
|
||||
)
|
||||
ON CONFLICT(user_id, field_id) DO UPDATE SET value = ?, updated_at = NOW()
|
||||
`, userID, fieldKey, fieldKey, value, value).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
32
internal/repository/db_pool.go
Normal file
32
internal/repository/db_pool.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
type dbPoolSettings struct {
|
||||
MaxOpenConns int
|
||||
MaxIdleConns int
|
||||
ConnMaxLifetime time.Duration
|
||||
ConnMaxIdleTime time.Duration
|
||||
}
|
||||
|
||||
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
|
||||
return dbPoolSettings{
|
||||
MaxOpenConns: cfg.Database.MaxOpenConns,
|
||||
MaxIdleConns: cfg.Database.MaxIdleConns,
|
||||
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
|
||||
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
db.SetMaxOpenConns(settings.MaxOpenConns)
|
||||
db.SetMaxIdleConns(settings.MaxIdleConns)
|
||||
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
|
||||
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
|
||||
}
|
||||
50
internal/repository/db_pool_test.go
Normal file
50
internal/repository/db_pool_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func TestBuildDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 50,
|
||||
MaxIdleConns: 10,
|
||||
ConnMaxLifetimeMinutes: 30,
|
||||
ConnMaxIdleTimeMinutes: 5,
|
||||
},
|
||||
}
|
||||
|
||||
settings := buildDBPoolSettings(cfg)
|
||||
require.Equal(t, 50, settings.MaxOpenConns)
|
||||
require.Equal(t, 10, settings.MaxIdleConns)
|
||||
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
|
||||
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
|
||||
}
|
||||
|
||||
func TestApplyDBPoolSettings(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
MaxOpenConns: 40,
|
||||
MaxIdleConns: 8,
|
||||
ConnMaxLifetimeMinutes: 15,
|
||||
ConnMaxIdleTimeMinutes: 3,
|
||||
},
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_ = db.Close()
|
||||
})
|
||||
|
||||
applyDBPoolSettings(db, cfg)
|
||||
stats := db.Stats()
|
||||
require.Equal(t, 40, stats.MaxOpenConnections)
|
||||
}
|
||||
256
internal/repository/device.go
Normal file
256
internal/repository/device.go
Normal file
@@ -0,0 +1,256 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// DeviceRepository 设备数据访问层
|
||||
type DeviceRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewDeviceRepository 创建设备数据访问层
|
||||
func NewDeviceRepository(db *gorm.DB) *DeviceRepository {
|
||||
return &DeviceRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建设备
|
||||
func (r *DeviceRepository) Create(ctx context.Context, device *domain.Device) error {
|
||||
// GORM omits zero values on insert for fields with DB defaults. Explicitly
|
||||
// backfill inactive status so callers can persist status=0 devices.
|
||||
requestedStatus := device.Status
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(device).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requestedStatus == domain.DeviceStatusInactive {
|
||||
if err := tx.Model(&domain.Device{}).Where("id = ?", device.ID).Update("status", requestedStatus).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
device.Status = requestedStatus
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update 更新设备
|
||||
func (r *DeviceRepository) Update(ctx context.Context, device *domain.Device) error {
|
||||
return r.db.WithContext(ctx).Save(device).Error
|
||||
}
|
||||
|
||||
// Delete 删除设备
|
||||
func (r *DeviceRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.Device{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取设备
|
||||
func (r *DeviceRepository) GetByID(ctx context.Context, id int64) (*domain.Device, error) {
|
||||
var device domain.Device
|
||||
err := r.db.WithContext(ctx).First(&device, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// GetByDeviceID 根据设备ID和用户ID获取设备
|
||||
func (r *DeviceRepository) GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
var device domain.Device
|
||||
err := r.db.WithContext(ctx).Where("user_id = ? AND device_id = ?", userID, deviceID).First(&device).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// List 获取设备列表
|
||||
func (r *DeviceRepository) List(ctx context.Context, offset, limit int) ([]*domain.Device, int64, error) {
|
||||
var devices []*domain.Device
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Device{})
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return devices, total, nil
|
||||
}
|
||||
|
||||
// ListByUserID 根据用户ID获取设备列表
|
||||
func (r *DeviceRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error) {
|
||||
var devices []*domain.Device
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("user_id = ?", userID)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Order("last_active_time DESC").Find(&devices).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return devices, total, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态获取设备列表
|
||||
func (r *DeviceRepository) ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error) {
|
||||
var devices []*domain.Device
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return devices, total, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新设备状态
|
||||
func (r *DeviceRepository) UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateLastActiveTime 更新最后活跃时间
|
||||
func (r *DeviceRepository) UpdateLastActiveTime(ctx context.Context, id int64) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("last_active_time", now).Error
|
||||
}
|
||||
|
||||
// Exists 检查设备是否存在
|
||||
func (r *DeviceRepository) Exists(ctx context.Context, userID int64, deviceID string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.Device{}).
|
||||
Where("user_id = ? AND device_id = ?", userID, deviceID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户的所有设备
|
||||
func (r *DeviceRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.Device{}).Error
|
||||
}
|
||||
|
||||
// GetActiveDevices 获取活跃设备
|
||||
func (r *DeviceRepository) GetActiveDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
|
||||
var devices []*domain.Device
|
||||
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour)
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND last_active_time > ?", userID, thirtyDaysAgo).
|
||||
Order("last_active_time DESC").
|
||||
Find(&devices).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// TrustDevice 设置设备为信任状态
|
||||
func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expiresAt *time.Time) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_trusted": true,
|
||||
"trust_expires_at": expiresAt,
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// UntrustDevice 取消设备信任状态
|
||||
func (r *DeviceRepository) UntrustDevice(ctx context.Context, deviceID int64) error {
|
||||
updates := map[string]interface{}{
|
||||
"is_trusted": false,
|
||||
"trust_expires_at": nil,
|
||||
}
|
||||
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
|
||||
}
|
||||
|
||||
// DeleteAllByUserIDExcept 删除用户的所有设备(除指定设备外)
|
||||
func (r *DeviceRepository) DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND id != ?", userID, exceptDeviceID).
|
||||
Delete(&domain.Device{}).Error
|
||||
}
|
||||
|
||||
// GetTrustedDevices 获取用户的信任设备列表
|
||||
func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
|
||||
var devices []*domain.Device
|
||||
now := time.Now()
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now).
|
||||
Order("last_active_time DESC").
|
||||
Find(&devices).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
// ListDevicesParams 设备列表查询参数
|
||||
type ListDevicesParams struct {
|
||||
UserID int64
|
||||
Status domain.DeviceStatus
|
||||
IsTrusted *bool
|
||||
Keyword string
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// ListAll 获取所有设备列表(支持筛选)
|
||||
func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParams) ([]*domain.Device, int64, error) {
|
||||
var devices []*domain.Device
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Device{})
|
||||
|
||||
// 按用户ID筛选
|
||||
if params.UserID > 0 {
|
||||
query = query.Where("user_id = ?", params.UserID)
|
||||
}
|
||||
// 按状态筛选
|
||||
if params.Status >= 0 {
|
||||
query = query.Where("status = ?", params.Status)
|
||||
}
|
||||
// 按信任状态筛选
|
||||
if params.IsTrusted != nil {
|
||||
query = query.Where("is_trusted = ?", *params.IsTrusted)
|
||||
}
|
||||
// 按关键词筛选(设备名/IP/位置)
|
||||
if params.Keyword != "" {
|
||||
search := "%" + params.Keyword + "%"
|
||||
query = query.Where("device_name LIKE ? OR ip LIKE ? OR location LIKE ?", search, search, search)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(params.Offset).Limit(params.Limit).
|
||||
Order("last_active_time DESC").Find(&devices).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return devices, total, nil
|
||||
}
|
||||
92
internal/repository/email_cache_integration_test.go
Normal file
92
internal/repository/email_cache_integration_test.go
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type EmailCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.EmailCache
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewEmailCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
|
||||
email := "a@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
got, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode")
|
||||
require.Equal(s.T(), "123456", got.Code)
|
||||
require.Equal(s.T(), 1, got.Attempts)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
|
||||
email := "ttl@example.com"
|
||||
emailTTL := 2 * time.Minute
|
||||
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
|
||||
|
||||
emailKey := verifyCodeKeyPrefix + email
|
||||
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
|
||||
require.NoError(s.T(), err, "TTL emailKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
|
||||
email := "delete@example.com"
|
||||
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
|
||||
|
||||
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
|
||||
|
||||
// Verify it exists
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.NoError(s.T(), err, "GetVerificationCode before delete")
|
||||
|
||||
// Delete
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
|
||||
|
||||
// Verify it's gone
|
||||
_, err = s.cache.GetVerificationCode(s.ctx, email)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
|
||||
// Deleting a non-existent key should not error
|
||||
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
|
||||
}
|
||||
|
||||
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
|
||||
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
|
||||
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestEmailCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(EmailCacheSuite))
|
||||
}
|
||||
45
internal/repository/email_cache_test.go
Normal file
45
internal/repository/email_cache_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestVerifyCodeKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_email",
|
||||
email: "user@example.com",
|
||||
expected: "verify_code:user@example.com",
|
||||
},
|
||||
{
|
||||
name: "empty_email",
|
||||
email: "",
|
||||
expected: "verify_code:",
|
||||
},
|
||||
{
|
||||
name: "email_with_plus",
|
||||
email: "user+tag@example.com",
|
||||
expected: "verify_code:user+tag@example.com",
|
||||
},
|
||||
{
|
||||
name: "email_with_special_chars",
|
||||
email: "user.name+tag@sub.domain.com",
|
||||
expected: "verify_code:user.name+tag@sub.domain.com",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := verifyCodeKey(tc.email)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
109
internal/repository/gateway_cache_integration_test.go
Normal file
109
internal/repository/gateway_cache_integration_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GatewayCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GatewayCache
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGatewayCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
|
||||
sessionID := "s1"
|
||||
accountID := int64(99)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.NoError(s.T(), err, "GetSessionAccountID")
|
||||
require.Equal(s.T(), accountID, sid, "session id mismatch")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
|
||||
sessionID := "s2"
|
||||
accountID := int64(100)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL sessionKey after Set")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
|
||||
sessionID := "s3"
|
||||
accountID := int64(101)
|
||||
groupID := int64(1)
|
||||
initialTTL := 1 * time.Minute
|
||||
refreshTTL := 3 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
|
||||
|
||||
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
|
||||
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
|
||||
require.NoError(s.T(), err, "TTL after Refresh")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
|
||||
// RefreshSessionTTL on a missing key should not error (no-op)
|
||||
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
|
||||
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
|
||||
sessionID := "openai:s4"
|
||||
accountID := int64(102)
|
||||
groupID := int64(1)
|
||||
sessionTTL := 1 * time.Minute
|
||||
|
||||
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
|
||||
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
|
||||
sessionID := "corrupted"
|
||||
groupID := int64(1)
|
||||
sessionKey := buildSessionKey(groupID, sessionID)
|
||||
|
||||
// Set a non-integer value
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
|
||||
|
||||
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
|
||||
require.Error(s.T(), err, "expected error for corrupted value")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
|
||||
}
|
||||
|
||||
func TestGatewayCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayCacheSuite))
|
||||
}
|
||||
250
internal/repository/gateway_routing_integration_test.go
Normal file
250
internal/repository/gateway_routing_integration_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
dbent "github.com/user-management-system/ent"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
// GatewayRoutingSuite 测试网关路由相关的数据库查询
|
||||
// 验证账户选择和分流逻辑在真实数据库环境下的行为
|
||||
type GatewayRoutingSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
accountRepo *accountRepository
|
||||
}
|
||||
|
||||
func (s *GatewayRoutingSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
|
||||
}
|
||||
|
||||
func TestGatewayRoutingSuite(t *testing.T) {
|
||||
suite.Run(t, new(GatewayRoutingSuite))
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
|
||||
// 创建各平台账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-oauth",
|
||||
Platform: service.PlatformGemini,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 1,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-oauth",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 2,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"refresh_token": "test-refresh",
|
||||
"project_id": "test-project",
|
||||
},
|
||||
})
|
||||
|
||||
// 创建不应被选中的 anthropic 账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "anthropic-oauth",
|
||||
Platform: service.PlatformAnthropic,
|
||||
Type: service.AccountTypeOAuth,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Priority: 0,
|
||||
})
|
||||
|
||||
// 查询 gemini + antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
|
||||
|
||||
// 验证返回的账户平台
|
||||
platforms := make(map[string]bool)
|
||||
for _, acc := range accounts {
|
||||
platforms[acc.Platform] = true
|
||||
}
|
||||
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
|
||||
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
|
||||
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
|
||||
|
||||
// 验证账户 ID 匹配
|
||||
ids := make(map[int64]bool)
|
||||
for _, acc := range accounts {
|
||||
ids[acc.ID] = true
|
||||
}
|
||||
s.Require().True(ids[geminiAcc.ID])
|
||||
s.Require().True(ids[antigravityAcc.ID])
|
||||
}
|
||||
|
||||
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
|
||||
// 创建 gemini 分组
|
||||
group := mustCreateGroup(s.T(), s.client, &service.Group{
|
||||
Name: "gemini-group",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
// 创建账户
|
||||
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "bound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "unbound-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只绑定一个账户到分组
|
||||
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
|
||||
|
||||
// 查询分组内的账户
|
||||
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
|
||||
service.PlatformGemini,
|
||||
service.PlatformAntigravity,
|
||||
})
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
|
||||
s.Require().Equal(boundAcc.ID, accounts[0].ID)
|
||||
|
||||
// 确认未绑定的账户不在结果中
|
||||
for _, acc := range accounts {
|
||||
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
|
||||
}
|
||||
}
|
||||
|
||||
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
|
||||
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
|
||||
// 创建多种平台账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-1",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-1",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 只查询 antigravity 平台
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1)
|
||||
s.Require().Equal(antigravity.ID, accounts[0].ID)
|
||||
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
|
||||
}
|
||||
|
||||
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
|
||||
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
|
||||
// 创建可调度账户
|
||||
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "active-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true)
|
||||
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "inactive-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
|
||||
|
||||
// 创建错误状态账户
|
||||
mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "error-antigravity",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusError,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
|
||||
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
|
||||
s.Require().Equal(activeAcc.ID, accounts[0].ID)
|
||||
}
|
||||
|
||||
// TestPlatformRoutingDecision 验证平台路由决策
|
||||
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
|
||||
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
|
||||
// 创建两种平台的账户
|
||||
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "gemini-route-test",
|
||||
Platform: service.PlatformGemini,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
|
||||
Name: "antigravity-route-test",
|
||||
Platform: service.PlatformAntigravity,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expectedService string
|
||||
}{
|
||||
{
|
||||
name: "Gemini账户路由到ForwardNative",
|
||||
accountID: geminiAcc.ID,
|
||||
expectedService: "GeminiMessagesCompatService.ForwardNative",
|
||||
},
|
||||
{
|
||||
name: "Antigravity账户路由到ForwardGemini",
|
||||
accountID: antigravityAcc.ID,
|
||||
expectedService: "AntigravityGatewayService.ForwardGemini",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
s.Run(tt.name, func() {
|
||||
// 从数据库获取账户
|
||||
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
// 模拟 Handler 层的路由决策
|
||||
var routedService string
|
||||
if account.Platform == service.PlatformAntigravity {
|
||||
routedService = "AntigravityGatewayService.ForwardGemini"
|
||||
} else {
|
||||
routedService = "GeminiMessagesCompatService.ForwardNative"
|
||||
}
|
||||
|
||||
s.Require().Equal(tt.expectedService, routedService)
|
||||
})
|
||||
}
|
||||
}
|
||||
9
internal/repository/gemini_drive_client.go
Normal file
9
internal/repository/gemini_drive_client.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package repository
|
||||
|
||||
import "github.com/user-management-system/internal/pkg/geminicli"
|
||||
|
||||
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
|
||||
// Returned as geminicli.DriveClient interface for DI (Strategy A).
|
||||
func NewGeminiDriveClient() geminicli.DriveClient {
|
||||
return geminicli.NewDriveClient()
|
||||
}
|
||||
47
internal/repository/gemini_token_cache_integration_test.go
Normal file
47
internal/repository/gemini_token_cache_integration_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GeminiTokenCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache service.GeminiTokenCache
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewGeminiTokenCache(s.rdb)
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
|
||||
cacheKey := "project-123"
|
||||
token := "token-value"
|
||||
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
|
||||
|
||||
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), token, got)
|
||||
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
|
||||
|
||||
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
|
||||
}
|
||||
|
||||
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
|
||||
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
|
||||
}
|
||||
|
||||
func TestGeminiTokenCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(GeminiTokenCacheSuite))
|
||||
}
|
||||
28
internal/repository/gemini_token_cache_test.go
Normal file
28
internal/repository/gemini_token_cache_test.go
Normal file
@@ -0,0 +1,28 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
|
||||
rdb := redis.NewClient(&redis.Options{
|
||||
Addr: "127.0.0.1:1",
|
||||
DialTimeout: 50 * time.Millisecond,
|
||||
ReadTimeout: 50 * time.Millisecond,
|
||||
WriteTimeout: 50 * time.Millisecond,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
_ = rdb.Close()
|
||||
})
|
||||
|
||||
cache := NewGeminiTokenCache(rdb)
|
||||
err := cache.DeleteAccessToken(context.Background(), "broken")
|
||||
require.Error(t, err)
|
||||
}
|
||||
67
internal/repository/identity_cache_integration_test.go
Normal file
67
internal/repository/identity_cache_integration_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type IdentityCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *identityCache
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewIdentityCache(s.rdb).(*identityCache)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
|
||||
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
|
||||
require.NoError(s.T(), err, "GetFingerprint")
|
||||
require.Equal(s.T(), "c1", gotFP.ClientID)
|
||||
require.Equal(s.T(), "ua", gotFP.UserAgent)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
|
||||
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
|
||||
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
|
||||
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
|
||||
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
|
||||
require.NoError(s.T(), err, "TTL fpKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
|
||||
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
|
||||
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
|
||||
|
||||
_, err := s.cache.GetFingerprint(s.ctx, 999)
|
||||
require.Error(s.T(), err, "expected error for corrupted JSON")
|
||||
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
|
||||
}
|
||||
|
||||
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
|
||||
err := s.cache.SetFingerprint(s.ctx, 100, nil)
|
||||
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
|
||||
}
|
||||
|
||||
func TestIdentityCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(IdentityCacheSuite))
|
||||
}
|
||||
46
internal/repository/identity_cache_test.go
Normal file
46
internal/repository/identity_cache_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFingerprintKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
accountID int64
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "normal_account_id",
|
||||
accountID: 123,
|
||||
expected: "fingerprint:123",
|
||||
},
|
||||
{
|
||||
name: "zero_account_id",
|
||||
accountID: 0,
|
||||
expected: "fingerprint:0",
|
||||
},
|
||||
{
|
||||
name: "negative_account_id",
|
||||
accountID: -1,
|
||||
expected: "fingerprint:-1",
|
||||
},
|
||||
{
|
||||
name: "max_int64",
|
||||
accountID: math.MaxInt64,
|
||||
expected: "fingerprint:9223372036854775807",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
got := fingerprintKey(tc.accountID)
|
||||
require.Equal(t, tc.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
63
internal/repository/inprocess_transport_test.go
Normal file
63
internal/repository/inprocess_transport_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
|
||||
// It captures the request body (if any) and then rewinds it before invoking the handler.
|
||||
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
|
||||
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
|
||||
var body []byte
|
||||
if r.Body != nil {
|
||||
body, _ = io.ReadAll(r.Body)
|
||||
_ = r.Body.Close()
|
||||
r.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
if capture != nil {
|
||||
capture(r, body)
|
||||
}
|
||||
|
||||
rec := httptest.NewRecorder()
|
||||
handler(rec, r)
|
||||
return rec.Result(), nil
|
||||
})
|
||||
}
|
||||
|
||||
var (
|
||||
canListenOnce sync.Once
|
||||
canListen bool
|
||||
canListenErr error
|
||||
)
|
||||
|
||||
func localListenerAvailable() bool {
|
||||
canListenOnce.Do(func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
canListenErr = err
|
||||
canListen = false
|
||||
return
|
||||
}
|
||||
_ = ln.Close()
|
||||
canListen = true
|
||||
})
|
||||
return canListen
|
||||
}
|
||||
|
||||
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
|
||||
tb.Helper()
|
||||
if !localListenerAvailable() {
|
||||
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
|
||||
}
|
||||
return httptest.NewServer(handler)
|
||||
}
|
||||
140
internal/repository/login_log.go
Normal file
140
internal/repository/login_log.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// LoginLogRepository 登录日志仓储
|
||||
type LoginLogRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewLoginLogRepository 创建登录日志仓储
|
||||
func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository {
|
||||
return &LoginLogRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建登录日志
|
||||
func (r *LoginLogRepository) Create(ctx context.Context, log *domain.LoginLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取登录日志
|
||||
func (r *LoginLogRepository) GetByID(ctx context.Context, id int64) (*domain.LoginLog, error) {
|
||||
var log domain.LoginLog
|
||||
if err := r.db.WithContext(ctx).First(&log, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
// ListByUserID 获取用户的登录日志列表
|
||||
func (r *LoginLogRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.LoginLog, int64, error) {
|
||||
var logs []*domain.LoginLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("user_id = ?", userID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// List 获取登录日志列表(管理员用)
|
||||
func (r *LoginLogRepository) List(ctx context.Context, offset, limit int) ([]*domain.LoginLog, int64, error) {
|
||||
var logs []*domain.LoginLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{})
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ListByStatus 按状态查询登录日志
|
||||
func (r *LoginLogRepository) ListByStatus(ctx context.Context, status int, offset, limit int) ([]*domain.LoginLog, int64, error) {
|
||||
var logs []*domain.LoginLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("status = ?", status)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ListByTimeRange 按时间范围查询登录日志
|
||||
func (r *LoginLogRepository) ListByTimeRange(ctx context.Context, start, end time.Time, offset, limit int) ([]*domain.LoginLog, int64, error) {
|
||||
var logs []*domain.LoginLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).
|
||||
Where("created_at >= ? AND created_at <= ?", start, end)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户所有登录日志
|
||||
func (r *LoginLogRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.LoginLog{}).Error
|
||||
}
|
||||
|
||||
// DeleteOlderThan 删除指定天数前的日志
|
||||
func (r *LoginLogRepository) DeleteOlderThan(ctx context.Context, days int) error {
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
return r.db.WithContext(ctx).Where("created_at < ?", cutoff).Delete(&domain.LoginLog{}).Error
|
||||
}
|
||||
|
||||
// CountByResultSince 统计指定时间之后特定结果的登录次数
|
||||
// success=true 统计成功次数,false 统计失败次数
|
||||
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
|
||||
status := 0 // 失败
|
||||
if success {
|
||||
status = 1 // 成功
|
||||
}
|
||||
var count int64
|
||||
r.db.WithContext(ctx).Model(&domain.LoginLog{}).
|
||||
Where("status = ? AND created_at >= ?", status, since).
|
||||
Count(&count)
|
||||
return count
|
||||
}
|
||||
|
||||
// ListAllForExport 获取所有登录日志(用于导出,无分页)
|
||||
func (r *LoginLogRepository) ListAllForExport(ctx context.Context, userID int64, status int, startAt, endAt *time.Time) ([]*domain.LoginLog, error) {
|
||||
var logs []*domain.LoginLog
|
||||
query := r.db.WithContext(ctx).Model(&domain.LoginLog{})
|
||||
|
||||
if userID > 0 {
|
||||
query = query.Where("user_id = ?", userID)
|
||||
}
|
||||
if status == 0 || status == 1 {
|
||||
query = query.Where("status = ?", status)
|
||||
}
|
||||
if startAt != nil {
|
||||
query = query.Where("created_at >= ?", startAt)
|
||||
}
|
||||
if endAt != nil {
|
||||
query = query.Where("created_at <= ?", endAt)
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC").Find(&logs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logs, nil
|
||||
}
|
||||
113
internal/repository/operation_log.go
Normal file
113
internal/repository/operation_log.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// OperationLogRepository 操作日志仓储
|
||||
type OperationLogRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewOperationLogRepository 创建操作日志仓储
|
||||
func NewOperationLogRepository(db *gorm.DB) *OperationLogRepository {
|
||||
return &OperationLogRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建操作日志
|
||||
func (r *OperationLogRepository) Create(ctx context.Context, log *domain.OperationLog) error {
|
||||
return r.db.WithContext(ctx).Create(log).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取操作日志
|
||||
func (r *OperationLogRepository) GetByID(ctx context.Context, id int64) (*domain.OperationLog, error) {
|
||||
var log domain.OperationLog
|
||||
if err := r.db.WithContext(ctx).First(&log, id).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &log, nil
|
||||
}
|
||||
|
||||
// ListByUserID 获取用户的操作日志列表
|
||||
func (r *OperationLogRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.OperationLog, int64, error) {
|
||||
var logs []*domain.OperationLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).Where("user_id = ?", userID)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// List 获取操作日志列表(管理员用)
|
||||
func (r *OperationLogRepository) List(ctx context.Context, offset, limit int) ([]*domain.OperationLog, int64, error) {
|
||||
var logs []*domain.OperationLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{})
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ListByMethod 按HTTP方法查询操作日志
|
||||
func (r *OperationLogRepository) ListByMethod(ctx context.Context, method string, offset, limit int) ([]*domain.OperationLog, int64, error) {
|
||||
var logs []*domain.OperationLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).Where("request_method = ?", method)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// ListByTimeRange 按时间范围查询操作日志
|
||||
func (r *OperationLogRepository) ListByTimeRange(ctx context.Context, start, end time.Time, offset, limit int) ([]*domain.OperationLog, int64, error) {
|
||||
var logs []*domain.OperationLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).
|
||||
Where("created_at >= ? AND created_at <= ?", start, end)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
|
||||
// DeleteOlderThan 删除指定天数前的日志
|
||||
func (r *OperationLogRepository) DeleteOlderThan(ctx context.Context, days int) error {
|
||||
cutoff := time.Now().AddDate(0, 0, -days)
|
||||
return r.db.WithContext(ctx).Where("created_at < ?", cutoff).Delete(&domain.OperationLog{}).Error
|
||||
}
|
||||
|
||||
// Search 按关键词搜索操作日志
|
||||
func (r *OperationLogRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.OperationLog, int64, error) {
|
||||
var logs []*domain.OperationLog
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).
|
||||
Where("operation_name LIKE ? OR request_path LIKE ? OR operation_type LIKE ?",
|
||||
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return logs, total, nil
|
||||
}
|
||||
79
internal/repository/ops_write_pressure_integration_test.go
Normal file
79
internal/repository/ops_write_pressure_integration_test.go
Normal file
@@ -0,0 +1,79 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
|
||||
|
||||
repo := NewOpsRepository(integrationDB).(*opsRepository)
|
||||
now := time.Now().UTC()
|
||||
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
|
||||
{
|
||||
RequestID: "batch-ops-1",
|
||||
ErrorPhase: "upstream",
|
||||
ErrorType: "upstream_error",
|
||||
Severity: "error",
|
||||
StatusCode: 429,
|
||||
ErrorMessage: "rate limited",
|
||||
CreatedAt: now,
|
||||
},
|
||||
{
|
||||
RequestID: "batch-ops-2",
|
||||
ErrorPhase: "internal",
|
||||
ErrorType: "api_error",
|
||||
Severity: "error",
|
||||
StatusCode: 500,
|
||||
ErrorMessage: "internal error",
|
||||
CreatedAt: now.Add(time.Millisecond),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, inserted)
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(12345)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 1, count)
|
||||
|
||||
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
|
||||
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
|
||||
|
||||
accountID := int64(67890)
|
||||
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
|
||||
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
|
||||
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
|
||||
|
||||
var count int
|
||||
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
|
||||
require.Equal(t, 2, count)
|
||||
}
|
||||
16
internal/repository/pagination.go
Normal file
16
internal/repository/pagination.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package repository
|
||||
|
||||
import "github.com/user-management-system/internal/pkg/pagination"
|
||||
|
||||
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
|
||||
pages := int(total) / params.Limit()
|
||||
if int(total)%params.Limit() > 0 {
|
||||
pages++
|
||||
}
|
||||
return &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: params.Page,
|
||||
PageSize: params.Limit(),
|
||||
Pages: pages,
|
||||
}
|
||||
}
|
||||
58
internal/repository/password_history.go
Normal file
58
internal/repository/password_history.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// PasswordHistoryRepository 密码历史记录数据访问层
|
||||
type PasswordHistoryRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPasswordHistoryRepository 创建密码历史记录数据访问层
|
||||
func NewPasswordHistoryRepository(db *gorm.DB) *PasswordHistoryRepository {
|
||||
return &PasswordHistoryRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建密码历史记录
|
||||
func (r *PasswordHistoryRepository) Create(ctx context.Context, history *domain.PasswordHistory) error {
|
||||
return r.db.WithContext(ctx).Create(history).Error
|
||||
}
|
||||
|
||||
// GetByUserID 获取用户的密码历史记录(最近 N 条,按时间倒序)
|
||||
func (r *PasswordHistoryRepository) GetByUserID(ctx context.Context, userID int64, limit int) ([]*domain.PasswordHistory, error) {
|
||||
var histories []*domain.PasswordHistory
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Find(&histories).Error
|
||||
return histories, err
|
||||
}
|
||||
|
||||
// DeleteOldRecords 删除超过 keepCount 条的旧记录(保留最新的 keepCount 条)
|
||||
func (r *PasswordHistoryRepository) DeleteOldRecords(ctx context.Context, userID int64, keepCount int) error {
|
||||
// 找出要保留的最后一条记录的 ID
|
||||
var ids []int64
|
||||
err := r.db.WithContext(ctx).
|
||||
Model(&domain.PasswordHistory{}).
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Limit(keepCount).
|
||||
Pluck("id", &ids).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 删除不在保留列表中的记录
|
||||
return r.db.WithContext(ctx).
|
||||
Where("user_id = ? AND id NOT IN ?", userID, ids).
|
||||
Delete(&domain.PasswordHistory{}).Error
|
||||
}
|
||||
202
internal/repository/permission.go
Normal file
202
internal/repository/permission.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// PermissionRepository 权限数据访问层
|
||||
type PermissionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewPermissionRepository 创建权限数据访问层
|
||||
func NewPermissionRepository(db *gorm.DB) *PermissionRepository {
|
||||
return &PermissionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建权限
|
||||
func (r *PermissionRepository) Create(ctx context.Context, permission *domain.Permission) error {
|
||||
// GORM omits zero values on insert for fields with DB defaults. Explicitly
|
||||
// backfill disabled status so callers can persist status=0 permissions.
|
||||
requestedStatus := permission.Status
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(permission).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requestedStatus == domain.PermissionStatusDisabled {
|
||||
if err := tx.Model(&domain.Permission{}).Where("id = ?", permission.ID).Update("status", requestedStatus).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
permission.Status = requestedStatus
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update 更新权限
|
||||
func (r *PermissionRepository) Update(ctx context.Context, permission *domain.Permission) error {
|
||||
return r.db.WithContext(ctx).Save(permission).Error
|
||||
}
|
||||
|
||||
// Delete 删除权限
|
||||
func (r *PermissionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.Permission{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取权限
|
||||
func (r *PermissionRepository) GetByID(ctx context.Context, id int64) (*domain.Permission, error) {
|
||||
var permission domain.Permission
|
||||
err := r.db.WithContext(ctx).First(&permission, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// GetByCode 根据代码获取权限
|
||||
func (r *PermissionRepository) GetByCode(ctx context.Context, code string) (*domain.Permission, error) {
|
||||
var permission domain.Permission
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&permission).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// List 获取权限列表
|
||||
func (r *PermissionRepository) List(ctx context.Context, offset, limit int) ([]*domain.Permission, int64, error) {
|
||||
var permissions []*domain.Permission
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Permission{})
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return permissions, total, nil
|
||||
}
|
||||
|
||||
// ListByType 根据类型获取权限列表
|
||||
func (r *PermissionRepository) ListByType(ctx context.Context, permissionType domain.PermissionType, offset, limit int) ([]*domain.Permission, int64, error) {
|
||||
var permissions []*domain.Permission
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("type = ?", permissionType)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return permissions, total, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态获取权限列表
|
||||
func (r *PermissionRepository) ListByStatus(ctx context.Context, status domain.PermissionStatus, offset, limit int) ([]*domain.Permission, int64, error) {
|
||||
var permissions []*domain.Permission
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return permissions, total, nil
|
||||
}
|
||||
|
||||
// GetByRoleIDs 根据角色ID获取权限列表
|
||||
func (r *PermissionRepository) GetByRoleIDs(ctx context.Context, roleIDs []int64) ([]*domain.Permission, error) {
|
||||
var permissions []*domain.Permission
|
||||
|
||||
err := r.db.WithContext(ctx).
|
||||
Joins("INNER JOIN role_permissions ON permissions.id = role_permissions.permission_id").
|
||||
Where("role_permissions.role_id IN ?", roleIDs).
|
||||
Where("permissions.status = ?", domain.PermissionStatusEnabled).
|
||||
Find(&permissions).Error
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// ExistsByCode 检查权限代码是否存在
|
||||
func (r *PermissionRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("code = ?", code).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新权限状态
|
||||
func (r *PermissionRepository) UpdateStatus(ctx context.Context, id int64, status domain.PermissionStatus) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.Permission{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// Search 搜索权限
|
||||
func (r *PermissionRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Permission, int64, error) {
|
||||
var permissions []*domain.Permission
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Permission{}).
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return permissions, total, nil
|
||||
}
|
||||
|
||||
// ListByParentID 根据父ID获取权限列表
|
||||
func (r *PermissionRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Permission, error) {
|
||||
var permissions []*domain.Permission
|
||||
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&permissions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表批量获取权限
|
||||
func (r *PermissionRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Permission, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*domain.Permission{}, nil
|
||||
}
|
||||
|
||||
var permissions []*domain.Permission
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&permissions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return permissions, nil
|
||||
}
|
||||
49
internal/repository/redis.go
Normal file
49
internal/repository/redis.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
)
|
||||
|
||||
// InitRedis 初始化 Redis 客户端
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
|
||||
// 1. 默认连接池大小可能不足以支撑高并发
|
||||
// 2. 无超时控制可能导致慢操作阻塞
|
||||
//
|
||||
// 新实现支持可配置的连接池和超时参数:
|
||||
// 1. PoolSize: 控制最大并发连接数(默认 128)
|
||||
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10)
|
||||
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
|
||||
func InitRedis(cfg *config.Config) *redis.Client {
|
||||
return redis.NewClient(buildRedisOptions(cfg))
|
||||
}
|
||||
|
||||
// buildRedisOptions 构建 Redis 连接选项
|
||||
// 从配置文件读取连接池和超时参数,支持生产环境调优
|
||||
func buildRedisOptions(cfg *config.Config) *redis.Options {
|
||||
opts := &redis.Options{
|
||||
Addr: cfg.Redis.Address(),
|
||||
Password: cfg.Redis.Password,
|
||||
DB: cfg.Redis.DB,
|
||||
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
|
||||
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
|
||||
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
|
||||
PoolSize: cfg.Redis.PoolSize, // 连接池大小
|
||||
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
|
||||
}
|
||||
|
||||
if cfg.Redis.EnableTLS {
|
||||
opts.TLSConfig = &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: cfg.Redis.Host,
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
47
internal/repository/redis_test.go
Normal file
47
internal/repository/redis_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBuildRedisOptions(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
Port: 6379,
|
||||
Password: "secret",
|
||||
DB: 2,
|
||||
DialTimeoutSeconds: 5,
|
||||
ReadTimeoutSeconds: 3,
|
||||
WriteTimeoutSeconds: 4,
|
||||
PoolSize: 100,
|
||||
MinIdleConns: 10,
|
||||
},
|
||||
}
|
||||
|
||||
opts := buildRedisOptions(cfg)
|
||||
require.Equal(t, "localhost:6379", opts.Addr)
|
||||
require.Equal(t, "secret", opts.Password)
|
||||
require.Equal(t, 2, opts.DB)
|
||||
require.Equal(t, 5*time.Second, opts.DialTimeout)
|
||||
require.Equal(t, 3*time.Second, opts.ReadTimeout)
|
||||
require.Equal(t, 4*time.Second, opts.WriteTimeout)
|
||||
require.Equal(t, 100, opts.PoolSize)
|
||||
require.Equal(t, 10, opts.MinIdleConns)
|
||||
require.Nil(t, opts.TLSConfig)
|
||||
|
||||
// Test case with TLS enabled
|
||||
cfgTLS := &config.Config{
|
||||
Redis: config.RedisConfig{
|
||||
Host: "localhost",
|
||||
EnableTLS: true,
|
||||
},
|
||||
}
|
||||
optsTLS := buildRedisOptions(cfgTLS)
|
||||
require.NotNil(t, optsTLS.TLSConfig)
|
||||
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
|
||||
}
|
||||
305
internal/repository/repo_bench_test.go
Normal file
305
internal/repository/repo_bench_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
// repo_bench_test.go — repository 层性能基准测试
|
||||
// 覆盖:批量写入、并发只读查询、分页列表、更新状态、软删除
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var repoBenchCounter int64
|
||||
|
||||
// openBenchDB 为 Benchmark 打开独立内存 DB(不依赖 *testing.T)
|
||||
func openBenchDB(b *testing.B) *gorm.DB {
|
||||
b.Helper()
|
||||
id := atomic.AddInt64(&repoBenchCounter, 1)
|
||||
dsn := fmt.Sprintf("file:repobenchdb%d?mode=memory&cache=private", id)
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
b.Fatalf("openBenchDB: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
); err != nil {
|
||||
b.Fatalf("AutoMigrate: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// seedUsers 往 DB 插入 n 条用户
|
||||
func seedUsers(b *testing.B, repo *UserRepository, n int) {
|
||||
b.Helper()
|
||||
ctx := context.Background()
|
||||
for i := 0; i < n; i++ {
|
||||
if err := repo.Create(ctx, &domain.User{
|
||||
Username: fmt.Sprintf("benchuser%06d", i),
|
||||
Email: domain.StrPtr(fmt.Sprintf("bench%06d@example.com", i)),
|
||||
Phone: domain.StrPtr(fmt.Sprintf("1380000%04d", i%10000)),
|
||||
Password: "hashed_placeholder",
|
||||
Status: domain.UserStatusActive,
|
||||
}); err != nil {
|
||||
b.Fatalf("seedUsers i=%d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_Create — 单条写入吞吐 ----------
|
||||
|
||||
func BenchmarkRepo_Create(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.Create(ctx, &domain.User{ //nolint:errcheck
|
||||
Username: fmt.Sprintf("cr_%d_%d", b.N, i),
|
||||
Email: domain.StrPtr(fmt.Sprintf("cr_%d_%d@bench.com", b.N, i)),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_BulkCreate — 批量写入(串行) ----------
|
||||
|
||||
func BenchmarkRepo_BulkCreate(b *testing.B) {
|
||||
sizes := []int{10, 100, 500}
|
||||
for _, size := range sizes {
|
||||
size := size
|
||||
b.Run(fmt.Sprintf("batch=%d", size), func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
users := make([]*domain.User, size)
|
||||
for j := 0; j < size; j++ {
|
||||
users[j] = &domain.User{
|
||||
Username: fmt.Sprintf("bulk_%d_%d_%d", i, j, size),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
}
|
||||
b.StartTimer()
|
||||
for _, u := range users {
|
||||
repo.Create(ctx, u) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_GetByID — 主键查询 ----------
|
||||
|
||||
func BenchmarkRepo_GetByID(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 1000)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
id := int64(1)
|
||||
for pb.Next() {
|
||||
repo.GetByID(ctx, id) //nolint:errcheck
|
||||
id++
|
||||
if id > 1000 {
|
||||
id = 1
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_GetByUsername — 索引查询 ----------
|
||||
|
||||
func BenchmarkRepo_GetByUsername(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 500)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.GetByUsername(ctx, fmt.Sprintf("benchuser%06d", i%500)) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_GetByEmail — 索引查询 ----------
|
||||
|
||||
func BenchmarkRepo_GetByEmail(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 500)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.GetByEmail(ctx, fmt.Sprintf("bench%06d@example.com", i%500)) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_List — 分页列表 ----------
|
||||
|
||||
func BenchmarkRepo_List(b *testing.B) {
|
||||
pageSizes := []int{10, 50, 200}
|
||||
for _, ps := range pageSizes {
|
||||
ps := ps
|
||||
b.Run(fmt.Sprintf("pageSize=%d", ps), func(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 1000)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.List(ctx, 0, ps) //nolint:errcheck
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_ListByStatus ----------
|
||||
|
||||
func BenchmarkRepo_ListByStatus(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 1000)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.ListByStatus(ctx, domain.UserStatusActive, 0, 20) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_UpdateStatus ----------
|
||||
|
||||
func BenchmarkRepo_UpdateStatus(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 200)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id := int64(i%200) + 1
|
||||
repo.UpdateStatus(ctx, id, domain.UserStatusActive) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_Update — 全字段更新 ----------
|
||||
|
||||
func BenchmarkRepo_Update(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 100)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
id := int64(i%100) + 1
|
||||
u, err := repo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
u.Nickname = fmt.Sprintf("nick%d", i)
|
||||
repo.Update(ctx, u) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_Delete — 软删除 ----------
|
||||
|
||||
func BenchmarkRepo_Delete(b *testing.B) {
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
b.StopTimer()
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
repo.Create(ctx, &domain.User{Username: "victim", Password: "hash", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
b.StartTimer()
|
||||
repo.Delete(ctx, 1) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_ExistsByUsername ----------
|
||||
|
||||
func BenchmarkRepo_ExistsByUsername(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 500)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := 0
|
||||
for pb.Next() {
|
||||
repo.ExistsByUsername(ctx, fmt.Sprintf("benchuser%06d", i%500)) //nolint:errcheck
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_ConcurrentReadWrite — 高并发读写混合 ----------
|
||||
|
||||
func BenchmarkRepo_ConcurrentReadWrite(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 200)
|
||||
ctx := context.Background()
|
||||
|
||||
var mu sync.Mutex // SQLite 不支持多写并发,需要序列化写入
|
||||
b.ResetTimer()
|
||||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
i := int64(1)
|
||||
for pb.Next() {
|
||||
if i%10 == 0 {
|
||||
// 10% 写操作
|
||||
mu.Lock()
|
||||
repo.UpdateLastLogin(ctx, i%200+1, "10.0.0.1") //nolint:errcheck
|
||||
mu.Unlock()
|
||||
} else {
|
||||
// 90% 读操作
|
||||
repo.GetByID(ctx, i%200+1) //nolint:errcheck
|
||||
}
|
||||
i++
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// ---------- BenchmarkRepo_Search — 模糊搜索 ----------
|
||||
|
||||
func BenchmarkRepo_Search(b *testing.B) {
|
||||
db := openBenchDB(b)
|
||||
repo := NewUserRepository(db)
|
||||
seedUsers(b, repo, 2000)
|
||||
ctx := context.Background()
|
||||
b.ResetTimer()
|
||||
|
||||
keywords := []string{"benchuser000", "bench0001", "benchuser05"}
|
||||
for i := 0; i < b.N; i++ {
|
||||
repo.Search(ctx, keywords[i%len(keywords)], 0, 20) //nolint:errcheck
|
||||
}
|
||||
}
|
||||
536
internal/repository/repo_robustness_test.go
Normal file
536
internal/repository/repo_robustness_test.go
Normal file
@@ -0,0 +1,536 @@
|
||||
// repo_robustness_test.go — repository 层鲁棒性测试
|
||||
// 覆盖:重复主键、唯一索引冲突、大量数据分页正确性、
|
||||
// SQL 注入防护(参数化查询验证)、软删除后查询、
|
||||
// 空字符串/极值/特殊字符输入、上下文取消
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// ============================================================
|
||||
// 1. 唯一索引冲突
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_DuplicateUsername(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u1 := &domain.User{Username: "dupuser", Password: "hash", Status: domain.UserStatusActive}
|
||||
if err := repo.Create(ctx, u1); err != nil {
|
||||
t.Fatalf("第一次创建应成功: %v", err)
|
||||
}
|
||||
|
||||
u2 := &domain.User{Username: "dupuser", Password: "hash2", Status: domain.UserStatusActive}
|
||||
err := repo.Create(ctx, u2)
|
||||
if err == nil {
|
||||
t.Error("重复用户名应返回唯一索引冲突错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_DuplicateEmail(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
email := "dup@example.com"
|
||||
repo.Create(ctx, &domain.User{Username: "user1", Email: domain.StrPtr(email), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
err := repo.Create(ctx, &domain.User{Username: "user2", Email: domain.StrPtr(email), Password: "h", Status: domain.UserStatusActive})
|
||||
if err == nil {
|
||||
t.Error("重复邮箱应返回唯一索引冲突错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_DuplicatePhone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
phone := "13900000001"
|
||||
repo.Create(ctx, &domain.User{Username: "pa", Phone: domain.StrPtr(phone), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
err := repo.Create(ctx, &domain.User{Username: "pb", Phone: domain.StrPtr(phone), Password: "h", Status: domain.UserStatusActive})
|
||||
if err == nil {
|
||||
t.Error("重复手机号应返回唯一索引冲突错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_MultipleNullEmail(t *testing.T) {
|
||||
// NULL 不触发唯一约束,多个用户可以都没有邮箱
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
err := repo.Create(ctx, &domain.User{
|
||||
Username: fmt.Sprintf("nomail%d", i),
|
||||
Email: nil, // NULL
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NULL email 用户%d 创建失败: %v", i, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 2. 查询不存在的记录
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_GetByID_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
_, err := repo.GetByID(context.Background(), 99999)
|
||||
if err == nil {
|
||||
t.Error("查询不存在的 ID 应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_GetByUsername_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
_, err := repo.GetByUsername(context.Background(), "ghost")
|
||||
if err == nil {
|
||||
t.Error("查询不存在的用户名应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_GetByEmail_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
_, err := repo.GetByEmail(context.Background(), "nope@none.com")
|
||||
if err == nil {
|
||||
t.Error("查询不存在的邮箱应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_GetByPhone_NotFound(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
_, err := repo.GetByPhone(context.Background(), "00000000000")
|
||||
if err == nil {
|
||||
t.Error("查询不存在的手机号应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 3. 软删除后的查询行为
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_SoftDelete_HiddenFromGet(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &domain.User{Username: "softdel", Password: "h", Status: domain.UserStatusActive}
|
||||
repo.Create(ctx, u) //nolint:errcheck
|
||||
id := u.ID
|
||||
|
||||
if err := repo.Delete(ctx, id); err != nil {
|
||||
t.Fatalf("Delete: %v", err)
|
||||
}
|
||||
|
||||
_, err := repo.GetByID(ctx, id)
|
||||
if err == nil {
|
||||
t.Error("软删除后 GetByID 应返回错误(记录被隐藏)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_SoftDelete_HiddenFromList(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("listdel%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
users, total, _ := repo.List(ctx, 0, 100)
|
||||
initialCount := len(users)
|
||||
initialTotal := total
|
||||
|
||||
// 删除第一个
|
||||
repo.Delete(ctx, users[0].ID) //nolint:errcheck
|
||||
|
||||
users2, total2, _ := repo.List(ctx, 0, 100)
|
||||
if len(users2) != initialCount-1 {
|
||||
t.Errorf("删除后 List 应减少 1 条,实际 %d -> %d", initialCount, len(users2))
|
||||
}
|
||||
if total2 != initialTotal-1 {
|
||||
t.Errorf("删除后 total 应减少 1,实际 %d -> %d", initialTotal, total2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_DeleteNonExistent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
// 软删除一个不存在的 ID,GORM 通常返回 nil(RowsAffected=0 不报错)
|
||||
err := repo.Delete(context.Background(), 99999)
|
||||
_ = err // 不 panic 即可
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 4. SQL 注入防护(参数化查询)
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_SQLInjection_GetByUsername(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 先插入一个真实用户
|
||||
repo.Create(ctx, &domain.User{Username: "legit", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
// 注入载荷:尝试用 OR '1'='1' 绕过查询
|
||||
injections := []string{
|
||||
"' OR '1'='1",
|
||||
"'; DROP TABLE users; --",
|
||||
`" OR "1"="1`,
|
||||
"admin'--",
|
||||
"legit' UNION SELECT * FROM users --",
|
||||
}
|
||||
|
||||
for _, payload := range injections {
|
||||
_, err := repo.GetByUsername(ctx, payload)
|
||||
if err == nil {
|
||||
t.Errorf("SQL 注入载荷 %q 不应返回用户(应返回 not found)", payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_SQLInjection_Search(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.User{Username: "victim", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
injections := []string{
|
||||
"' OR '1'='1",
|
||||
"%; SELECT * FROM users; --",
|
||||
"victim' UNION SELECT username FROM users --",
|
||||
}
|
||||
|
||||
for _, payload := range injections {
|
||||
users, _, err := repo.Search(ctx, payload, 0, 100)
|
||||
if err != nil {
|
||||
continue // 参数化查询报错也可接受
|
||||
}
|
||||
for _, u := range users {
|
||||
if u.Username == "victim" && !strings.Contains(payload, "victim") {
|
||||
t.Errorf("SQL 注入载荷 %q 不应返回不匹配的用户", payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_SQLInjection_ExistsByUsername(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.User{Username: "realuser", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
// 这些载荷不应导致 ExistsByUsername("' OR '1'='1") 返回 true(找到不存在的用户)
|
||||
exists, err := repo.ExistsByUsername(ctx, "' OR '1'='1")
|
||||
if err != nil {
|
||||
t.Logf("ExistsByUsername SQL注入: err=%v (可接受)", err)
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
t.Error("SQL 注入载荷在 ExistsByUsername 中不应返回 true")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 5. 分页边界值
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_List_ZeroOffset(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("pg%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
users, total, err := repo.List(ctx, 0, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(users) != 3 {
|
||||
t.Errorf("offset=0, limit=3 应返回 3 条,实际 %d", len(users))
|
||||
}
|
||||
if total != 5 {
|
||||
t.Errorf("total 应为 5,实际 %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_List_OffsetBeyondTotal(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 3; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("ov%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
users, total, err := repo.List(ctx, 100, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List: %v", err)
|
||||
}
|
||||
if len(users) != 0 {
|
||||
t.Errorf("offset 超过总数应返回空列表,实际 %d 条", len(users))
|
||||
}
|
||||
if total != 3 {
|
||||
t.Errorf("total 应为 3,实际 %d", total)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_List_LargeLimit(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 10; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("ll%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
users, _, err := repo.List(ctx, 0, 999999)
|
||||
if err != nil {
|
||||
t.Fatalf("List with huge limit: %v", err)
|
||||
}
|
||||
if len(users) != 10 {
|
||||
t.Errorf("超大 limit 应返回全部 10 条,实际 %d", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_List_EmptyDB(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
users, total, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("空 DB List 应无错误: %v", err)
|
||||
}
|
||||
if total != 0 {
|
||||
t.Errorf("空 DB total 应为 0,实际 %d", total)
|
||||
}
|
||||
if len(users) != 0 {
|
||||
t.Errorf("空 DB 应返回空列表,实际 %d 条", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 6. 搜索边界值
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_Search_EmptyKeyword(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("sk%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
users, total, err := repo.Search(ctx, "", 0, 20)
|
||||
// 空关键字 → LIKE '%%' 匹配所有;验证不报错
|
||||
if err != nil {
|
||||
t.Fatalf("空关键字 Search 应无错误: %v", err)
|
||||
}
|
||||
if total < 5 {
|
||||
t.Errorf("空关键字应匹配所有用户(>=5),实际 total=%d,rows=%d", total, len(users))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_Search_SpecialCharsKeyword(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
repo.Create(ctx, &domain.User{Username: "normaluser", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
// 含 LIKE 元字符
|
||||
for _, kw := range []string{"%", "_", "\\", "%_%", "%%"} {
|
||||
_, _, err := repo.Search(ctx, kw, 0, 10)
|
||||
if err != nil {
|
||||
t.Logf("特殊关键字 %q 搜索出错(可接受): %v", kw, err)
|
||||
}
|
||||
// 主要验证不 panic
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_Search_VeryLongKeyword(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
longKw := strings.Repeat("a", 10000)
|
||||
_, _, err := repo.Search(ctx, longKw, 0, 10)
|
||||
_ = err // 不应 panic
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 7. 超长字段存储
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_LongFieldValues(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &domain.User{
|
||||
Username: strings.Repeat("x", 45), // varchar(50) 以内
|
||||
Password: strings.Repeat("y", 200),
|
||||
Nickname: strings.Repeat("n", 45),
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
err := repo.Create(ctx, u)
|
||||
// SQLite 不严格限制 varchar 长度,期望成功;其他数据库可能截断或报错
|
||||
if err != nil {
|
||||
t.Logf("超长字段创建结果: %v(SQLite 可能允许)", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 8. UpdateLastLogin 特殊 IP
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_UpdateLastLogin_EmptyIP(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &domain.User{Username: "iptest", Password: "h", Status: domain.UserStatusActive}
|
||||
repo.Create(ctx, u) //nolint:errcheck
|
||||
|
||||
// 空 IP 不应报错
|
||||
if err := repo.UpdateLastLogin(ctx, u.ID, ""); err != nil {
|
||||
t.Errorf("空 IP UpdateLastLogin 应无错误: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_UpdateLastLogin_LongIP(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u := &domain.User{Username: "longiptest", Password: "h", Status: domain.UserStatusActive}
|
||||
repo.Create(ctx, u) //nolint:errcheck
|
||||
|
||||
longIP := strings.Repeat("1", 500)
|
||||
err := repo.UpdateLastLogin(ctx, u.ID, longIP)
|
||||
_ = err // SQLite 宽容,不 panic 即可
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 9. 并发写入安全(SQLite 序列化写入)
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_ConcurrentCreate_NoDeadlock(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
// 启用 WAL 模式可减少锁冲突,这里使用默认设置
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
const goroutines = 20
|
||||
var wg sync.WaitGroup
|
||||
var mu sync.Mutex // SQLite 只允许单写,用互斥锁序列化
|
||||
errorCount := 0
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
err := repo.Create(ctx, &domain.User{
|
||||
Username: fmt.Sprintf("concurrent_%d", idx),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
if err != nil {
|
||||
errorCount++
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Errorf("序列化并发写入:%d/%d 次失败", errorCount, goroutines)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepo_Robust_ConcurrentReadWrite_NoDataRace(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 预先插入数据
|
||||
for i := 0; i < 10; i++ {
|
||||
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("rw%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var writeMu sync.Mutex
|
||||
|
||||
for i := 0; i < 30; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
if idx%5 == 0 {
|
||||
writeMu.Lock()
|
||||
repo.UpdateStatus(ctx, int64(idx%10)+1, domain.UserStatusActive) //nolint:errcheck
|
||||
writeMu.Unlock()
|
||||
} else {
|
||||
repo.GetByID(ctx, int64(idx%10)+1) //nolint:errcheck
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
// 无 panic / 数据竞争即通过
|
||||
}
|
||||
|
||||
// ============================================================
|
||||
// 10. Exists 方法边界
|
||||
// ============================================================
|
||||
|
||||
func TestRepo_Robust_ExistsByUsername_EmptyString(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
// 查询空字符串用户名,不应 panic
|
||||
exists, err := repo.ExistsByUsername(context.Background(), "")
|
||||
if err != nil {
|
||||
t.Logf("ExistsByUsername('') err: %v", err)
|
||||
}
|
||||
_ = exists
|
||||
}
|
||||
|
||||
func TestRepo_Robust_ExistsByEmail_NilEquivalent(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
// 查询空邮箱
|
||||
exists, err := repo.ExistsByEmail(context.Background(), "")
|
||||
_ = err
|
||||
_ = exists
|
||||
}
|
||||
|
||||
func TestRepo_Robust_ExistsByPhone_SQLInjection(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
repo.Create(ctx, &domain.User{Username: "phoneuser", Phone: domain.StrPtr("13900000001"), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
|
||||
|
||||
exists, err := repo.ExistsByPhone(ctx, "' OR '1'='1")
|
||||
if err != nil {
|
||||
t.Logf("ExistsByPhone SQL注入 err: %v", err)
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
t.Error("SQL 注入载荷在 ExistsByPhone 中不应返回 true")
|
||||
}
|
||||
}
|
||||
466
internal/repository/repository_additional_test.go
Normal file
466
internal/repository/repository_additional_test.go
Normal file
@@ -0,0 +1,466 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func migrateRepositoryTables(t *testing.T, db *gorm.DB, tables ...interface{}) {
|
||||
t.Helper()
|
||||
|
||||
if err := db.AutoMigrate(tables...); err != nil {
|
||||
t.Fatalf("migrate repository tables failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func int64Ptr(v int64) *int64 {
|
||||
return &v
|
||||
}
|
||||
|
||||
func TestDeviceRepositoryLifecycleAndQueries(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
migrateRepositoryTables(t, db, &domain.Device{})
|
||||
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
devices := []*domain.Device{
|
||||
{
|
||||
UserID: 1,
|
||||
DeviceID: "device-alpha",
|
||||
DeviceName: "Alpha",
|
||||
DeviceType: domain.DeviceTypeDesktop,
|
||||
DeviceOS: "Windows",
|
||||
DeviceBrowser: "Chrome",
|
||||
IP: "10.0.0.1",
|
||||
Location: "Shanghai",
|
||||
Status: domain.DeviceStatusActive,
|
||||
LastActiveTime: now.Add(-1 * time.Hour),
|
||||
},
|
||||
{
|
||||
UserID: 1,
|
||||
DeviceID: "device-beta",
|
||||
DeviceName: "Beta",
|
||||
DeviceType: domain.DeviceTypeWeb,
|
||||
DeviceOS: "macOS",
|
||||
DeviceBrowser: "Safari",
|
||||
IP: "10.0.0.2",
|
||||
Location: "Hangzhou",
|
||||
Status: domain.DeviceStatusInactive,
|
||||
LastActiveTime: now.Add(-2 * time.Hour),
|
||||
},
|
||||
{
|
||||
UserID: 2,
|
||||
DeviceID: "device-gamma",
|
||||
DeviceName: "Gamma",
|
||||
DeviceType: domain.DeviceTypeMobile,
|
||||
DeviceOS: "Android",
|
||||
DeviceBrowser: "WebView",
|
||||
IP: "10.0.0.3",
|
||||
Location: "Beijing",
|
||||
Status: domain.DeviceStatusActive,
|
||||
LastActiveTime: now.Add(-40 * 24 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, device := range devices {
|
||||
if err := repo.Create(ctx, device); err != nil {
|
||||
t.Fatalf("Create(%s) failed: %v", device.DeviceID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if allDevices, total, err := repo.List(ctx, 0, 10); err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
} else if total != 3 || len(allDevices) != 3 {
|
||||
t.Fatalf("expected 3 devices, got total=%d len=%d", total, len(allDevices))
|
||||
}
|
||||
|
||||
loadedByDeviceID, err := repo.GetByDeviceID(ctx, 1, "device-beta")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByDeviceID failed: %v", err)
|
||||
}
|
||||
if loadedByDeviceID.DeviceName != "Beta" {
|
||||
t.Fatalf("expected device name Beta, got %q", loadedByDeviceID.DeviceName)
|
||||
}
|
||||
|
||||
exists, err := repo.Exists(ctx, 1, "device-alpha")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists(device-alpha) failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("expected device-alpha to exist")
|
||||
}
|
||||
|
||||
missing, err := repo.Exists(ctx, 1, "missing-device")
|
||||
if err != nil {
|
||||
t.Fatalf("Exists(missing-device) failed: %v", err)
|
||||
}
|
||||
if missing {
|
||||
t.Fatal("expected missing-device to be absent")
|
||||
}
|
||||
|
||||
userDevices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByUserID failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(userDevices) != 2 {
|
||||
t.Fatalf("expected 2 devices for user 1, got total=%d len=%d", total, len(userDevices))
|
||||
}
|
||||
if userDevices[0].DeviceID != "device-alpha" {
|
||||
t.Fatalf("expected latest active device first, got %q", userDevices[0].DeviceID)
|
||||
}
|
||||
|
||||
activeDevices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByStatus failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(activeDevices) != 2 {
|
||||
t.Fatalf("expected 2 active devices, got total=%d len=%d", total, len(activeDevices))
|
||||
}
|
||||
|
||||
if err := repo.UpdateStatus(ctx, devices[1].ID, domain.DeviceStatusActive); err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
beforeTouch, err := repo.GetByID(ctx, devices[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID before UpdateLastActiveTime failed: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if err := repo.UpdateLastActiveTime(ctx, devices[1].ID); err != nil {
|
||||
t.Fatalf("UpdateLastActiveTime failed: %v", err)
|
||||
}
|
||||
|
||||
afterTouch, err := repo.GetByID(ctx, devices[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID after UpdateLastActiveTime failed: %v", err)
|
||||
}
|
||||
if !afterTouch.LastActiveTime.After(beforeTouch.LastActiveTime) {
|
||||
t.Fatal("expected last_active_time to move forward")
|
||||
}
|
||||
|
||||
recentDevices, err := repo.GetActiveDevices(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetActiveDevices failed: %v", err)
|
||||
}
|
||||
if len(recentDevices) != 2 {
|
||||
t.Fatalf("expected 2 recent devices for user 1, got %d", len(recentDevices))
|
||||
}
|
||||
|
||||
if err := repo.DeleteByUserID(ctx, 1); err != nil {
|
||||
t.Fatalf("DeleteByUserID failed: %v", err)
|
||||
}
|
||||
|
||||
remainingDevices, remainingTotal, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List after DeleteByUserID failed: %v", err)
|
||||
}
|
||||
if remainingTotal != 1 || len(remainingDevices) != 1 {
|
||||
t.Fatalf("expected 1 remaining device, got total=%d len=%d", remainingTotal, len(remainingDevices))
|
||||
}
|
||||
|
||||
if err := repo.Delete(ctx, devices[2].ID); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := repo.GetByID(ctx, devices[2].ID); err == nil {
|
||||
t.Fatal("expected deleted device lookup to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
migrateRepositoryTables(t, db, &domain.LoginLog{})
|
||||
|
||||
repo := NewLoginLogRepository(db)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
logs := []*domain.LoginLog{
|
||||
{
|
||||
UserID: int64Ptr(1),
|
||||
LoginType: int(domain.LoginTypePassword),
|
||||
DeviceID: "device-alpha",
|
||||
IP: "10.0.0.1",
|
||||
Location: "Shanghai",
|
||||
Status: 1,
|
||||
CreatedAt: now.Add(-1 * time.Hour),
|
||||
},
|
||||
{
|
||||
UserID: int64Ptr(1),
|
||||
LoginType: int(domain.LoginTypeSMSCode),
|
||||
DeviceID: "device-beta",
|
||||
IP: "10.0.0.2",
|
||||
Location: "Hangzhou",
|
||||
Status: 0,
|
||||
FailReason: "code expired",
|
||||
CreatedAt: now.Add(-30 * time.Minute),
|
||||
},
|
||||
{
|
||||
UserID: int64Ptr(2),
|
||||
LoginType: int(domain.LoginTypeOAuth),
|
||||
DeviceID: "device-gamma",
|
||||
IP: "10.0.0.3",
|
||||
Location: "Beijing",
|
||||
Status: 1,
|
||||
CreatedAt: now.Add(-45 * 24 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
if err := repo.Create(ctx, log); err != nil {
|
||||
t.Fatalf("Create login log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
loaded, err := repo.GetByID(ctx, logs[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if loaded.DeviceID != "device-alpha" {
|
||||
t.Fatalf("expected device-alpha, got %q", loaded.DeviceID)
|
||||
}
|
||||
|
||||
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByUserID failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(userLogs) != 2 {
|
||||
t.Fatalf("expected 2 user logs, got total=%d len=%d", total, len(userLogs))
|
||||
}
|
||||
if userLogs[0].DeviceID != "device-beta" {
|
||||
t.Fatalf("expected newest login log first, got %q", userLogs[0].DeviceID)
|
||||
}
|
||||
|
||||
allLogs, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if total != 3 || len(allLogs) != 3 {
|
||||
t.Fatalf("expected 3 total logs, got total=%d len=%d", total, len(allLogs))
|
||||
}
|
||||
|
||||
successLogs, total, err := repo.ListByStatus(ctx, 1, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByStatus failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(successLogs) != 2 {
|
||||
t.Fatalf("expected 2 success logs, got total=%d len=%d", total, len(successLogs))
|
||||
}
|
||||
|
||||
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-2*time.Hour), now, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByTimeRange failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(recentLogs) != 2 {
|
||||
t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs))
|
||||
}
|
||||
|
||||
if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 {
|
||||
t.Fatalf("expected 1 recent success login, got %d", count)
|
||||
}
|
||||
if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 {
|
||||
t.Fatalf("expected 1 recent failed login, got %d", count)
|
||||
}
|
||||
|
||||
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
|
||||
t.Fatalf("DeleteOlderThan failed: %v", err)
|
||||
}
|
||||
|
||||
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List after DeleteOlderThan failed: %v", err)
|
||||
}
|
||||
if retainedTotal != 2 || len(retainedLogs) != 2 {
|
||||
t.Fatalf("expected 2 retained logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
|
||||
}
|
||||
|
||||
if err := repo.DeleteByUserID(ctx, 1); err != nil {
|
||||
t.Fatalf("DeleteByUserID failed: %v", err)
|
||||
}
|
||||
|
||||
finalLogs, finalTotal, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List after DeleteByUserID failed: %v", err)
|
||||
}
|
||||
if finalTotal != 0 || len(finalLogs) != 0 {
|
||||
t.Fatalf("expected all logs removed, got total=%d len=%d", finalTotal, len(finalLogs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepositoryKeepsNewestRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
migrateRepositoryTables(t, db, &domain.PasswordHistory{})
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
histories := []*domain.PasswordHistory{
|
||||
{UserID: 1, PasswordHash: "hash-1", CreatedAt: now.Add(-4 * time.Hour)},
|
||||
{UserID: 1, PasswordHash: "hash-2", CreatedAt: now.Add(-3 * time.Hour)},
|
||||
{UserID: 1, PasswordHash: "hash-3", CreatedAt: now.Add(-2 * time.Hour)},
|
||||
{UserID: 1, PasswordHash: "hash-4", CreatedAt: now.Add(-1 * time.Hour)},
|
||||
{UserID: 2, PasswordHash: "hash-foreign", CreatedAt: now.Add(-30 * time.Minute)},
|
||||
}
|
||||
|
||||
for _, history := range histories {
|
||||
if err := repo.Create(ctx, history); err != nil {
|
||||
t.Fatalf("Create password history failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
latestTwo, err := repo.GetByUserID(ctx, 1, 2)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID(limit=2) failed: %v", err)
|
||||
}
|
||||
if len(latestTwo) != 2 {
|
||||
t.Fatalf("expected 2 latest password histories, got %d", len(latestTwo))
|
||||
}
|
||||
if latestTwo[0].PasswordHash != "hash-4" || latestTwo[1].PasswordHash != "hash-3" {
|
||||
t.Fatalf("expected newest password hashes to be retained, got %q and %q", latestTwo[0].PasswordHash, latestTwo[1].PasswordHash)
|
||||
}
|
||||
|
||||
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
|
||||
t.Fatalf("DeleteOldRecords failed: %v", err)
|
||||
}
|
||||
|
||||
remainingHistories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID after DeleteOldRecords failed: %v", err)
|
||||
}
|
||||
if len(remainingHistories) != 2 {
|
||||
t.Fatalf("expected 2 remaining histories, got %d", len(remainingHistories))
|
||||
}
|
||||
if remainingHistories[0].PasswordHash != "hash-4" || remainingHistories[1].PasswordHash != "hash-3" {
|
||||
t.Fatalf("unexpected remaining password hashes: %q and %q", remainingHistories[0].PasswordHash, remainingHistories[1].PasswordHash)
|
||||
}
|
||||
|
||||
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
|
||||
t.Fatalf("DeleteOldRecords for missing user failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationLogRepositorySearchAndRetention(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
migrateRepositoryTables(t, db, &domain.OperationLog{})
|
||||
|
||||
repo := NewOperationLogRepository(db)
|
||||
ctx := context.Background()
|
||||
now := time.Now().UTC()
|
||||
|
||||
logs := []*domain.OperationLog{
|
||||
{
|
||||
UserID: int64Ptr(1),
|
||||
OperationType: "user",
|
||||
OperationName: "create user",
|
||||
RequestMethod: "POST",
|
||||
RequestPath: "/api/v1/users",
|
||||
RequestParams: `{"username":"alice"}`,
|
||||
ResponseStatus: 201,
|
||||
IP: "10.0.0.1",
|
||||
UserAgent: "Chrome",
|
||||
CreatedAt: now.Add(-20 * time.Minute),
|
||||
},
|
||||
{
|
||||
UserID: int64Ptr(1),
|
||||
OperationType: "dashboard",
|
||||
OperationName: "view dashboard",
|
||||
RequestMethod: "GET",
|
||||
RequestPath: "/dashboard",
|
||||
RequestParams: "{}",
|
||||
ResponseStatus: 200,
|
||||
IP: "10.0.0.2",
|
||||
UserAgent: "Chrome",
|
||||
CreatedAt: now.Add(-10 * time.Minute),
|
||||
},
|
||||
{
|
||||
UserID: int64Ptr(2),
|
||||
OperationType: "user",
|
||||
OperationName: "delete user",
|
||||
RequestMethod: "DELETE",
|
||||
RequestPath: "/api/v1/users/7",
|
||||
RequestParams: "{}",
|
||||
ResponseStatus: 204,
|
||||
IP: "10.0.0.3",
|
||||
UserAgent: "Firefox",
|
||||
CreatedAt: now.Add(-40 * 24 * time.Hour),
|
||||
},
|
||||
}
|
||||
|
||||
for _, log := range logs {
|
||||
if err := repo.Create(ctx, log); err != nil {
|
||||
t.Fatalf("Create operation log failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
loaded, err := repo.GetByID(ctx, logs[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if loaded.OperationName != "create user" {
|
||||
t.Fatalf("expected create user log, got %q", loaded.OperationName)
|
||||
}
|
||||
|
||||
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByUserID failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(userLogs) != 2 {
|
||||
t.Fatalf("expected 2 user operation logs, got total=%d len=%d", total, len(userLogs))
|
||||
}
|
||||
if userLogs[0].OperationName != "view dashboard" {
|
||||
t.Fatalf("expected newest operation log first, got %q", userLogs[0].OperationName)
|
||||
}
|
||||
|
||||
allLogs, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if total != 3 || len(allLogs) != 3 {
|
||||
t.Fatalf("expected 3 total operation logs, got total=%d len=%d", total, len(allLogs))
|
||||
}
|
||||
|
||||
postLogs, total, err := repo.ListByMethod(ctx, "POST", 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByMethod failed: %v", err)
|
||||
}
|
||||
if total != 1 || len(postLogs) != 1 || postLogs[0].OperationName != "create user" {
|
||||
t.Fatalf("expected a single POST operation log, got total=%d len=%d", total, len(postLogs))
|
||||
}
|
||||
|
||||
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-1*time.Hour), now, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByTimeRange failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(recentLogs) != 2 {
|
||||
t.Fatalf("expected 2 recent operation logs, got total=%d len=%d", total, len(recentLogs))
|
||||
}
|
||||
|
||||
searchResults, total, err := repo.Search(ctx, "user", 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(searchResults) != 2 {
|
||||
t.Fatalf("expected 2 operation logs matching user, got total=%d len=%d", total, len(searchResults))
|
||||
}
|
||||
|
||||
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
|
||||
t.Fatalf("DeleteOlderThan failed: %v", err)
|
||||
}
|
||||
|
||||
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List after DeleteOlderThan failed: %v", err)
|
||||
}
|
||||
if retainedTotal != 2 || len(retainedLogs) != 2 {
|
||||
t.Fatalf("expected 2 retained operation logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
|
||||
}
|
||||
}
|
||||
603
internal/repository/repository_relationships_test.go
Normal file
603
internal/repository/repository_relationships_test.go
Normal file
@@ -0,0 +1,603 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func containsInt64(values []int64, target int64) bool {
|
||||
for _, value := range values {
|
||||
if value == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func TestRoleRepositoryLifecycleAndQueries(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewRoleRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
admin := &domain.Role{
|
||||
Name: "Admin Test",
|
||||
Code: "admin-test",
|
||||
Description: "root role",
|
||||
Level: 1,
|
||||
IsSystem: true,
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}
|
||||
if err := repo.Create(ctx, admin); err != nil {
|
||||
t.Fatalf("Create(admin) failed: %v", err)
|
||||
}
|
||||
|
||||
parentID := admin.ID
|
||||
auditor := &domain.Role{
|
||||
Name: "Auditor Test",
|
||||
Code: "auditor-test",
|
||||
Description: "audit role",
|
||||
ParentID: &parentID,
|
||||
Level: 2,
|
||||
IsDefault: true,
|
||||
Status: domain.RoleStatusDisabled,
|
||||
}
|
||||
viewer := &domain.Role{
|
||||
Name: "Viewer Test",
|
||||
Code: "viewer-test",
|
||||
Description: "view role",
|
||||
Level: 1,
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}
|
||||
|
||||
for _, role := range []*domain.Role{auditor, viewer} {
|
||||
if err := repo.Create(ctx, role); err != nil {
|
||||
t.Fatalf("Create(%s) failed: %v", role.Code, err)
|
||||
}
|
||||
}
|
||||
|
||||
loadedByID, err := repo.GetByID(ctx, admin.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if loadedByID.Code != "admin-test" {
|
||||
t.Fatalf("expected admin-test, got %q", loadedByID.Code)
|
||||
}
|
||||
|
||||
loadedByCode, err := repo.GetByCode(ctx, "auditor-test")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByCode failed: %v", err)
|
||||
}
|
||||
if loadedByCode.ID != auditor.ID {
|
||||
t.Fatalf("expected auditor id %d, got %d", auditor.ID, loadedByCode.ID)
|
||||
}
|
||||
|
||||
allRoles, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if total != 3 || len(allRoles) != 3 {
|
||||
t.Fatalf("expected 3 roles, got total=%d len=%d", total, len(allRoles))
|
||||
}
|
||||
|
||||
enabledRoles, total, err := repo.ListByStatus(ctx, domain.RoleStatusEnabled, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByStatus failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(enabledRoles) != 2 {
|
||||
t.Fatalf("expected 2 enabled roles, got total=%d len=%d", total, len(enabledRoles))
|
||||
}
|
||||
|
||||
defaultRoles, err := repo.GetDefaultRoles(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefaultRoles failed: %v", err)
|
||||
}
|
||||
if len(defaultRoles) != 1 || defaultRoles[0].ID != auditor.ID {
|
||||
t.Fatalf("expected auditor as default role, got %+v", defaultRoles)
|
||||
}
|
||||
|
||||
exists, err := repo.ExistsByCode(ctx, "viewer-test")
|
||||
if err != nil {
|
||||
t.Fatalf("ExistsByCode(viewer-test) failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("expected viewer-test to exist")
|
||||
}
|
||||
|
||||
missing, err := repo.ExistsByCode(ctx, "missing-role")
|
||||
if err != nil {
|
||||
t.Fatalf("ExistsByCode(missing-role) failed: %v", err)
|
||||
}
|
||||
if missing {
|
||||
t.Fatal("expected missing-role to be absent")
|
||||
}
|
||||
|
||||
auditor.Description = "audit role updated"
|
||||
if err := repo.Update(ctx, auditor); err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
if err := repo.UpdateStatus(ctx, auditor.ID, domain.RoleStatusEnabled); err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
searchResults, total, err := repo.Search(ctx, "audit", 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
if total != 1 || len(searchResults) != 1 || searchResults[0].ID != auditor.ID {
|
||||
t.Fatalf("expected auditor search hit, got total=%d len=%d", total, len(searchResults))
|
||||
}
|
||||
|
||||
childRoles, err := repo.ListByParentID(ctx, admin.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByParentID failed: %v", err)
|
||||
}
|
||||
if len(childRoles) != 1 || childRoles[0].ID != auditor.ID {
|
||||
t.Fatalf("expected auditor child role, got %+v", childRoles)
|
||||
}
|
||||
|
||||
roleSubset, err := repo.GetByIDs(ctx, []int64{admin.ID, auditor.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs failed: %v", err)
|
||||
}
|
||||
if len(roleSubset) != 2 {
|
||||
t.Fatalf("expected 2 roles from GetByIDs, got %d", len(roleSubset))
|
||||
}
|
||||
|
||||
emptySubset, err := repo.GetByIDs(ctx, []int64{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs(empty) failed: %v", err)
|
||||
}
|
||||
if len(emptySubset) != 0 {
|
||||
t.Fatalf("expected empty slice for GetByIDs(empty), got %d", len(emptySubset))
|
||||
}
|
||||
|
||||
if err := repo.Delete(ctx, viewer.ID); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := repo.GetByID(ctx, viewer.ID); err == nil {
|
||||
t.Fatal("expected deleted role lookup to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionRepositoryLifecycleAndQueries(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
repo := NewPermissionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
parent := &domain.Permission{
|
||||
Name: "Dashboard",
|
||||
Code: "dashboard:view",
|
||||
Type: domain.PermissionTypeMenu,
|
||||
Description: "dashboard menu",
|
||||
Path: "/dashboard",
|
||||
Sort: 1,
|
||||
Status: domain.PermissionStatusEnabled,
|
||||
}
|
||||
if err := repo.Create(ctx, parent); err != nil {
|
||||
t.Fatalf("Create(parent) failed: %v", err)
|
||||
}
|
||||
|
||||
parentID := parent.ID
|
||||
apiPermission := &domain.Permission{
|
||||
Name: "Audit API",
|
||||
Code: "audit:read",
|
||||
Type: domain.PermissionTypeAPI,
|
||||
Description: "audit api",
|
||||
ParentID: &parentID,
|
||||
Path: "/api/audit",
|
||||
Method: "GET",
|
||||
Sort: 2,
|
||||
Status: domain.PermissionStatusDisabled,
|
||||
}
|
||||
buttonPermission := &domain.Permission{
|
||||
Name: "Audit Button",
|
||||
Code: "audit:button",
|
||||
Type: domain.PermissionTypeButton,
|
||||
Description: "audit action",
|
||||
Sort: 3,
|
||||
Status: domain.PermissionStatusEnabled,
|
||||
}
|
||||
|
||||
for _, permission := range []*domain.Permission{apiPermission, buttonPermission} {
|
||||
if err := repo.Create(ctx, permission); err != nil {
|
||||
t.Fatalf("Create(%s) failed: %v", permission.Code, err)
|
||||
}
|
||||
}
|
||||
|
||||
role := &domain.Role{
|
||||
Name: "Permission Role",
|
||||
Code: "permission-role",
|
||||
Description: "role for permission join queries",
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}
|
||||
if err := db.WithContext(ctx).Create(role).Error; err != nil {
|
||||
t.Fatalf("create role for permission joins failed: %v", err)
|
||||
}
|
||||
|
||||
for _, rolePermission := range []*domain.RolePermission{
|
||||
{RoleID: role.ID, PermissionID: parent.ID},
|
||||
{RoleID: role.ID, PermissionID: apiPermission.ID},
|
||||
} {
|
||||
if err := db.WithContext(ctx).Create(rolePermission).Error; err != nil {
|
||||
t.Fatalf("create role_permission failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
loadedByID, err := repo.GetByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if loadedByID.Code != "dashboard:view" {
|
||||
t.Fatalf("expected dashboard:view, got %q", loadedByID.Code)
|
||||
}
|
||||
|
||||
loadedByCode, err := repo.GetByCode(ctx, "audit:read")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByCode failed: %v", err)
|
||||
}
|
||||
if loadedByCode.ID != apiPermission.ID {
|
||||
t.Fatalf("expected audit:read id %d, got %d", apiPermission.ID, loadedByCode.ID)
|
||||
}
|
||||
|
||||
allPermissions, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List failed: %v", err)
|
||||
}
|
||||
if total != 3 || len(allPermissions) != 3 {
|
||||
t.Fatalf("expected 3 permissions, got total=%d len=%d", total, len(allPermissions))
|
||||
}
|
||||
|
||||
apiPermissions, total, err := repo.ListByType(ctx, domain.PermissionTypeAPI, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByType failed: %v", err)
|
||||
}
|
||||
if total != 1 || len(apiPermissions) != 1 || apiPermissions[0].ID != apiPermission.ID {
|
||||
t.Fatalf("expected audit api permission, got total=%d len=%d", total, len(apiPermissions))
|
||||
}
|
||||
|
||||
enabledPermissions, total, err := repo.ListByStatus(ctx, domain.PermissionStatusEnabled, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByStatus failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(enabledPermissions) != 2 {
|
||||
t.Fatalf("expected 2 enabled permissions, got total=%d len=%d", total, len(enabledPermissions))
|
||||
}
|
||||
|
||||
rolePermissions, err := repo.GetByRoleIDs(ctx, []int64{role.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByRoleIDs failed: %v", err)
|
||||
}
|
||||
if len(rolePermissions) != 1 || rolePermissions[0].ID != parent.ID {
|
||||
t.Fatalf("expected only enabled parent permission in join query, got %+v", rolePermissions)
|
||||
}
|
||||
|
||||
exists, err := repo.ExistsByCode(ctx, "audit:button")
|
||||
if err != nil {
|
||||
t.Fatalf("ExistsByCode(audit:button) failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("expected audit:button to exist")
|
||||
}
|
||||
|
||||
missing, err := repo.ExistsByCode(ctx, "permission:missing")
|
||||
if err != nil {
|
||||
t.Fatalf("ExistsByCode(missing) failed: %v", err)
|
||||
}
|
||||
if missing {
|
||||
t.Fatal("expected permission:missing to be absent")
|
||||
}
|
||||
|
||||
apiPermission.Description = "audit api updated"
|
||||
if err := repo.Update(ctx, apiPermission); err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
if err := repo.UpdateStatus(ctx, apiPermission.ID, domain.PermissionStatusEnabled); err != nil {
|
||||
t.Fatalf("UpdateStatus failed: %v", err)
|
||||
}
|
||||
|
||||
searchResults, total, err := repo.Search(ctx, "audit", 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("Search failed: %v", err)
|
||||
}
|
||||
if total != 2 || len(searchResults) != 2 {
|
||||
t.Fatalf("expected 2 audit-related permissions, got total=%d len=%d", total, len(searchResults))
|
||||
}
|
||||
|
||||
childPermissions, err := repo.ListByParentID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByParentID failed: %v", err)
|
||||
}
|
||||
if len(childPermissions) != 1 || childPermissions[0].ID != apiPermission.ID {
|
||||
t.Fatalf("expected api permission child, got %+v", childPermissions)
|
||||
}
|
||||
|
||||
permissionSubset, err := repo.GetByIDs(ctx, []int64{parent.ID, apiPermission.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs failed: %v", err)
|
||||
}
|
||||
if len(permissionSubset) != 2 {
|
||||
t.Fatalf("expected 2 permissions from GetByIDs, got %d", len(permissionSubset))
|
||||
}
|
||||
|
||||
emptySubset, err := repo.GetByIDs(ctx, []int64{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs(empty) failed: %v", err)
|
||||
}
|
||||
if len(emptySubset) != 0 {
|
||||
t.Fatalf("expected empty slice for GetByIDs(empty), got %d", len(emptySubset))
|
||||
}
|
||||
|
||||
if err := repo.Delete(ctx, buttonPermission.ID); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := repo.GetByID(ctx, buttonPermission.ID); err == nil {
|
||||
t.Fatal("expected deleted permission lookup to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserRoleAndRolePermissionRepositoriesLifecycle(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
userRoleRepo := NewUserRoleRepository(db)
|
||||
rolePermissionRepo := NewRolePermissionRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
users := []*domain.User{
|
||||
{Username: "repo-user-1", Password: "hash", Status: domain.UserStatusActive},
|
||||
{Username: "repo-user-2", Password: "hash", Status: domain.UserStatusActive},
|
||||
}
|
||||
for _, user := range users {
|
||||
if err := db.WithContext(ctx).Create(user).Error; err != nil {
|
||||
t.Fatalf("create user failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
roles := []*domain.Role{
|
||||
{Name: "Repo Role 1", Code: "repo-role-1", Status: domain.RoleStatusEnabled},
|
||||
{Name: "Repo Role 2", Code: "repo-role-2", Status: domain.RoleStatusEnabled},
|
||||
}
|
||||
for _, role := range roles {
|
||||
if err := db.WithContext(ctx).Create(role).Error; err != nil {
|
||||
t.Fatalf("create role failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
permissions := []*domain.Permission{
|
||||
{Name: "Repo Permission 1", Code: "repo:permission:1", Type: domain.PermissionTypeAPI, Status: domain.PermissionStatusEnabled},
|
||||
{Name: "Repo Permission 2", Code: "repo:permission:2", Type: domain.PermissionTypeAPI, Status: domain.PermissionStatusEnabled},
|
||||
}
|
||||
for _, permission := range permissions {
|
||||
if err := db.WithContext(ctx).Create(permission).Error; err != nil {
|
||||
t.Fatalf("create permission failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
userRolePrimary := &domain.UserRole{UserID: users[0].ID, RoleID: roles[0].ID}
|
||||
if err := userRoleRepo.Create(ctx, userRolePrimary); err != nil {
|
||||
t.Fatalf("UserRole Create failed: %v", err)
|
||||
}
|
||||
|
||||
if err := userRoleRepo.BatchCreate(ctx, []*domain.UserRole{}); err != nil {
|
||||
t.Fatalf("UserRole BatchCreate(empty) failed: %v", err)
|
||||
}
|
||||
|
||||
userRoleBatch := []*domain.UserRole{
|
||||
{UserID: users[0].ID, RoleID: roles[1].ID},
|
||||
{UserID: users[1].ID, RoleID: roles[0].ID},
|
||||
}
|
||||
if err := userRoleRepo.BatchCreate(ctx, userRoleBatch); err != nil {
|
||||
t.Fatalf("UserRole BatchCreate failed: %v", err)
|
||||
}
|
||||
|
||||
exists, err := userRoleRepo.Exists(ctx, users[0].ID, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("UserRole Exists failed: %v", err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("expected primary user-role relation to exist")
|
||||
}
|
||||
|
||||
missing, err := userRoleRepo.Exists(ctx, users[1].ID, roles[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("UserRole Exists(missing) failed: %v", err)
|
||||
}
|
||||
if missing {
|
||||
t.Fatal("expected missing user-role relation to be absent")
|
||||
}
|
||||
|
||||
rolesForUserOne, err := userRoleRepo.GetByUserID(ctx, users[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID failed: %v", err)
|
||||
}
|
||||
if len(rolesForUserOne) != 2 {
|
||||
t.Fatalf("expected 2 roles for user one, got %d", len(rolesForUserOne))
|
||||
}
|
||||
|
||||
usersForRoleOne, err := userRoleRepo.GetByRoleID(ctx, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByRoleID failed: %v", err)
|
||||
}
|
||||
if len(usersForRoleOne) != 2 {
|
||||
t.Fatalf("expected 2 users for role one, got %d", len(usersForRoleOne))
|
||||
}
|
||||
|
||||
roleIDs, err := userRoleRepo.GetRoleIDsByUserID(ctx, users[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoleIDsByUserID failed: %v", err)
|
||||
}
|
||||
if len(roleIDs) != 2 || !containsInt64(roleIDs, roles[0].ID) || !containsInt64(roleIDs, roles[1].ID) {
|
||||
t.Fatalf("unexpected role IDs for user one: %+v", roleIDs)
|
||||
}
|
||||
|
||||
userIDs, err := userRoleRepo.GetUserIDByRoleID(ctx, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetUserIDByRoleID failed: %v", err)
|
||||
}
|
||||
if len(userIDs) != 2 || !containsInt64(userIDs, users[0].ID) || !containsInt64(userIDs, users[1].ID) {
|
||||
t.Fatalf("unexpected user IDs for role one: %+v", userIDs)
|
||||
}
|
||||
|
||||
if err := userRoleRepo.BatchDelete(ctx, []*domain.UserRole{}); err != nil {
|
||||
t.Fatalf("UserRole BatchDelete(empty) failed: %v", err)
|
||||
}
|
||||
|
||||
if err := userRoleRepo.BatchDelete(ctx, []*domain.UserRole{userRoleBatch[0]}); err != nil {
|
||||
t.Fatalf("UserRole BatchDelete failed: %v", err)
|
||||
}
|
||||
|
||||
if err := userRoleRepo.Delete(ctx, userRolePrimary.ID); err != nil {
|
||||
t.Fatalf("UserRole Delete failed: %v", err)
|
||||
}
|
||||
|
||||
existsAfterDelete, err := userRoleRepo.Exists(ctx, users[0].ID, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("UserRole Exists after Delete failed: %v", err)
|
||||
}
|
||||
if existsAfterDelete {
|
||||
t.Fatal("expected primary user-role relation to be removed")
|
||||
}
|
||||
|
||||
if err := userRoleRepo.DeleteByUserID(ctx, users[1].ID); err != nil {
|
||||
t.Fatalf("DeleteByUserID failed: %v", err)
|
||||
}
|
||||
|
||||
if err := userRoleRepo.Create(ctx, &domain.UserRole{UserID: users[0].ID, RoleID: roles[1].ID}); err != nil {
|
||||
t.Fatalf("recreate user-role failed: %v", err)
|
||||
}
|
||||
if err := userRoleRepo.DeleteByRoleID(ctx, roles[1].ID); err != nil {
|
||||
t.Fatalf("DeleteByRoleID failed: %v", err)
|
||||
}
|
||||
|
||||
remainingUserRoles, err := userRoleRepo.GetByRoleID(ctx, roles[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByRoleID after DeleteByRoleID failed: %v", err)
|
||||
}
|
||||
if len(remainingUserRoles) != 0 {
|
||||
t.Fatalf("expected no user-role relations for role two, got %d", len(remainingUserRoles))
|
||||
}
|
||||
|
||||
rolePermissionPrimary := &domain.RolePermission{RoleID: roles[0].ID, PermissionID: permissions[0].ID}
|
||||
if err := rolePermissionRepo.Create(ctx, rolePermissionPrimary); err != nil {
|
||||
t.Fatalf("RolePermission Create failed: %v", err)
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.BatchCreate(ctx, []*domain.RolePermission{}); err != nil {
|
||||
t.Fatalf("RolePermission BatchCreate(empty) failed: %v", err)
|
||||
}
|
||||
|
||||
rolePermissionBatch := []*domain.RolePermission{
|
||||
{RoleID: roles[0].ID, PermissionID: permissions[1].ID},
|
||||
{RoleID: roles[1].ID, PermissionID: permissions[0].ID},
|
||||
}
|
||||
if err := rolePermissionRepo.BatchCreate(ctx, rolePermissionBatch); err != nil {
|
||||
t.Fatalf("RolePermission BatchCreate failed: %v", err)
|
||||
}
|
||||
|
||||
rpExists, err := rolePermissionRepo.Exists(ctx, roles[0].ID, permissions[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("RolePermission Exists failed: %v", err)
|
||||
}
|
||||
if !rpExists {
|
||||
t.Fatal("expected primary role-permission relation to exist")
|
||||
}
|
||||
|
||||
rpMissing, err := rolePermissionRepo.Exists(ctx, roles[1].ID, permissions[1].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("RolePermission Exists(missing) failed: %v", err)
|
||||
}
|
||||
if rpMissing {
|
||||
t.Fatal("expected missing role-permission relation to be absent")
|
||||
}
|
||||
|
||||
permissionsForRoleOne, err := rolePermissionRepo.GetByRoleID(ctx, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByRoleID failed: %v", err)
|
||||
}
|
||||
if len(permissionsForRoleOne) != 2 {
|
||||
t.Fatalf("expected 2 permissions for role one, got %d", len(permissionsForRoleOne))
|
||||
}
|
||||
|
||||
rolesForPermissionOne, err := rolePermissionRepo.GetByPermissionID(ctx, permissions[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByPermissionID failed: %v", err)
|
||||
}
|
||||
if len(rolesForPermissionOne) != 2 {
|
||||
t.Fatalf("expected 2 roles for permission one, got %d", len(rolesForPermissionOne))
|
||||
}
|
||||
|
||||
permissionIDs, err := rolePermissionRepo.GetPermissionIDsByRoleID(ctx, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPermissionIDsByRoleID failed: %v", err)
|
||||
}
|
||||
if len(permissionIDs) != 2 || !containsInt64(permissionIDs, permissions[0].ID) || !containsInt64(permissionIDs, permissions[1].ID) {
|
||||
t.Fatalf("unexpected permission IDs for role one: %+v", permissionIDs)
|
||||
}
|
||||
|
||||
roleIDsByPermission, err := rolePermissionRepo.GetRoleIDByPermissionID(ctx, permissions[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetRoleIDByPermissionID failed: %v", err)
|
||||
}
|
||||
if len(roleIDsByPermission) != 2 || !containsInt64(roleIDsByPermission, roles[0].ID) || !containsInt64(roleIDsByPermission, roles[1].ID) {
|
||||
t.Fatalf("unexpected role IDs for permission one: %+v", roleIDsByPermission)
|
||||
}
|
||||
|
||||
loadedPermission, err := rolePermissionRepo.GetPermissionByID(ctx, permissions[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetPermissionByID failed: %v", err)
|
||||
}
|
||||
if loadedPermission.Code != "repo:permission:1" {
|
||||
t.Fatalf("expected repo:permission:1, got %q", loadedPermission.Code)
|
||||
}
|
||||
|
||||
permissionIDsByRoleIDs, err := rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, []int64{roles[0].ID, roles[1].ID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPermissionIDsByRoleIDs failed: %v", err)
|
||||
}
|
||||
if len(permissionIDsByRoleIDs) != 3 {
|
||||
t.Fatalf("expected 3 permission IDs from combined roles, got %d", len(permissionIDsByRoleIDs))
|
||||
}
|
||||
|
||||
emptyPermissionIDs, err := rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, []int64{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPermissionIDsByRoleIDs(empty) failed: %v", err)
|
||||
}
|
||||
if len(emptyPermissionIDs) != 0 {
|
||||
t.Fatalf("expected empty slice for GetPermissionIDsByRoleIDs(empty), got %d", len(emptyPermissionIDs))
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.BatchDelete(ctx, []*domain.RolePermission{}); err != nil {
|
||||
t.Fatalf("RolePermission BatchDelete(empty) failed: %v", err)
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.BatchDelete(ctx, []*domain.RolePermission{rolePermissionBatch[0]}); err != nil {
|
||||
t.Fatalf("RolePermission BatchDelete failed: %v", err)
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.Delete(ctx, rolePermissionPrimary.ID); err != nil {
|
||||
t.Fatalf("RolePermission Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.DeleteByPermissionID(ctx, permissions[0].ID); err != nil {
|
||||
t.Fatalf("DeleteByPermissionID failed: %v", err)
|
||||
}
|
||||
|
||||
if err := rolePermissionRepo.Create(ctx, &domain.RolePermission{RoleID: roles[0].ID, PermissionID: permissions[1].ID}); err != nil {
|
||||
t.Fatalf("recreate role-permission failed: %v", err)
|
||||
}
|
||||
if err := rolePermissionRepo.DeleteByRoleID(ctx, roles[0].ID); err != nil {
|
||||
t.Fatalf("DeleteByRoleID failed: %v", err)
|
||||
}
|
||||
|
||||
remainingRolePermissions, err := rolePermissionRepo.GetByRoleID(ctx, roles[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByRoleID after DeleteByRoleID failed: %v", err)
|
||||
}
|
||||
if len(remainingRolePermissions) != 0 {
|
||||
t.Fatalf("expected no role-permission relations for role one, got %d", len(remainingRolePermissions))
|
||||
}
|
||||
}
|
||||
213
internal/repository/role.go
Normal file
213
internal/repository/role.go
Normal file
@@ -0,0 +1,213 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// RoleRepository 角色数据访问层
|
||||
type RoleRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewRoleRepository 创建角色数据访问层
|
||||
func NewRoleRepository(db *gorm.DB) *RoleRepository {
|
||||
return &RoleRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建角色
|
||||
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
|
||||
// GORM omits zero values on insert for fields with DB defaults. Explicitly
|
||||
// backfill disabled status so callers can persist status=0 roles.
|
||||
requestedStatus := role.Status
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(role).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requestedStatus == domain.RoleStatusDisabled {
|
||||
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
role.Status = requestedStatus
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update 更新角色
|
||||
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
|
||||
return r.db.WithContext(ctx).Save(role).Error
|
||||
}
|
||||
|
||||
// Delete 删除角色
|
||||
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取角色
|
||||
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).First(&role, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// GetByCode 根据代码获取角色
|
||||
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// List 获取角色列表
|
||||
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
|
||||
var roles []*domain.Role
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Role{})
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return roles, total, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态获取角色列表
|
||||
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
|
||||
var roles []*domain.Role
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return roles, total, nil
|
||||
}
|
||||
|
||||
// GetDefaultRoles 获取默认角色
|
||||
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
|
||||
var roles []*domain.Role
|
||||
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// ExistsByCode 检查角色代码是否存在
|
||||
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// UpdateStatus 更新角色状态
|
||||
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// Search 搜索角色
|
||||
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
|
||||
var roles []*domain.Role
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Role{}).
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return roles, total, nil
|
||||
}
|
||||
|
||||
// ListByParentID 根据父ID获取角色列表
|
||||
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
|
||||
var roles []*domain.Role
|
||||
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// GetByIDs 根据ID列表批量获取角色
|
||||
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
|
||||
if len(ids) == 0 {
|
||||
return []*domain.Role{}, nil
|
||||
}
|
||||
|
||||
var roles []*domain.Role
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// GetAncestorIDs 获取角色的所有祖先角色ID(用于权限继承)
|
||||
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var ancestorIDs []int64
|
||||
currentID := roleID
|
||||
|
||||
// 循环向上查找父角色,直到没有父角色为止
|
||||
for {
|
||||
var role domain.Role
|
||||
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
break
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if role.ParentID == nil {
|
||||
break
|
||||
}
|
||||
ancestorIDs = append(ancestorIDs, *role.ParentID)
|
||||
currentID = *role.ParentID
|
||||
}
|
||||
|
||||
return ancestorIDs, nil
|
||||
}
|
||||
|
||||
// GetAncestors 获取角色的完整继承链(从父到子)
|
||||
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
|
||||
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ancestorIDs) == 0 {
|
||||
return []*domain.Role{}, nil
|
||||
}
|
||||
return r.GetByIDs(ctx, ancestorIDs)
|
||||
}
|
||||
150
internal/repository/role_permission.go
Normal file
150
internal/repository/role_permission.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// RolePermissionRepository 角色权限关联数据访问层
|
||||
type RolePermissionRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewRolePermissionRepository 创建角色权限关联数据访问层
|
||||
func NewRolePermissionRepository(db *gorm.DB) *RolePermissionRepository {
|
||||
return &RolePermissionRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建角色权限关联
|
||||
func (r *RolePermissionRepository) Create(ctx context.Context, rolePermission *domain.RolePermission) error {
|
||||
return r.db.WithContext(ctx).Create(rolePermission).Error
|
||||
}
|
||||
|
||||
// Delete 删除角色权限关联
|
||||
func (r *RolePermissionRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.RolePermission{}, id).Error
|
||||
}
|
||||
|
||||
// DeleteByRoleID 删除角色的所有权限
|
||||
func (r *RolePermissionRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
|
||||
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.RolePermission{}).Error
|
||||
}
|
||||
|
||||
// DeleteByPermissionID 删除权限的所有角色
|
||||
func (r *RolePermissionRepository) DeleteByPermissionID(ctx context.Context, permissionID int64) error {
|
||||
return r.db.WithContext(ctx).Where("permission_id = ?", permissionID).Delete(&domain.RolePermission{}).Error
|
||||
}
|
||||
|
||||
// GetByRoleID 根据角色ID获取权限列表
|
||||
func (r *RolePermissionRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.RolePermission, error) {
|
||||
var rolePermissions []*domain.RolePermission
|
||||
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&rolePermissions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rolePermissions, nil
|
||||
}
|
||||
|
||||
// GetByPermissionID 根据权限ID获取角色列表
|
||||
func (r *RolePermissionRepository) GetByPermissionID(ctx context.Context, permissionID int64) ([]*domain.RolePermission, error) {
|
||||
var rolePermissions []*domain.RolePermission
|
||||
err := r.db.WithContext(ctx).Where("permission_id = ?", permissionID).Find(&rolePermissions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rolePermissions, nil
|
||||
}
|
||||
|
||||
// GetPermissionIDsByRoleID 根据角色ID获取权限ID列表
|
||||
func (r *RolePermissionRepository) GetPermissionIDsByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var permissionIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id = ?", roleID).Pluck("permission_id", &permissionIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return permissionIDs, nil
|
||||
}
|
||||
|
||||
// GetRoleIDByPermissionID 根据权限ID获取角色ID列表
|
||||
func (r *RolePermissionRepository) GetRoleIDByPermissionID(ctx context.Context, permissionID int64) ([]int64, error) {
|
||||
var roleIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("permission_id = ?", permissionID).Pluck("role_id", &roleIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roleIDs, nil
|
||||
}
|
||||
|
||||
// Exists 检查角色权限关联是否存在
|
||||
func (r *RolePermissionRepository) Exists(ctx context.Context, roleID, permissionID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
|
||||
Where("role_id = ? AND permission_id = ?", roleID, permissionID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建角色权限关联
|
||||
func (r *RolePermissionRepository) BatchCreate(ctx context.Context, rolePermissions []*domain.RolePermission) error {
|
||||
if len(rolePermissions) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&rolePermissions).Error
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除角色权限关联
|
||||
func (r *RolePermissionRepository) BatchDelete(ctx context.Context, rolePermissions []*domain.RolePermission) error {
|
||||
if len(rolePermissions) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ids []int64
|
||||
for _, rp := range rolePermissions {
|
||||
ids = append(ids, rp.ID)
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Delete(&domain.RolePermission{}, ids).Error
|
||||
}
|
||||
|
||||
// GetPermissionByID 根据权限ID获取权限信息
|
||||
func (r *RolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID int64) (*domain.Permission, error) {
|
||||
var permission domain.Permission
|
||||
err := r.db.WithContext(ctx).First(&permission, permissionID).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &permission, nil
|
||||
}
|
||||
|
||||
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID
|
||||
func (r *RolePermissionRepository) GetPermissionIDsByRoleIDs(ctx context.Context, roleIDs []int64) ([]int64, error) {
|
||||
if len(roleIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
var permissionIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
|
||||
Where("role_id IN ?", roleIDs).
|
||||
Pluck("permission_id", &permissionIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return permissionIDs, nil
|
||||
}
|
||||
|
||||
// GetPermissionsByIDs 根据权限ID列表批量获取权限
|
||||
func (r *RolePermissionRepository) GetPermissionsByIDs(ctx context.Context, permissionIDs []int64) ([]*domain.Permission, error) {
|
||||
if len(permissionIDs) == 0 {
|
||||
return []*domain.Permission{}, nil
|
||||
}
|
||||
|
||||
var permissions []*domain.Permission
|
||||
err := r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return permissions, nil
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
rdb := testRedis(t)
|
||||
client := testEntClient(t)
|
||||
|
||||
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
|
||||
|
||||
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
|
||||
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
|
||||
cache := NewSchedulerCache(rdb)
|
||||
|
||||
cfg := &config.Config{
|
||||
RunMode: config.RunModeStandard,
|
||||
Gateway: config.GatewayConfig{
|
||||
Scheduling: config.GatewaySchedulingConfig{
|
||||
OutboxPollIntervalSeconds: 1,
|
||||
FullRebuildIntervalSeconds: 0,
|
||||
DbFallbackEnabled: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account := &service.Account{
|
||||
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
|
||||
Platform: service.PlatformOpenAI,
|
||||
Type: service.AccountTypeAPIKey,
|
||||
Status: service.StatusActive,
|
||||
Schedulable: true,
|
||||
Concurrency: 3,
|
||||
Priority: 1,
|
||||
Credentials: map[string]any{},
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.NoError(t, accountRepo.Create(ctx, account))
|
||||
require.NoError(t, cache.SetAccount(ctx, account))
|
||||
|
||||
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
|
||||
svc.Start()
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
|
||||
updated, err := accountRepo.GetByID(ctx, account.ID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.LastUsedAt)
|
||||
expectedUnix := updated.LastUsedAt.Unix()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
cached, err := cache.GetAccount(ctx, account.ID)
|
||||
if err != nil || cached == nil || cached.LastUsedAt == nil {
|
||||
return false
|
||||
}
|
||||
return cached.LastUsedAt.Unix() == expectedUnix
|
||||
}, 5*time.Second, 100*time.Millisecond)
|
||||
}
|
||||
295
internal/repository/social_account_repo.go
Normal file
295
internal/repository/social_account_repo.go
Normal file
@@ -0,0 +1,295 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SocialAccountRepository 社交账号仓库接口
|
||||
type SocialAccountRepository interface {
|
||||
Create(ctx context.Context, account *domain.SocialAccount) error
|
||||
Update(ctx context.Context, account *domain.SocialAccount) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error
|
||||
GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error)
|
||||
GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error)
|
||||
GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error)
|
||||
List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error)
|
||||
}
|
||||
|
||||
// SocialAccountRepositoryImpl 社交账号仓库实现
|
||||
type SocialAccountRepositoryImpl struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB)
|
||||
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
|
||||
var sqlDB *sql.DB
|
||||
switch d := db.(type) {
|
||||
case *gorm.DB:
|
||||
var err error
|
||||
sqlDB, err = d.DB()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err)
|
||||
}
|
||||
case *sql.DB:
|
||||
sqlDB = d
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported db type: %T", db)
|
||||
}
|
||||
if sqlDB == nil {
|
||||
return nil, fmt.Errorf("sql db is nil")
|
||||
}
|
||||
return &SocialAccountRepositoryImpl{db: sqlDB}, nil
|
||||
}
|
||||
|
||||
// Create 创建社交账号
|
||||
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
|
||||
query := `
|
||||
INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
`
|
||||
|
||||
result, err := r.db.ExecContext(ctx, query,
|
||||
account.UserID,
|
||||
account.Provider,
|
||||
account.OpenID,
|
||||
account.UnionID,
|
||||
account.Nickname,
|
||||
account.Avatar,
|
||||
account.Gender,
|
||||
account.Email,
|
||||
account.Phone,
|
||||
account.Extra,
|
||||
account.Status,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create social account: %w", err)
|
||||
}
|
||||
|
||||
id, err := result.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
account.ID = id
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update 更新社交账号
|
||||
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
|
||||
query := `
|
||||
UPDATE user_social_accounts
|
||||
SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query,
|
||||
account.UnionID,
|
||||
account.Nickname,
|
||||
account.Avatar,
|
||||
account.Gender,
|
||||
account.Email,
|
||||
account.Phone,
|
||||
account.Extra,
|
||||
account.Status,
|
||||
account.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update social account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete 删除社交账号
|
||||
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
|
||||
query := `DELETE FROM user_social_accounts WHERE id = ?`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete social account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
|
||||
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
|
||||
query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query, provider, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete social account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取社交账号
|
||||
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
|
||||
query := `
|
||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
||||
FROM user_social_accounts
|
||||
WHERE id = ?
|
||||
`
|
||||
|
||||
var account domain.SocialAccount
|
||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
||||
&account.ID,
|
||||
&account.UserID,
|
||||
&account.Provider,
|
||||
&account.OpenID,
|
||||
&account.UnionID,
|
||||
&account.Nickname,
|
||||
&account.Avatar,
|
||||
&account.Gender,
|
||||
&account.Email,
|
||||
&account.Phone,
|
||||
&account.Extra,
|
||||
&account.Status,
|
||||
&account.CreatedAt,
|
||||
&account.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取社交账号列表
|
||||
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
|
||||
query := `
|
||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
||||
FROM user_social_accounts
|
||||
WHERE user_id = ?
|
||||
ORDER BY created_at DESC
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query social accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*domain.SocialAccount
|
||||
for rows.Next() {
|
||||
var account domain.SocialAccount
|
||||
err := rows.Scan(
|
||||
&account.ID,
|
||||
&account.UserID,
|
||||
&account.Provider,
|
||||
&account.OpenID,
|
||||
&account.UnionID,
|
||||
&account.Nickname,
|
||||
&account.Avatar,
|
||||
&account.Gender,
|
||||
&account.Email,
|
||||
&account.Phone,
|
||||
&account.Extra,
|
||||
&account.Status,
|
||||
&account.CreatedAt,
|
||||
&account.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accounts = append(accounts, &account)
|
||||
}
|
||||
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
|
||||
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
|
||||
query := `
|
||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
||||
FROM user_social_accounts
|
||||
WHERE provider = ? AND open_id = ?
|
||||
`
|
||||
|
||||
var account domain.SocialAccount
|
||||
err := r.db.QueryRowContext(ctx, query, provider, openID).Scan(
|
||||
&account.ID,
|
||||
&account.UserID,
|
||||
&account.Provider,
|
||||
&account.OpenID,
|
||||
&account.UnionID,
|
||||
&account.Nickname,
|
||||
&account.Avatar,
|
||||
&account.Gender,
|
||||
&account.Email,
|
||||
&account.Phone,
|
||||
&account.Extra,
|
||||
&account.Status,
|
||||
&account.CreatedAt,
|
||||
&account.UpdatedAt,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||||
}
|
||||
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
// List 分页获取社交账号列表
|
||||
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
|
||||
// 获取总数
|
||||
var total int64
|
||||
countQuery := `SELECT COUNT(*) FROM user_social_accounts`
|
||||
if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
query := `
|
||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
||||
FROM user_social_accounts
|
||||
ORDER BY created_at DESC
|
||||
LIMIT ? OFFSET ?
|
||||
`
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, query, limit, offset)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var accounts []*domain.SocialAccount
|
||||
for rows.Next() {
|
||||
var account domain.SocialAccount
|
||||
err := rows.Scan(
|
||||
&account.ID,
|
||||
&account.UserID,
|
||||
&account.Provider,
|
||||
&account.OpenID,
|
||||
&account.UnionID,
|
||||
&account.Nickname,
|
||||
&account.Avatar,
|
||||
&account.Gender,
|
||||
&account.Email,
|
||||
&account.Phone,
|
||||
&account.Extra,
|
||||
&account.Status,
|
||||
&account.CreatedAt,
|
||||
&account.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
accounts = append(accounts, &account)
|
||||
}
|
||||
|
||||
return accounts, total, nil
|
||||
}
|
||||
41
internal/repository/social_account_repo_constructor_test.go
Normal file
41
internal/repository/social_account_repo_constructor_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package repository
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestNewSocialAccountRepository_AcceptsGormDB(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("expected constructor to succeed: %v", err)
|
||||
}
|
||||
if repo == nil {
|
||||
t.Fatal("expected repository instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSocialAccountRepository_AcceptsSQLDB(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("expected sql db handle: %v", err)
|
||||
}
|
||||
|
||||
repo, err := NewSocialAccountRepository(sqlDB)
|
||||
if err != nil {
|
||||
t.Fatalf("expected constructor to succeed: %v", err)
|
||||
}
|
||||
if repo == nil {
|
||||
t.Fatal("expected repository instance")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSocialAccountRepository_RejectsUnsupportedType(t *testing.T) {
|
||||
repo, err := NewSocialAccountRepository(struct{}{})
|
||||
if err == nil {
|
||||
t.Fatal("expected constructor error")
|
||||
}
|
||||
if repo != nil {
|
||||
t.Fatal("did not expect repository instance")
|
||||
}
|
||||
}
|
||||
42
internal/repository/sql_scan.go
Normal file
42
internal/repository/sql_scan.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
)
|
||||
|
||||
type sqlQueryer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
// scanSingleRow 执行查询并扫描第一行到 dest。
|
||||
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
|
||||
// 如果 Close 失败,会与原始错误合并返回。
|
||||
// 设计目的:仅依赖 QueryContext,避免 QueryRowContext 对 *sql.Tx 的强绑定,
|
||||
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
|
||||
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
|
||||
rows, err := q.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if closeErr := rows.Close(); closeErr != nil {
|
||||
err = errors.Join(err, closeErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if !rows.Next() {
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
if err = rows.Scan(dest...); err != nil {
|
||||
return err
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
48
internal/repository/testdb_helper_test.go
Normal file
48
internal/repository/testdb_helper_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite" // 纯 Go SQLite,注册 "sqlite" 驱动
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var repoDBCounter int64
|
||||
|
||||
// openTestDB 为每个测试打开独立的内存数据库(使用 modernc.org/sqlite,无需 CGO)
|
||||
// 每次调用都生成唯一的 DSN,避免多个测试共用同一内存 DB 导致 index 重复错误
|
||||
func openTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&repoDBCounter, 1)
|
||||
dsn := fmt.Sprintf("file:repotestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
tables := []interface{}{
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
}
|
||||
if err := db.AutoMigrate(tables...); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
99
internal/repository/theme.go
Normal file
99
internal/repository/theme.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// ThemeConfigRepository 主题配置数据访问层
|
||||
type ThemeConfigRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewThemeConfigRepository 创建主题配置数据访问层
|
||||
func NewThemeConfigRepository(db *gorm.DB) *ThemeConfigRepository {
|
||||
return &ThemeConfigRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建主题配置
|
||||
func (r *ThemeConfigRepository) Create(ctx context.Context, theme *domain.ThemeConfig) error {
|
||||
return r.db.WithContext(ctx).Create(theme).Error
|
||||
}
|
||||
|
||||
// Update 更新主题配置
|
||||
func (r *ThemeConfigRepository) Update(ctx context.Context, theme *domain.ThemeConfig) error {
|
||||
return r.db.WithContext(ctx).Save(theme).Error
|
||||
}
|
||||
|
||||
// Delete 删除主题配置
|
||||
func (r *ThemeConfigRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.ThemeConfig{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取主题配置
|
||||
func (r *ThemeConfigRepository) GetByID(ctx context.Context, id int64) (*domain.ThemeConfig, error) {
|
||||
var theme domain.ThemeConfig
|
||||
err := r.db.WithContext(ctx).First(&theme, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &theme, nil
|
||||
}
|
||||
|
||||
// GetByName 根据名称获取主题配置
|
||||
func (r *ThemeConfigRepository) GetByName(ctx context.Context, name string) (*domain.ThemeConfig, error) {
|
||||
var theme domain.ThemeConfig
|
||||
err := r.db.WithContext(ctx).Where("name = ?", name).First(&theme).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &theme, nil
|
||||
}
|
||||
|
||||
// GetDefault 获取默认主题
|
||||
func (r *ThemeConfigRepository) GetDefault(ctx context.Context) (*domain.ThemeConfig, error) {
|
||||
var theme domain.ThemeConfig
|
||||
err := r.db.WithContext(ctx).Where("is_default = ?", true).First(&theme).Error
|
||||
if err != nil {
|
||||
// 如果没有默认主题,返回默认配置
|
||||
if err == gorm.ErrRecordNotFound {
|
||||
return domain.DefaultThemeConfig(), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &theme, nil
|
||||
}
|
||||
|
||||
// List 获取所有已启用的主题配置
|
||||
func (r *ThemeConfigRepository) List(ctx context.Context) ([]*domain.ThemeConfig, error) {
|
||||
var themes []*domain.ThemeConfig
|
||||
err := r.db.WithContext(ctx).Where("enabled = ?", true).Order("is_default DESC, id ASC").Find(&themes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return themes, nil
|
||||
}
|
||||
|
||||
// ListAll 获取所有主题配置
|
||||
func (r *ThemeConfigRepository) ListAll(ctx context.Context) ([]*domain.ThemeConfig, error) {
|
||||
var themes []*domain.ThemeConfig
|
||||
err := r.db.WithContext(ctx).Order("is_default DESC, id ASC").Find(&themes).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return themes, nil
|
||||
}
|
||||
|
||||
// SetDefault 设置默认主题
|
||||
func (r *ThemeConfigRepository) SetDefault(ctx context.Context, id int64) error {
|
||||
// 先清除所有默认标记
|
||||
if err := r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 设置新的默认主题
|
||||
return r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
|
||||
}
|
||||
73
internal/repository/update_cache_integration_test.go
Normal file
73
internal/repository/update_cache_integration_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UpdateCacheSuite struct {
|
||||
IntegrationRedisSuite
|
||||
cache *updateCache
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) SetupTest() {
|
||||
s.IntegrationRedisSuite.SetupTest()
|
||||
s.cache = NewUpdateCache(s.rdb).(*updateCache)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
|
||||
_, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err, "GetUpdateInfo")
|
||||
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
|
||||
updateTTL := 5 * time.Minute
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err, "TTL updateCacheKey")
|
||||
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
|
||||
}
|
||||
|
||||
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
|
||||
// TTL=0 means persist forever (no expiry) in Redis SET command
|
||||
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
|
||||
|
||||
info, err := s.cache.GetUpdateInfo(s.ctx)
|
||||
require.NoError(s.T(), err)
|
||||
require.Equal(s.T(), "v0.0.0", info)
|
||||
|
||||
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
|
||||
require.NoError(s.T(), err)
|
||||
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
|
||||
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
|
||||
}
|
||||
|
||||
func TestUpdateCacheSuite(t *testing.T) {
|
||||
suite.Run(t, new(UpdateCacheSuite))
|
||||
}
|
||||
314
internal/repository/user.go
Normal file
314
internal/repository/user.go
Normal file
@@ -0,0 +1,314 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _)
|
||||
// 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配
|
||||
func escapeLikePattern(s string) string {
|
||||
// 先转义 \,再转义 % 和 _(顺序很重要)
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return s
|
||||
}
|
||||
|
||||
// UserRepository 用户数据访问层
|
||||
type UserRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRepository 创建用户数据访问层
|
||||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||||
return &UserRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Create(user).Error
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Save(user).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户(软删除)
|
||||
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByUsername 根据用户名获取用户
|
||||
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByEmail 根据邮箱获取用户
|
||||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetByPhone 根据手机号获取用户
|
||||
func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// ListByStatus 根据状态获取用户列表
|
||||
func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态
|
||||
func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error
|
||||
}
|
||||
|
||||
// UpdateLastLogin 更新最后登录信息
|
||||
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
|
||||
now := time.Now()
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||
"last_login_time": &now,
|
||||
"last_login_ip": ip,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ExistsByUsername 检查用户名是否存在
|
||||
func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ExistsByEmail 检查邮箱是否存在
|
||||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// ExistsByPhone 检查手机号是否存在
|
||||
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Search 搜索用户
|
||||
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
var total int64
|
||||
|
||||
// 转义 LIKE 特殊字符,防止搜索被意外干扰
|
||||
escapedKeyword := escapeLikePattern(keyword)
|
||||
pattern := "%" + escapedKeyword + "%"
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where(
|
||||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||||
pattern, pattern, pattern, pattern,
|
||||
)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 获取列表
|
||||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// UpdateTOTP 更新用户的 TOTP 字段
|
||||
func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error {
|
||||
return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{
|
||||
"totp_enabled": user.TOTPEnabled,
|
||||
"totp_secret": user.TOTPSecret,
|
||||
"totp_recovery_codes": user.TOTPRecoveryCodes,
|
||||
}).Error
|
||||
}
|
||||
|
||||
// UpdatePassword 更新用户密码
|
||||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||||
}
|
||||
|
||||
// ListCreatedAfter 查询指定时间之后创建的用户(limit=0表示不限制数量)
|
||||
func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
var total int64
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since)
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if limit > 0 {
|
||||
query = query.Offset(offset).Limit(limit)
|
||||
}
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// AdvancedFilter 高级用户筛选请求
|
||||
type AdvancedFilter struct {
|
||||
Keyword string // 关键字(用户名/邮箱/手机号/昵称)
|
||||
Status int // 状态:-1 全部,0/1/2/3 对应 UserStatus
|
||||
RoleIDs []int64 // 角色ID列表(按角色筛选)
|
||||
CreatedFrom *time.Time // 注册时间范围(起始)
|
||||
CreatedTo *time.Time // 注册时间范围(截止)
|
||||
LastLoginFrom *time.Time // 最后登录时间范围(起始)
|
||||
SortBy string // 排序字段:created_at, last_login_time, username
|
||||
SortOrder string // 排序方向:asc, desc
|
||||
Offset int
|
||||
Limit int
|
||||
}
|
||||
|
||||
// AdvancedSearch 高级用户搜索(支持多维度组合筛选)
|
||||
func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) {
|
||||
var users []*domain.User
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||||
|
||||
// 关键字搜索(转义 LIKE 特殊字符)
|
||||
if filter.Keyword != "" {
|
||||
like := "%" + escapeLikePattern(filter.Keyword) + "%"
|
||||
query = query.Where(
|
||||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||||
like, like, like, like,
|
||||
)
|
||||
}
|
||||
|
||||
// 状态筛选
|
||||
if filter.Status >= 0 {
|
||||
query = query.Where("status = ?", filter.Status)
|
||||
}
|
||||
|
||||
// 注册时间范围
|
||||
if filter.CreatedFrom != nil {
|
||||
query = query.Where("created_at >= ?", filter.CreatedFrom)
|
||||
}
|
||||
if filter.CreatedTo != nil {
|
||||
query = query.Where("created_at <= ?", filter.CreatedTo)
|
||||
}
|
||||
|
||||
// 最后登录时间范围
|
||||
if filter.LastLoginFrom != nil {
|
||||
query = query.Where("last_login_time >= ?", filter.LastLoginFrom)
|
||||
}
|
||||
|
||||
// 按角色筛选(子查询)
|
||||
if len(filter.RoleIDs) > 0 {
|
||||
query = query.Where(
|
||||
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
|
||||
filter.RoleIDs,
|
||||
)
|
||||
}
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// 排序
|
||||
sortBy := "created_at"
|
||||
sortOrder := "DESC"
|
||||
if filter.SortBy != "" {
|
||||
allowedFields := map[string]bool{
|
||||
"created_at": true, "last_login_time": true,
|
||||
"username": true, "updated_at": true,
|
||||
}
|
||||
if allowedFields[filter.SortBy] {
|
||||
sortBy = filter.SortBy
|
||||
}
|
||||
}
|
||||
if filter.SortOrder == "asc" {
|
||||
sortOrder = "ASC"
|
||||
}
|
||||
query = query.Order(sortBy + " " + sortOrder)
|
||||
|
||||
// 分页
|
||||
limit := filter.Limit
|
||||
if limit <= 0 {
|
||||
limit = 20
|
||||
}
|
||||
if limit > 200 {
|
||||
limit = 200
|
||||
}
|
||||
query = query.Offset(filter.Offset).Limit(limit)
|
||||
|
||||
if err := query.Find(&users).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
537
internal/repository/user_repo_integration_test.go
Normal file
537
internal/repository/user_repo_integration_test.go
Normal file
@@ -0,0 +1,537 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/user-management-system/ent"
|
||||
"github.com/user-management-system/internal/pkg/pagination"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UserRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *userRepository
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
s.client = testEntClient(s.T())
|
||||
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
|
||||
|
||||
// 清理测试数据,确保每个测试从干净状态开始
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
|
||||
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
|
||||
}
|
||||
|
||||
func TestUserRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
if u.Email == "" {
|
||||
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
|
||||
}
|
||||
if u.PasswordHash == "" {
|
||||
u.PasswordHash = "test-password-hash"
|
||||
}
|
||||
if u.Role == "" {
|
||||
u.Role = service.RoleUser
|
||||
}
|
||||
if u.Status == "" {
|
||||
u.Status = service.StatusActive
|
||||
}
|
||||
if u.Concurrency == 0 {
|
||||
u.Concurrency = 5
|
||||
}
|
||||
|
||||
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
|
||||
return u
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
|
||||
s.T().Helper()
|
||||
|
||||
now := time.Now()
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1 * time.Hour)).
|
||||
SetExpiresAt(now.Add(24 * time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
|
||||
if mutate != nil {
|
||||
mutate(create)
|
||||
}
|
||||
|
||||
sub, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create subscription")
|
||||
return sub
|
||||
}
|
||||
|
||||
// --- Create / GetByID / GetByEmail / Update / Delete ---
|
||||
|
||||
func (s *UserRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser(&service.User{
|
||||
Email: "create@test.com",
|
||||
Username: "testuser",
|
||||
PasswordHash: "test-password-hash",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
s.Require().NotZero(user.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal("create@test.com", got.Email)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail() {
|
||||
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
|
||||
|
||||
got, err := s.repo.GetByEmail(s.ctx, user.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
|
||||
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
|
||||
s.Require().Error(err, "expected error for non-existent email")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
got.Username = "updated"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated", updated.Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
|
||||
|
||||
err := s.repo.Delete(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
// --- List / ListWithFilters ---
|
||||
|
||||
func (s *UserRepoSuite) TestList() {
|
||||
s.mustCreateUser(&service.User{Email: "list1@test.com"})
|
||||
s.mustCreateUser(&service.User{Email: "list2@test.com"})
|
||||
|
||||
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(users, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Status() {
|
||||
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
|
||||
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(service.StatusActive, users[0].Status)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Role() {
|
||||
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
|
||||
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal(service.RoleAdmin, users[0].Role)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_Search() {
|
||||
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
|
||||
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Contains(users[0].Email, "alice")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
|
||||
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
|
||||
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(users, 1)
|
||||
s.Require().Equal("JohnDoe", users[0].Username)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
|
||||
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
|
||||
groupActive := s.mustCreateGroup("g-sub-active")
|
||||
groupExpired := s.mustCreateGroup("g-sub-expired")
|
||||
|
||||
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusActive)
|
||||
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
|
||||
})
|
||||
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
|
||||
})
|
||||
|
||||
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Len(users, 1, "expected 1 user")
|
||||
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
|
||||
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
|
||||
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
target := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
// --- Balance operations ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance() {
|
||||
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
|
||||
s.Require().NoError(err, "UpdateBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(12.5, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
|
||||
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
|
||||
s.Require().NoError(err, "UpdateBalance with negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(7.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance() {
|
||||
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
|
||||
s.Require().NoError(err, "DeductBalance")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(5.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
|
||||
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
|
||||
|
||||
// 透支策略:允许扣除超过余额的金额
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
|
||||
// 验证余额变为负数
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
|
||||
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
|
||||
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
|
||||
s.Require().NoError(err, "DeductBalance exact amount")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(0.0, got.Balance, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
|
||||
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
|
||||
|
||||
// 扣除超过余额的金额 - 应该成功
|
||||
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
|
||||
// 验证余额为负
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
|
||||
}
|
||||
|
||||
// --- Concurrency ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency() {
|
||||
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
|
||||
s.Require().NoError(err, "UpdateConcurrency")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(8, got.Concurrency)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
|
||||
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
|
||||
|
||||
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
|
||||
s.Require().NoError(err, "UpdateConcurrency negative")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(3, got.Concurrency)
|
||||
}
|
||||
|
||||
// --- ExistsByEmail ---
|
||||
|
||||
func (s *UserRepoSuite) TestExistsByEmail() {
|
||||
s.mustCreateUser(&service.User{Email: "exists@test.com"})
|
||||
|
||||
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
|
||||
s.Require().NoError(err, "ExistsByEmail")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- RemoveGroupFromAllowedGroups ---
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
|
||||
target := s.mustCreateGroup("target-42")
|
||||
other := s.mustCreateGroup("other-7")
|
||||
|
||||
userA := s.mustCreateUser(&service.User{
|
||||
Email: "a1@example.com",
|
||||
AllowedGroups: []int64{target.ID, other.ID},
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "a2@example.com",
|
||||
AllowedGroups: []int64{other.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
|
||||
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, userA.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotContains(got.AllowedGroups, target.ID)
|
||||
s.Require().Contains(got.AllowedGroups, other.ID)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
|
||||
groupA := s.mustCreateGroup("nomatch-a")
|
||||
groupB := s.mustCreateGroup("nomatch-b")
|
||||
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "nomatch@test.com",
|
||||
AllowedGroups: []int64{groupA.ID, groupB.ID},
|
||||
})
|
||||
|
||||
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Zero(affected, "expected no affected rows")
|
||||
}
|
||||
|
||||
// --- GetFirstAdmin ---
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin() {
|
||||
admin1 := s.mustCreateUser(&service.User{
|
||||
Email: "admin1@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "admin2@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "user@example.com",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
_, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().Error(err, "expected error when no admin exists")
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "disabled@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
activeAdmin := s.mustCreateUser(&service.User{
|
||||
Email: "active@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetFirstAdmin(s.ctx)
|
||||
s.Require().NoError(err, "GetFirstAdmin")
|
||||
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
|
||||
}
|
||||
|
||||
// --- Combined ---
|
||||
|
||||
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
|
||||
user1 := s.mustCreateUser(&service.User{
|
||||
Email: "a@example.com",
|
||||
Username: "Alice",
|
||||
Role: service.RoleUser,
|
||||
Status: service.StatusActive,
|
||||
Balance: 10,
|
||||
})
|
||||
user2 := s.mustCreateUser(&service.User{
|
||||
Email: "b@example.com",
|
||||
Username: "Bob",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusActive,
|
||||
Balance: 1,
|
||||
})
|
||||
s.mustCreateUser(&service.User{
|
||||
Email: "c@example.com",
|
||||
Role: service.RoleAdmin,
|
||||
Status: service.StatusDisabled,
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
|
||||
|
||||
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
|
||||
s.Require().NoError(err, "GetByEmail")
|
||||
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
|
||||
|
||||
got.Username = "Alice2"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
|
||||
got2, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
|
||||
got3, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateBalance")
|
||||
s.Require().InDelta(12.5, got3.Balance, 1e-6)
|
||||
|
||||
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
|
||||
got4, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after DeductBalance")
|
||||
s.Require().InDelta(7.5, got4.Balance, 1e-6)
|
||||
|
||||
// 透支策略:允许扣除超过余额的金额
|
||||
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
|
||||
s.Require().NoError(err, "DeductBalance should allow overdraft")
|
||||
gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after overdraft")
|
||||
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
|
||||
|
||||
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
|
||||
got5, err := s.repo.GetByID(s.ctx, user1.ID)
|
||||
s.Require().NoError(err, "GetByID after UpdateConcurrency")
|
||||
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
|
||||
|
||||
params := pagination.PaginationParams{Page: 1, PageSize: 10}
|
||||
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
|
||||
s.Require().NoError(err, "ListWithFilters")
|
||||
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
|
||||
s.Require().Len(users, 1, "ListWithFilters len mismatch")
|
||||
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
|
||||
}
|
||||
|
||||
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
|
||||
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
|
||||
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
|
||||
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
|
||||
err := s.repo.DeductBalance(s.ctx, 999999, 5)
|
||||
s.Require().Error(err, "expected error for non-existent user")
|
||||
// DeductBalance 在用户不存在时返回 ErrUserNotFound
|
||||
s.Require().ErrorIs(err, service.ErrUserNotFound)
|
||||
}
|
||||
198
internal/repository/user_repository_test.go
Normal file
198
internal/repository/user_repository_test.go
Normal file
@@ -0,0 +1,198 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func setupTestDB(t *testing.T) *gorm.DB {
|
||||
return openTestDB(t)
|
||||
}
|
||||
|
||||
// TestUserRepository_Create 测试创建用户
|
||||
func TestUserRepository_Create(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "testuser",
|
||||
Email: domain.StrPtr("test@example.com"),
|
||||
Phone: domain.StrPtr("13800138000"),
|
||||
Password: "hashedpassword",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, user); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if user.ID == 0 {
|
||||
t.Error("创建后用户ID不应为0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_GetByUsername 测试根据用户名查询
|
||||
func TestUserRepository_GetByUsername(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "findme",
|
||||
Email: domain.StrPtr("findme@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
found, err := repo.GetByUsername(ctx, "findme")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUsername() error = %v", err)
|
||||
}
|
||||
if found.Username != "findme" {
|
||||
t.Errorf("Username = %v, want findme", found.Username)
|
||||
}
|
||||
|
||||
_, err = repo.GetByUsername(ctx, "notexist")
|
||||
if err == nil {
|
||||
t.Error("查找不存在的用户应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_GetByEmail 测试根据邮箱查询
|
||||
func TestUserRepository_GetByEmail(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "emailuser",
|
||||
Email: domain.StrPtr("email@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
found, err := repo.GetByEmail(ctx, "email@example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByEmail() error = %v", err)
|
||||
}
|
||||
if domain.DerefStr(found.Email) != "email@example.com" {
|
||||
t.Errorf("Email = %v, want email@example.com", domain.DerefStr(found.Email))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_Update 测试更新用户
|
||||
func TestUserRepository_Update(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "updateme",
|
||||
Email: domain.StrPtr("update@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
user.Nickname = "已更新"
|
||||
if err := repo.Update(ctx, user); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, user.ID)
|
||||
if found.Nickname != "已更新" {
|
||||
t.Errorf("Nickname = %v, want 已更新", found.Nickname)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_Delete 测试删除用户
|
||||
func TestUserRepository_Delete(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "deleteme",
|
||||
Email: domain.StrPtr("delete@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
if err := repo.Delete(ctx, user.ID); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := repo.GetByID(ctx, user.ID)
|
||||
if err == nil {
|
||||
t.Error("删除后查询应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_ExistsBy 测试存在性检查
|
||||
func TestUserRepository_ExistsBy(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "existsuser",
|
||||
Email: domain.StrPtr("exists@example.com"),
|
||||
Phone: domain.StrPtr("13900139000"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
ok, _ := repo.ExistsByUsername(ctx, "existsuser")
|
||||
if !ok {
|
||||
t.Error("ExistsByUsername 应返回 true")
|
||||
}
|
||||
|
||||
ok, _ = repo.ExistsByEmail(ctx, "exists@example.com")
|
||||
if !ok {
|
||||
t.Error("ExistsByEmail 应返回 true")
|
||||
}
|
||||
|
||||
ok, _ = repo.ExistsByPhone(ctx, "13900139000")
|
||||
if !ok {
|
||||
t.Error("ExistsByPhone 应返回 true")
|
||||
}
|
||||
|
||||
ok, _ = repo.ExistsByUsername(ctx, "notexist")
|
||||
if ok {
|
||||
t.Error("不存在的用户 ExistsByUsername 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_List 测试列表查询
|
||||
func TestUserRepository_List(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "listuser" + string(rune('0'+i)),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
}
|
||||
|
||||
users, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(users) != 5 {
|
||||
t.Errorf("len(users) = %d, want 5", len(users))
|
||||
}
|
||||
if total != 5 {
|
||||
t.Errorf("total = %d, want 5", total)
|
||||
}
|
||||
}
|
||||
175
internal/repository/user_role.go
Normal file
175
internal/repository/user_role.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// UserRoleRepository 用户角色关联数据访问层
|
||||
type UserRoleRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewUserRoleRepository 创建用户角色关联数据访问层
|
||||
func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository {
|
||||
return &UserRoleRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建用户角色关联
|
||||
func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error {
|
||||
return r.db.WithContext(ctx).Create(userRole).Error
|
||||
}
|
||||
|
||||
// Delete 删除用户角色关联
|
||||
func (r *UserRoleRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, id).Error
|
||||
}
|
||||
|
||||
// DeleteByUserID 删除用户的所有角色
|
||||
func (r *UserRoleRepository) DeleteByUserID(ctx context.Context, userID int64) error {
|
||||
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error
|
||||
}
|
||||
|
||||
// DeleteByRoleID 删除角色的所有用户
|
||||
func (r *UserRoleRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
|
||||
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.UserRole{}).Error
|
||||
}
|
||||
|
||||
// GetByUserID 根据用户ID获取角色列表
|
||||
func (r *UserRoleRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserRole, error) {
|
||||
var userRoles []*domain.UserRole
|
||||
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&userRoles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
// GetByRoleID 根据角色ID获取用户列表
|
||||
func (r *UserRoleRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error) {
|
||||
var userRoles []*domain.UserRole
|
||||
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&userRoles).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
// GetRoleIDsByUserID 根据用户ID获取角色ID列表
|
||||
func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int64) ([]int64, error) {
|
||||
var roleIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &roleIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return roleIDs, nil
|
||||
}
|
||||
|
||||
// GetUserRolesAndPermissions 获取用户角色和权限(PERF-01 优化:合并为单次 JOIN 查询)
|
||||
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
|
||||
var results []struct {
|
||||
RoleID int64
|
||||
RoleName string
|
||||
RoleCode string
|
||||
RoleStatus int
|
||||
PermissionID int64
|
||||
PermissionCode string
|
||||
PermissionName string
|
||||
}
|
||||
|
||||
// 使用 LEFT JOIN 一次性获取用户角色和权限
|
||||
err := r.db.WithContext(ctx).
|
||||
Raw(`
|
||||
SELECT DISTINCT r.id as role_id, r.name as role_name, r.code as role_code, r.status as role_status,
|
||||
p.id as permission_id, p.code as permission_code, p.name as permission_name
|
||||
FROM user_roles ur
|
||||
JOIN roles r ON ur.role_id = r.id
|
||||
LEFT JOIN role_permissions rp ON r.id = rp.role_id
|
||||
LEFT JOIN permissions p ON rp.permission_id = p.id
|
||||
WHERE ur.user_id = ? AND r.status = 1
|
||||
`, userID).
|
||||
Scan(&results).Error
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 构建角色和权限列表
|
||||
roleMap := make(map[int64]*domain.Role)
|
||||
permMap := make(map[int64]*domain.Permission)
|
||||
|
||||
for _, row := range results {
|
||||
if _, ok := roleMap[row.RoleID]; !ok {
|
||||
roleMap[row.RoleID] = &domain.Role{
|
||||
ID: row.RoleID,
|
||||
Name: row.RoleName,
|
||||
Code: row.RoleCode,
|
||||
Status: domain.RoleStatus(row.RoleStatus),
|
||||
}
|
||||
}
|
||||
if row.PermissionID > 0 {
|
||||
if _, ok := permMap[row.PermissionID]; !ok {
|
||||
permMap[row.PermissionID] = &domain.Permission{
|
||||
ID: row.PermissionID,
|
||||
Code: row.PermissionCode,
|
||||
Name: row.PermissionName,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
roles := make([]*domain.Role, 0, len(roleMap))
|
||||
for _, role := range roleMap {
|
||||
roles = append(roles, role)
|
||||
}
|
||||
|
||||
perms := make([]*domain.Permission, 0, len(permMap))
|
||||
for _, perm := range permMap {
|
||||
perms = append(perms, perm)
|
||||
}
|
||||
|
||||
return roles, perms, nil
|
||||
}
|
||||
|
||||
// GetUserIDByRoleID 根据角色ID获取用户ID列表
|
||||
func (r *UserRoleRepository) GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
|
||||
var userIDs []int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("role_id = ?", roleID).Pluck("user_id", &userIDs).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return userIDs, nil
|
||||
}
|
||||
|
||||
// Exists 检查用户角色关联是否存在
|
||||
func (r *UserRoleRepository) Exists(ctx context.Context, userID, roleID int64) (bool, error) {
|
||||
var count int64
|
||||
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).
|
||||
Where("user_id = ? AND role_id = ?", userID, roleID).
|
||||
Count(&count).Error
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// BatchCreate 批量创建用户角色关联
|
||||
func (r *UserRoleRepository) BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error {
|
||||
if len(userRoles) == 0 {
|
||||
return nil
|
||||
}
|
||||
return r.db.WithContext(ctx).Create(&userRoles).Error
|
||||
}
|
||||
|
||||
// BatchDelete 批量删除用户角色关联
|
||||
func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domain.UserRole) error {
|
||||
if len(userRoles) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var ids []int64
|
||||
for _, ur := range userRoles {
|
||||
ids = append(ids, ur.ID)
|
||||
}
|
||||
|
||||
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error
|
||||
}
|
||||
747
internal/repository/user_subscription_repo_integration_test.go
Normal file
747
internal/repository/user_subscription_repo_integration_test.go
Normal file
@@ -0,0 +1,747 @@
|
||||
//go:build integration
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
dbent "github.com/user-management-system/ent"
|
||||
"github.com/user-management-system/internal/pkg/pagination"
|
||||
"github.com/user-management-system/internal/service"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type UserSubscriptionRepoSuite struct {
|
||||
suite.Suite
|
||||
ctx context.Context
|
||||
client *dbent.Client
|
||||
repo *userSubscriptionRepository
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) SetupTest() {
|
||||
s.ctx = context.Background()
|
||||
tx := testEntTx(s.T())
|
||||
s.client = tx.Client()
|
||||
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
|
||||
}
|
||||
|
||||
func TestUserSubscriptionRepoSuite(t *testing.T) {
|
||||
suite.Run(t, new(UserSubscriptionRepoSuite))
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateUser(email string, role string) *service.User {
|
||||
s.T().Helper()
|
||||
|
||||
if role == "" {
|
||||
role = service.RoleUser
|
||||
}
|
||||
|
||||
u, err := s.client.User.Create().
|
||||
SetEmail(email).
|
||||
SetPasswordHash("test-password-hash").
|
||||
SetStatus(service.StatusActive).
|
||||
SetRole(role).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create user")
|
||||
return userEntityToService(u)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateGroup(name string) *service.Group {
|
||||
s.T().Helper()
|
||||
|
||||
g, err := s.client.Group.Create().
|
||||
SetName(name).
|
||||
SetStatus(service.StatusActive).
|
||||
Save(s.ctx)
|
||||
s.Require().NoError(err, "create group")
|
||||
return groupEntityToService(g)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
|
||||
s.T().Helper()
|
||||
|
||||
now := time.Now()
|
||||
create := s.client.UserSubscription.Create().
|
||||
SetUserID(userID).
|
||||
SetGroupID(groupID).
|
||||
SetStartsAt(now.Add(-1 * time.Hour)).
|
||||
SetExpiresAt(now.Add(24 * time.Hour)).
|
||||
SetStatus(service.SubscriptionStatusActive).
|
||||
SetAssignedAt(now).
|
||||
SetNotes("")
|
||||
|
||||
if mutate != nil {
|
||||
mutate(create)
|
||||
}
|
||||
|
||||
sub, err := create.Save(s.ctx)
|
||||
s.Require().NoError(err, "create user subscription")
|
||||
return sub
|
||||
}
|
||||
|
||||
// --- Create / GetByID / Update / Delete ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCreate() {
|
||||
user := s.mustCreateUser("sub-create@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-create")
|
||||
|
||||
sub := &service.UserSubscription{
|
||||
UserID: user.ID,
|
||||
GroupID: group.ID,
|
||||
Status: service.SubscriptionStatusActive,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
err := s.repo.Create(s.ctx, sub)
|
||||
s.Require().NoError(err, "Create")
|
||||
s.Require().NotZero(sub.ID, "expected ID to be set")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().Equal(sub.UserID, got.UserID)
|
||||
s.Require().Equal(sub.GroupID, got.GroupID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
|
||||
user := s.mustCreateUser("preload@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-preload")
|
||||
admin := s.mustCreateUser("admin@test.com", service.RoleAdmin)
|
||||
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetAssignedBy(admin.ID)
|
||||
})
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().NotNil(got.User, "expected User preload")
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
|
||||
s.Require().Equal(user.ID, got.User.ID)
|
||||
s.Require().Equal(group.ID, got.Group.ID)
|
||||
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
|
||||
_, err := s.repo.GetByID(s.ctx, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent ID")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdate() {
|
||||
user := s.mustCreateUser("update@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-update")
|
||||
created := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
sub, err := s.repo.GetByID(s.ctx, created.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
|
||||
sub.Notes = "updated notes"
|
||||
s.Require().NoError(s.repo.Update(s.ctx, sub), "Update")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "GetByID after update")
|
||||
s.Require().Equal("updated notes", got.Notes)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete() {
|
||||
user := s.mustCreateUser("delete@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-delete")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.Delete(s.ctx, sub.ID)
|
||||
s.Require().NoError(err, "Delete")
|
||||
|
||||
_, err = s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().Error(err, "expected error after delete")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDelete_Idempotent() {
|
||||
s.Require().NoError(s.repo.Delete(s.ctx, 42424242), "Delete should be idempotent")
|
||||
}
|
||||
|
||||
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
|
||||
user := s.mustCreateUser("byuser@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-byuser")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetByUserIDAndGroupID")
|
||||
s.Require().Equal(sub.ID, got.ID)
|
||||
s.Require().NotNil(got.Group, "expected Group preload")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
|
||||
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
|
||||
s.Require().Error(err, "expected error for non-existent pair")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
|
||||
user := s.mustCreateUser("active@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-active")
|
||||
|
||||
active := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||
s.Require().Equal(active.ID, got.ID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
|
||||
user := s.mustCreateUser("expired@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-expired")
|
||||
|
||||
s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
|
||||
})
|
||||
|
||||
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().Error(err, "expected error for expired subscription")
|
||||
}
|
||||
|
||||
// --- ListByUserID / ListActiveByUserID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
|
||||
user := s.mustCreateUser("listby@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-list1")
|
||||
g2 := s.mustCreateGroup("g-list2")
|
||||
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "ListByUserID")
|
||||
s.Require().Len(subs, 2)
|
||||
for _, sub := range subs {
|
||||
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
|
||||
user := s.mustCreateUser("listactive@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-act1")
|
||||
g2 := s.mustCreateGroup("g-act2")
|
||||
|
||||
s.mustCreateSubscription(user.ID, g1.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
|
||||
s.Require().NoError(err, "ListActiveByUserID")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- ListByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
|
||||
user1 := s.mustCreateUser("u1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("u2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-listgrp")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
|
||||
s.Require().NoError(err, "ListByGroupID")
|
||||
s.Require().Len(subs, 2)
|
||||
s.Require().Equal(int64(2), page.Total)
|
||||
for _, sub := range subs {
|
||||
s.Require().NotNil(sub.User, "expected User preload")
|
||||
s.Require().NotNil(sub.Group, "expected Group preload")
|
||||
}
|
||||
}
|
||||
|
||||
// --- List with filters ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
|
||||
user := s.mustCreateUser("list@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-list")
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
|
||||
s.Require().NoError(err, "List")
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(int64(1), page.Total)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
|
||||
user1 := s.mustCreateUser("filter1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("filter2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-filter")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(user1.ID, subs[0].UserID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
|
||||
user := s.mustCreateUser("grpfilter@test.com", service.RoleUser)
|
||||
g1 := s.mustCreateGroup("g-f1")
|
||||
g2 := s.mustCreateGroup("g-f2")
|
||||
|
||||
s.mustCreateSubscription(user.ID, g1.ID, nil)
|
||||
s.mustCreateSubscription(user.ID, g2.ID, nil)
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(g1.ID, subs[0].GroupID)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
|
||||
user1 := s.mustCreateUser("statfilter1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("statfilter2@test.com", service.RoleUser)
|
||||
group1 := s.mustCreateGroup("g-stat-1")
|
||||
group2 := s.mustCreateGroup("g-stat-2")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group1.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusActive)
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
s.mustCreateSubscription(user2.ID, group2.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
|
||||
s.Require().NoError(err)
|
||||
s.Require().Len(subs, 1)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
|
||||
}
|
||||
|
||||
// --- Usage tracking ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
|
||||
user := s.mustCreateUser("usage@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-usage")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
|
||||
s.Require().NoError(err, "IncrementUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(1.25, got.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, got.MonthlyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
|
||||
user := s.mustCreateUser("accum@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-accum")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(3.5, got.DailyUsageUSD, 1e-6)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
|
||||
user := s.mustCreateUser("activate@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-activate")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
|
||||
s.Require().NoError(err, "ActivateWindows")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().NotNil(got.DailyWindowStart)
|
||||
s.Require().NotNil(got.WeeklyWindowStart)
|
||||
s.Require().NotNil(got.MonthlyWindowStart)
|
||||
s.Require().WithinDuration(activateAt, *got.DailyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
|
||||
user := s.mustCreateUser("resetd@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetd")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetDailyUsageUsd(10.0)
|
||||
c.SetWeeklyUsageUsd(20.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetDailyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(0.0, got.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(20.0, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.DailyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.DailyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
|
||||
user := s.mustCreateUser("resetw@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetw")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetWeeklyUsageUsd(15.0)
|
||||
c.SetMonthlyUsageUsd(30.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetWeeklyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(0.0, got.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(30.0, got.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.WeeklyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.WeeklyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
|
||||
user := s.mustCreateUser("resetm@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-resetm")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetMonthlyUsageUsd(25.0)
|
||||
})
|
||||
|
||||
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
|
||||
s.Require().NoError(err, "ResetMonthlyUsage")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().InDelta(0.0, got.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(got.MonthlyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *got.MonthlyWindowStart, time.Microsecond)
|
||||
}
|
||||
|
||||
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
|
||||
user := s.mustCreateUser("status@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-status")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
|
||||
s.Require().NoError(err, "UpdateStatus")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
|
||||
user := s.mustCreateUser("extend@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-extend")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
|
||||
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
|
||||
s.Require().NoError(err, "ExtendExpiry")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().WithinDuration(newExpiry, got.ExpiresAt, time.Microsecond)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
|
||||
user := s.mustCreateUser("notes@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-notes")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
|
||||
s.Require().NoError(err, "UpdateNotes")
|
||||
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
s.Require().Equal("VIP user", got.Notes)
|
||||
}
|
||||
|
||||
// --- ListExpired / BatchUpdateExpiredStatus ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestListExpired() {
|
||||
user := s.mustCreateUser("listexp@test.com", service.RoleUser)
|
||||
groupActive := s.mustCreateGroup("g-listexp-active")
|
||||
groupExpired := s.mustCreateGroup("g-listexp-expired")
|
||||
|
||||
s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
expired, err := s.repo.ListExpired(s.ctx)
|
||||
s.Require().NoError(err, "ListExpired")
|
||||
s.Require().Len(expired, 1)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
|
||||
user := s.mustCreateUser("batch@test.com", service.RoleUser)
|
||||
groupFuture := s.mustCreateGroup("g-batch-future")
|
||||
groupPast := s.mustCreateGroup("g-batch-past")
|
||||
|
||||
active := s.mustCreateSubscription(user.ID, groupFuture.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
expiredActive := s.mustCreateSubscription(user.ID, groupPast.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||
s.Require().Equal(int64(1), affected)
|
||||
|
||||
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
|
||||
|
||||
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
|
||||
}
|
||||
|
||||
// --- ExistsByUserIDAndGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
|
||||
user := s.mustCreateUser("exists@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-exists")
|
||||
|
||||
s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
|
||||
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
|
||||
s.Require().True(exists)
|
||||
|
||||
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
|
||||
s.Require().NoError(err)
|
||||
s.Require().False(notExists)
|
||||
}
|
||||
|
||||
// --- CountByGroupID / CountActiveByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
|
||||
user1 := s.mustCreateUser("cnt1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("cnt2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-count")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetStatus(service.SubscriptionStatusExpired)
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
|
||||
})
|
||||
|
||||
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountByGroupID")
|
||||
s.Require().Equal(int64(2), count)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
|
||||
user1 := s.mustCreateUser("cntact1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("cntact2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-cntact")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
|
||||
})
|
||||
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) // expired by time
|
||||
})
|
||||
|
||||
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "CountActiveByGroupID")
|
||||
s.Require().Equal(int64(1), count, "only future expiry counts as active")
|
||||
}
|
||||
|
||||
// --- DeleteByGroupID ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
|
||||
user1 := s.mustCreateUser("delgrp1@test.com", service.RoleUser)
|
||||
user2 := s.mustCreateUser("delgrp2@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-delgrp")
|
||||
|
||||
s.mustCreateSubscription(user1.ID, group.ID, nil)
|
||||
s.mustCreateSubscription(user2.ID, group.ID, nil)
|
||||
|
||||
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
|
||||
s.Require().NoError(err, "DeleteByGroupID")
|
||||
s.Require().Equal(int64(2), affected)
|
||||
|
||||
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
|
||||
s.Require().Zero(count)
|
||||
}
|
||||
|
||||
// --- Combined scenario ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
|
||||
user := s.mustCreateUser("subr@example.com", service.RoleUser)
|
||||
groupActive := s.mustCreateGroup("g-subr-active")
|
||||
groupExpired := s.mustCreateGroup("g-subr-expired")
|
||||
|
||||
active := s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
|
||||
})
|
||||
expiredActive := s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
|
||||
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
|
||||
})
|
||||
|
||||
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, groupActive.ID)
|
||||
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
|
||||
s.Require().Equal(active.ID, got.ID, "expected active subscription")
|
||||
|
||||
activateAt := time.Now().Add(-25 * time.Hour)
|
||||
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
|
||||
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
|
||||
|
||||
after, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID")
|
||||
s.Require().InDelta(1.25, after.DailyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, after.WeeklyUsageUSD, 1e-6)
|
||||
s.Require().InDelta(1.25, after.MonthlyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
|
||||
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
|
||||
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
|
||||
|
||||
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
|
||||
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
|
||||
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
|
||||
s.Require().NoError(err, "GetByID after reset")
|
||||
s.Require().InDelta(0.0, afterReset.DailyUsageUSD, 1e-6)
|
||||
s.Require().NotNil(afterReset.DailyWindowStart)
|
||||
s.Require().WithinDuration(resetAt, *afterReset.DailyWindowStart, time.Microsecond)
|
||||
|
||||
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
|
||||
s.Require().NoError(err, "BatchUpdateExpiredStatus")
|
||||
s.Require().Equal(int64(1), affected, "expected 1 affected row")
|
||||
|
||||
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
|
||||
s.Require().NoError(err, "GetByID expired")
|
||||
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
|
||||
}
|
||||
|
||||
// --- 软删除过滤测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
|
||||
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-softdeleted")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
// 软删除分组
|
||||
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
|
||||
s.Require().NoError(err, "soft delete group")
|
||||
|
||||
// IncrementUsage 应该失败,因为分组已软删除
|
||||
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
|
||||
s.Require().Error(err, "should fail for soft-deleted group")
|
||||
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
|
||||
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
|
||||
s.Require().Error(err, "should fail for non-existent subscription")
|
||||
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
|
||||
}
|
||||
|
||||
// --- nil 入参测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
|
||||
err := s.repo.Create(s.ctx, nil)
|
||||
s.Require().Error(err, "Create should fail with nil input")
|
||||
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
|
||||
err := s.repo.Update(s.ctx, nil)
|
||||
s.Require().Error(err, "Update should fail with nil input")
|
||||
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
|
||||
}
|
||||
|
||||
// --- 并发用量更新测试 ---
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
|
||||
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
|
||||
group := s.mustCreateGroup("g-concurrent")
|
||||
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
|
||||
|
||||
const numGoroutines = 10
|
||||
const incrementPerGoroutine = 1.5
|
||||
|
||||
// 启动多个 goroutine 并发调用 IncrementUsage
|
||||
errCh := make(chan error, numGoroutines)
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
go func() {
|
||||
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
|
||||
}()
|
||||
}
|
||||
|
||||
// 等待所有 goroutine 完成
|
||||
for i := 0; i < numGoroutines; i++ {
|
||||
err := <-errCh
|
||||
s.Require().NoError(err, "IncrementUsage should succeed")
|
||||
}
|
||||
|
||||
// 验证累加结果正确
|
||||
got, err := s.repo.GetByID(s.ctx, sub.ID)
|
||||
s.Require().NoError(err)
|
||||
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
|
||||
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
|
||||
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
|
||||
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
|
||||
}
|
||||
|
||||
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
|
||||
baseClient := testEntClient(s.T())
|
||||
tx, err := baseClient.Tx(context.Background())
|
||||
s.Require().NoError(err, "begin tx")
|
||||
defer func() {
|
||||
if tx != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
txCtx := dbent.NewTxContext(context.Background(), tx)
|
||||
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
|
||||
|
||||
userEnt, err := tx.Client().User.Create().
|
||||
SetEmail("tx-user-" + suffix + "@example.com").
|
||||
SetPasswordHash("test").
|
||||
Save(txCtx)
|
||||
s.Require().NoError(err, "create user in tx")
|
||||
|
||||
groupEnt, err := tx.Client().Group.Create().
|
||||
SetName("tx-group-" + suffix).
|
||||
Save(txCtx)
|
||||
s.Require().NoError(err, "create group in tx")
|
||||
|
||||
repo := NewUserSubscriptionRepository(baseClient)
|
||||
sub := &service.UserSubscription{
|
||||
UserID: userEnt.ID,
|
||||
GroupID: groupEnt.ID,
|
||||
ExpiresAt: time.Now().AddDate(0, 0, 30),
|
||||
Status: service.SubscriptionStatusActive,
|
||||
AssignedAt: time.Now(),
|
||||
Notes: "tx",
|
||||
}
|
||||
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
|
||||
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
|
||||
|
||||
s.Require().NoError(tx.Rollback(), "rollback tx")
|
||||
tx = nil
|
||||
|
||||
_, err = repo.GetByID(context.Background(), sub.ID)
|
||||
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
|
||||
}
|
||||
126
internal/repository/webhook_repository.go
Normal file
126
internal/repository/webhook_repository.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WebhookRepository Webhook 持久化仓储
|
||||
type WebhookRepository struct {
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
// NewWebhookRepository 创建 Webhook 仓储
|
||||
func NewWebhookRepository(db *gorm.DB) *WebhookRepository {
|
||||
return &WebhookRepository{db: db}
|
||||
}
|
||||
|
||||
// Create 创建 Webhook
|
||||
func (r *WebhookRepository) Create(ctx context.Context, wh *domain.Webhook) error {
|
||||
// GORM omits zero values on insert for fields with DB defaults. Explicitly
|
||||
// backfill inactive status so repository callers can persist status=0.
|
||||
requestedStatus := wh.Status
|
||||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Create(wh).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
if requestedStatus == domain.WebhookStatusInactive {
|
||||
if err := tx.Model(&domain.Webhook{}).Where("id = ?", wh.ID).Update("status", requestedStatus).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
wh.Status = requestedStatus
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Update 更新 Webhook 字段(只更新 updates map 中的字段)
|
||||
func (r *WebhookRepository) Update(ctx context.Context, id int64, updates map[string]interface{}) error {
|
||||
return r.db.WithContext(ctx).
|
||||
Model(&domain.Webhook{}).
|
||||
Where("id = ?", id).
|
||||
Updates(updates).Error
|
||||
}
|
||||
|
||||
// Delete 删除 Webhook(软删除)
|
||||
func (r *WebhookRepository) Delete(ctx context.Context, id int64) error {
|
||||
return r.db.WithContext(ctx).Delete(&domain.Webhook{}, id).Error
|
||||
}
|
||||
|
||||
// GetByID 按 ID 获取 Webhook
|
||||
func (r *WebhookRepository) GetByID(ctx context.Context, id int64) (*domain.Webhook, error) {
|
||||
var wh domain.Webhook
|
||||
err := r.db.WithContext(ctx).First(&wh, id).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &wh, nil
|
||||
}
|
||||
|
||||
// ListByCreator 按创建者列出 Webhook(createdBy=0 表示列出所有)
|
||||
func (r *WebhookRepository) ListByCreator(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
|
||||
var webhooks []*domain.Webhook
|
||||
query := r.db.WithContext(ctx)
|
||||
if createdBy > 0 {
|
||||
query = query.Where("created_by = ?", createdBy)
|
||||
}
|
||||
if err := query.Find(&webhooks).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return webhooks, nil
|
||||
}
|
||||
|
||||
// ListByCreatorPaginated 按创建者分页列出 Webhook(createdBy=0 表示列出所有)
|
||||
func (r *WebhookRepository) ListByCreatorPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
|
||||
var webhooks []*domain.Webhook
|
||||
var total int64
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Webhook{})
|
||||
if createdBy > 0 {
|
||||
query = query.Where("created_by = ?", createdBy)
|
||||
}
|
||||
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if offset > 0 {
|
||||
query = query.Offset(offset)
|
||||
}
|
||||
if limit > 0 {
|
||||
query = query.Limit(limit)
|
||||
}
|
||||
|
||||
if err := query.Order("created_at DESC").Find(&webhooks).Error; err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return webhooks, total, nil
|
||||
}
|
||||
|
||||
// ListActive 列出所有状态为活跃的 Webhook
|
||||
func (r *WebhookRepository) ListActive(ctx context.Context) ([]*domain.Webhook, error) {
|
||||
var webhooks []*domain.Webhook
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("status = ?", domain.WebhookStatusActive).
|
||||
Find(&webhooks).Error
|
||||
return webhooks, err
|
||||
}
|
||||
|
||||
// CreateDelivery 记录投递日志
|
||||
func (r *WebhookRepository) CreateDelivery(ctx context.Context, delivery *domain.WebhookDelivery) error {
|
||||
return r.db.WithContext(ctx).Create(delivery).Error
|
||||
}
|
||||
|
||||
// ListDeliveries 按 Webhook ID 分页查询投递记录(最新在前)
|
||||
func (r *WebhookRepository) ListDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
|
||||
var deliveries []*domain.WebhookDelivery
|
||||
err := r.db.WithContext(ctx).
|
||||
Where("webhook_id = ?", webhookID).
|
||||
Order("created_at DESC").
|
||||
Limit(limit).
|
||||
Find(&deliveries).Error
|
||||
return deliveries, err
|
||||
}
|
||||
190
internal/repository/webhook_repository_test.go
Normal file
190
internal/repository/webhook_repository_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func setupWebhookRepository(t *testing.T) *WebhookRepository {
|
||||
t.Helper()
|
||||
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.Webhook{}, &domain.WebhookDelivery{}); err != nil {
|
||||
t.Fatalf("migrate webhook tables failed: %v", err)
|
||||
}
|
||||
|
||||
return NewWebhookRepository(db)
|
||||
}
|
||||
|
||||
func newWebhookFixture(name string, createdBy int64, status domain.WebhookStatus) *domain.Webhook {
|
||||
return &domain.Webhook{
|
||||
Name: name,
|
||||
URL: "https://example.com/webhook",
|
||||
Secret: "secret-demo",
|
||||
Events: `["user.registered"]`,
|
||||
Status: status,
|
||||
MaxRetries: 3,
|
||||
TimeoutSec: 10,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRepositoryCreateGetUpdateAndDelete(t *testing.T) {
|
||||
repo := setupWebhookRepository(t)
|
||||
ctx := context.Background()
|
||||
|
||||
webhook := newWebhookFixture("alpha", 101, domain.WebhookStatusActive)
|
||||
if err := repo.Create(ctx, webhook); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
if webhook.ID == 0 {
|
||||
t.Fatal("expected webhook id to be assigned")
|
||||
}
|
||||
|
||||
loaded, err := repo.GetByID(ctx, webhook.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID failed: %v", err)
|
||||
}
|
||||
if loaded.Name != "alpha" {
|
||||
t.Fatalf("expected loaded webhook name alpha, got %q", loaded.Name)
|
||||
}
|
||||
|
||||
if err := repo.Update(ctx, webhook.ID, map[string]interface{}{
|
||||
"name": "alpha-updated",
|
||||
"status": domain.WebhookStatusInactive,
|
||||
}); err != nil {
|
||||
t.Fatalf("Update failed: %v", err)
|
||||
}
|
||||
|
||||
updated, err := repo.GetByID(ctx, webhook.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID after update failed: %v", err)
|
||||
}
|
||||
if updated.Name != "alpha-updated" {
|
||||
t.Fatalf("expected updated name alpha-updated, got %q", updated.Name)
|
||||
}
|
||||
if updated.Status != domain.WebhookStatusInactive {
|
||||
t.Fatalf("expected updated status inactive, got %d", updated.Status)
|
||||
}
|
||||
|
||||
if err := repo.Delete(ctx, webhook.ID); err != nil {
|
||||
t.Fatalf("Delete failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := repo.GetByID(ctx, webhook.ID); err == nil {
|
||||
t.Fatal("expected deleted webhook lookup to fail")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRepositoryListsByCreatorAndActiveStatus(t *testing.T) {
|
||||
repo := setupWebhookRepository(t)
|
||||
ctx := context.Background()
|
||||
|
||||
fixtures := []*domain.Webhook{
|
||||
newWebhookFixture("creator-1-active", 1, domain.WebhookStatusActive),
|
||||
newWebhookFixture("creator-1-inactive", 1, domain.WebhookStatusInactive),
|
||||
newWebhookFixture("creator-2-active", 2, domain.WebhookStatusActive),
|
||||
}
|
||||
|
||||
for _, webhook := range fixtures {
|
||||
if err := repo.Create(ctx, webhook); err != nil {
|
||||
t.Fatalf("Create(%s) failed: %v", webhook.Name, err)
|
||||
}
|
||||
}
|
||||
|
||||
creatorOneHooks, err := repo.ListByCreator(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByCreator(1) failed: %v", err)
|
||||
}
|
||||
if len(creatorOneHooks) != 2 {
|
||||
t.Fatalf("expected 2 hooks for creator 1, got %d", len(creatorOneHooks))
|
||||
}
|
||||
|
||||
allHooks, err := repo.ListByCreator(ctx, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByCreator(0) failed: %v", err)
|
||||
}
|
||||
if len(allHooks) != 3 {
|
||||
t.Fatalf("expected 3 hooks when listing all creators, got %d", len(allHooks))
|
||||
}
|
||||
|
||||
activeHooks, err := repo.ListActive(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListActive failed: %v", err)
|
||||
}
|
||||
if len(activeHooks) != 2 {
|
||||
t.Fatalf("expected 2 active hooks, got %d", len(activeHooks))
|
||||
}
|
||||
for _, hook := range activeHooks {
|
||||
if hook.Status != domain.WebhookStatusActive {
|
||||
t.Fatalf("expected active hook status, got %d", hook.Status)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRepositoryCreateAndListDeliveries(t *testing.T) {
|
||||
repo := setupWebhookRepository(t)
|
||||
ctx := context.Background()
|
||||
|
||||
webhook := newWebhookFixture("delivery-hook", 7, domain.WebhookStatusActive)
|
||||
if err := repo.Create(ctx, webhook); err != nil {
|
||||
t.Fatalf("Create webhook failed: %v", err)
|
||||
}
|
||||
|
||||
olderTime := time.Now().Add(-time.Minute)
|
||||
newerTime := time.Now()
|
||||
|
||||
firstDelivery := &domain.WebhookDelivery{
|
||||
WebhookID: webhook.ID,
|
||||
EventType: domain.EventUserRegistered,
|
||||
Payload: `{"user":"older"}`,
|
||||
StatusCode: 200,
|
||||
ResponseBody: `{"ok":true}`,
|
||||
Attempt: 1,
|
||||
Success: true,
|
||||
CreatedAt: olderTime,
|
||||
}
|
||||
secondDelivery := &domain.WebhookDelivery{
|
||||
WebhookID: webhook.ID,
|
||||
EventType: domain.EventUserLogin,
|
||||
Payload: `{"user":"newer"}`,
|
||||
StatusCode: 500,
|
||||
ResponseBody: `{"ok":false}`,
|
||||
Attempt: 2,
|
||||
Success: false,
|
||||
Error: "delivery failed",
|
||||
CreatedAt: newerTime,
|
||||
}
|
||||
|
||||
if err := repo.CreateDelivery(ctx, firstDelivery); err != nil {
|
||||
t.Fatalf("CreateDelivery(first) failed: %v", err)
|
||||
}
|
||||
if err := repo.CreateDelivery(ctx, secondDelivery); err != nil {
|
||||
t.Fatalf("CreateDelivery(second) failed: %v", err)
|
||||
}
|
||||
|
||||
latestOnly, err := repo.ListDeliveries(ctx, webhook.ID, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("ListDeliveries(limit=1) failed: %v", err)
|
||||
}
|
||||
if len(latestOnly) != 1 {
|
||||
t.Fatalf("expected 1 latest delivery, got %d", len(latestOnly))
|
||||
}
|
||||
if latestOnly[0].ID != secondDelivery.ID {
|
||||
t.Fatalf("expected latest delivery id %d, got %d", secondDelivery.ID, latestOnly[0].ID)
|
||||
}
|
||||
|
||||
allDeliveries, err := repo.ListDeliveries(ctx, webhook.ID, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListDeliveries(limit=10) failed: %v", err)
|
||||
}
|
||||
if len(allDeliveries) != 2 {
|
||||
t.Fatalf("expected 2 deliveries, got %d", len(allDeliveries))
|
||||
}
|
||||
if allDeliveries[0].ID != secondDelivery.ID || allDeliveries[1].ID != firstDelivery.ID {
|
||||
t.Fatal("expected deliveries to be returned in reverse created_at order")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user