feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,145 @@
//go:build integration
package repository
import (
"context"
"fmt"
"strings"
"testing"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/require"
)
func uniqueTestValue(t *testing.T, prefix string) string {
t.Helper()
safeName := strings.NewReplacer("/", "_", " ", "_").Replace(t.Name())
return fmt.Sprintf("%s-%s", prefix, safeName)
}
func TestUserRepository_RemoveGroupFromAllowedGroups_RemovesAllOccurrences(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "target-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "other-group")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
repo := newUserRepositoryWithSQL(entClient, tx)
u1 := &service.User{
Email: uniqueTestValue(t, "u1") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u1))
u2 := &service.User{
Email: uniqueTestValue(t, "u2") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID},
}
require.NoError(t, repo.Create(ctx, u2))
u3 := &service.User{
Email: uniqueTestValue(t, "u3") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{otherGroup.ID},
}
require.NoError(t, repo.Create(ctx, u3))
affected, err := repo.RemoveGroupFromAllowedGroups(ctx, targetGroup.ID)
require.NoError(t, err)
require.Equal(t, int64(2), affected)
u1After, err := repo.GetByID(ctx, u1.ID)
require.NoError(t, err)
require.NotContains(t, u1After.AllowedGroups, targetGroup.ID)
require.Contains(t, u1After.AllowedGroups, otherGroup.ID)
u2After, err := repo.GetByID(ctx, u2.ID)
require.NoError(t, err)
require.NotContains(t, u2After.AllowedGroups, targetGroup.ID)
}
func TestGroupRepository_DeleteCascade_RemovesAllowedGroupsAndClearsApiKeys(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
entClient := tx.Client()
targetGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-target")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
otherGroup, err := entClient.Group.Create().
SetName(uniqueTestValue(t, "delete-cascade-other")).
SetStatus(service.StatusActive).
Save(ctx)
require.NoError(t, err)
userRepo := newUserRepositoryWithSQL(entClient, tx)
groupRepo := newGroupRepositoryWithSQL(entClient, tx)
apiKeyRepo := newAPIKeyRepositoryWithSQL(entClient, tx)
u := &service.User{
Email: uniqueTestValue(t, "cascade-user") + "@example.com",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
Concurrency: 5,
AllowedGroups: []int64{targetGroup.ID, otherGroup.ID},
}
require.NoError(t, userRepo.Create(ctx, u))
key := &service.APIKey{
UserID: u.ID,
Key: uniqueTestValue(t, "sk-test-delete-cascade"),
Name: "test key",
GroupID: &targetGroup.ID,
Status: service.StatusActive,
}
require.NoError(t, apiKeyRepo.Create(ctx, key))
_, err = groupRepo.DeleteCascade(ctx, targetGroup.ID)
require.NoError(t, err)
// Deleted group should be hidden by default queries (soft-delete semantics).
_, err = groupRepo.GetByID(ctx, targetGroup.ID)
require.ErrorIs(t, err, service.ErrGroupNotFound)
activeGroups, err := groupRepo.ListActive(ctx)
require.NoError(t, err)
for _, g := range activeGroups {
require.NotEqual(t, targetGroup.ID, g.ID)
}
// User.allowed_groups should no longer include the deleted group.
uAfter, err := userRepo.GetByID(ctx, u.ID)
require.NoError(t, err)
require.NotContains(t, uAfter.AllowedGroups, targetGroup.ID)
require.Contains(t, uAfter.AllowedGroups, otherGroup.ID)
// API keys bound to the deleted group should have group_id cleared.
keyAfter, err := apiKeyRepo.GetByID(ctx, key.ID)
require.NoError(t, err)
require.Nil(t, keyAfter.GroupID)
}

View File

@@ -0,0 +1,367 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type BillingCacheSuite struct {
IntegrationRedisSuite
}
func (s *BillingCacheSuite) TestUserBalance() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
_, err := cache.GetUserBalance(ctx, 1)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing balance key")
},
},
{
name: "deduct_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(1)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 1), "DeductUserBalance should not error")
_, err := rdb.Get(ctx, balanceKey).Result()
require.ErrorIs(s.T(), err, redis.Nil, "expected missing key after deduct on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(2)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 10.5, got, "balance mismatch")
ttl, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "deduct_reduces_balance",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(3)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 10.5), "SetUserBalance")
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 2.25), "DeductUserBalance")
got, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance after deduct")
require.Equal(s.T(), 8.25, got, "deduct mismatch")
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(100)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 50.0), "SetUserBalance")
exists, err := rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected balance key to exist")
require.NoError(s.T(), cache.InvalidateUserBalance(ctx, userID), "InvalidateUserBalance")
exists, err = rdb.Exists(ctx, balanceKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected balance key to be removed after invalidate")
_, err = cache.GetUserBalance(ctx, userID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "deduct_refreshes_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(103)
balanceKey := fmt.Sprintf("%s%d", billingBalanceKeyPrefix, userID)
require.NoError(s.T(), cache.SetUserBalance(ctx, userID, 100.0), "SetUserBalance")
ttl1, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL before deduct")
s.AssertTTLWithin(ttl1, 1*time.Second, billingCacheTTL)
require.NoError(s.T(), cache.DeductUserBalance(ctx, userID, 25.0), "DeductUserBalance")
balance, err := cache.GetUserBalance(ctx, userID)
require.NoError(s.T(), err, "GetUserBalance")
require.Equal(s.T(), 75.0, balance, "expected balance 75.0")
ttl2, err := rdb.TTL(ctx, balanceKey).Result()
require.NoError(s.T(), err, "TTL after deduct")
s.AssertTTLWithin(ttl2, 1*time.Second, billingCacheTTL)
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
func (s *BillingCacheSuite) TestSubscriptionCache() {
tests := []struct {
name string
fn func(ctx context.Context, rdb *redis.Client, cache service.BillingCache)
}{
{
name: "missing_key_returns_redis_nil",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(10)
groupID := int64(20)
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil for missing subscription key")
},
},
{
name: "update_usage_on_nonexistent_is_noop",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(11)
groupID := int64(21)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 1.0), "UpdateSubscriptionUsage should not error")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(0), exists, "expected missing subscription key after UpdateSubscriptionUsage on non-existent")
},
},
{
name: "set_and_get_with_ttl",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(12)
groupID := int64(22)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 7,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache")
require.Equal(s.T(), "active", gotSub.Status)
require.Equal(s.T(), int64(7), gotSub.Version)
require.Equal(s.T(), 1.0, gotSub.DailyUsage)
ttl, err := rdb.TTL(ctx, subKey).Result()
require.NoError(s.T(), err, "TTL subKey")
s.AssertTTLWithin(ttl, 1*time.Second, billingCacheTTL)
},
},
{
name: "update_usage_increments_all_fields",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(13)
groupID := int64(23)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
require.NoError(s.T(), cache.UpdateSubscriptionUsage(ctx, userID, groupID, 0.5), "UpdateSubscriptionUsage")
gotSub, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.NoError(s.T(), err, "GetSubscriptionCache after update")
require.Equal(s.T(), 1.5, gotSub.DailyUsage)
require.Equal(s.T(), 2.5, gotSub.WeeklyUsage)
require.Equal(s.T(), 3.5, gotSub.MonthlyUsage)
},
},
{
name: "invalidate_removes_key",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(101)
groupID := int64(10)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
DailyUsage: 1.0,
WeeklyUsage: 2.0,
MonthlyUsage: 3.0,
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, userID, groupID, data), "SetSubscriptionCache")
exists, err := rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists")
require.Equal(s.T(), int64(1), exists, "expected subscription key to exist")
require.NoError(s.T(), cache.InvalidateSubscriptionCache(ctx, userID, groupID), "InvalidateSubscriptionCache")
exists, err = rdb.Exists(ctx, subKey).Result()
require.NoError(s.T(), err, "Exists after invalidate")
require.Equal(s.T(), int64(0), exists, "expected subscription key to be removed after invalidate")
_, err = cache.GetSubscriptionCache(ctx, userID, groupID)
require.ErrorIs(s.T(), err, redis.Nil, "expected redis.Nil after invalidate")
},
},
{
name: "missing_status_returns_parsing_error",
fn: func(ctx context.Context, rdb *redis.Client, cache service.BillingCache) {
userID := int64(102)
groupID := int64(11)
subKey := fmt.Sprintf("%s%d:%d", billingSubKeyPrefix, userID, groupID)
fields := map[string]any{
"expires_at": time.Now().Add(1 * time.Hour).Unix(),
"daily_usage": 1.0,
"weekly_usage": 2.0,
"monthly_usage": 3.0,
"version": 1,
}
require.NoError(s.T(), rdb.HSet(ctx, subKey, fields).Err(), "HSet")
_, err := cache.GetSubscriptionCache(ctx, userID, groupID)
require.Error(s.T(), err, "expected error for missing status field")
require.NotErrorIs(s.T(), err, redis.Nil, "expected parsing error, not redis.Nil")
require.Equal(s.T(), "invalid cache: missing status", err.Error())
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, rdb, cache)
})
}
}
// TestDeductUserBalance_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播key 不存在redis.Nil应返回 nil。
func (s *BillingCacheSuite) TestDeductUserBalance_ErrorPropagation() {
tests := []struct {
name string
fn func(ctx context.Context, cache service.BillingCache)
expectErr bool
}{
{
name: "key_not_exists_returns_nil",
fn: func(ctx context.Context, cache service.BillingCache) {
// key 不存在时Lua 脚本返回 0redis.Nil应返回 nil 而非错误
err := cache.DeductUserBalance(ctx, 99999, 1.0)
require.NoError(s.T(), err, "DeductUserBalance on non-existent key should return nil")
},
},
{
name: "existing_key_deducts_successfully",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 200, 50.0))
err := cache.DeductUserBalance(ctx, 200, 10.0)
require.NoError(s.T(), err, "DeductUserBalance should succeed")
bal, err := cache.GetUserBalance(ctx, 200)
require.NoError(s.T(), err)
require.Equal(s.T(), 40.0, bal, "余额应为 40.0")
},
},
{
name: "cancelled_context_propagates_error",
fn: func(ctx context.Context, cache service.BillingCache) {
require.NoError(s.T(), cache.SetUserBalance(ctx, 201, 50.0))
cancelCtx, cancel := context.WithCancel(ctx)
cancel() // 立即取消
err := cache.DeductUserBalance(cancelCtx, 201, 10.0)
require.Error(s.T(), err, "cancelled context should propagate error")
},
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
tt.fn(ctx, cache)
})
}
}
// TestUpdateSubscriptionUsage_ErrorPropagation 验证 P2-12 修复:
// Redis 真实错误应传播key 不存在redis.Nil应返回 nil。
func (s *BillingCacheSuite) TestUpdateSubscriptionUsage_ErrorPropagation() {
s.Run("key_not_exists_returns_nil", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
err := cache.UpdateSubscriptionUsage(ctx, 88888, 77777, 1.0)
require.NoError(s.T(), err, "UpdateSubscriptionUsage on non-existent key should return nil")
})
s.Run("cancelled_context_propagates_error", func() {
rdb := testRedis(s.T())
cache := NewBillingCache(rdb)
ctx := context.Background()
data := &service.SubscriptionCacheData{
Status: "active",
ExpiresAt: time.Now().Add(1 * time.Hour),
Version: 1,
}
require.NoError(s.T(), cache.SetSubscriptionCache(ctx, 301, 401, data))
cancelCtx, cancel := context.WithCancel(ctx)
cancel()
err := cache.UpdateSubscriptionUsage(cancelCtx, 301, 401, 1.0)
require.Error(s.T(), err, "cancelled context should propagate error")
})
}
func TestBillingCacheSuite(t *testing.T) {
suite.Run(t, new(BillingCacheSuite))
}

View File

@@ -0,0 +1,111 @@
//go:build unit
package repository
import (
"math"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestBillingBalanceKey(t *testing.T) {
tests := []struct {
name string
userID int64
expected string
}{
{
name: "normal_user_id",
userID: 123,
expected: "billing:balance:123",
},
{
name: "zero_user_id",
userID: 0,
expected: "billing:balance:0",
},
{
name: "negative_user_id",
userID: -1,
expected: "billing:balance:-1",
},
{
name: "max_int64",
userID: math.MaxInt64,
expected: "billing:balance:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingBalanceKey(tc.userID)
require.Equal(t, tc.expected, got)
})
}
}
func TestBillingSubKey(t *testing.T) {
tests := []struct {
name string
userID int64
groupID int64
expected string
}{
{
name: "normal_ids",
userID: 123,
groupID: 456,
expected: "billing:sub:123:456",
},
{
name: "zero_ids",
userID: 0,
groupID: 0,
expected: "billing:sub:0:0",
},
{
name: "negative_ids",
userID: -1,
groupID: -2,
expected: "billing:sub:-1:-2",
},
{
name: "max_int64_ids",
userID: math.MaxInt64,
groupID: math.MaxInt64,
expected: "billing:sub:9223372036854775807:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := billingSubKey(tc.userID, tc.groupID)
require.Equal(t, tc.expected, got)
})
}
}
func TestJitteredTTL(t *testing.T) {
const (
minTTL = 4*time.Minute + 30*time.Second // 270s = 5min - 30s
maxTTL = 5*time.Minute + 30*time.Second // 330s = 5min + 30s
)
for i := 0; i < 200; i++ {
ttl := jitteredTTL()
require.GreaterOrEqual(t, ttl, minTTL, "jitteredTTL() 返回值低于下限: %v", ttl)
require.LessOrEqual(t, ttl, maxTTL, "jitteredTTL() 返回值超过上限: %v", ttl)
}
}
func TestJitteredTTL_HasVariation(t *testing.T) {
// 多次调用应该产生不同的值(验证抖动存在)
seen := make(map[time.Duration]struct{}, 50)
for i := 0; i < 50; i++ {
seen[jitteredTTL()] = struct{}{}
}
// 50 次调用中应该至少有 2 个不同的值
require.Greater(t, len(seen), 1, "jitteredTTL() 应产生不同的 TTL 值")
}

View File

@@ -0,0 +1,487 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
// 测试用 TTL 配置15 分钟,与默认值一致)
const testSlotTTLMinutes = 15
// 测试用 TTL Duration用于 TTL 断言
var testSlotTTL = time.Duration(testSlotTTLMinutes) * time.Minute
type ConcurrencyCacheSuite struct {
IntegrationRedisSuite
cache service.ConcurrencyCache
}
func TestConcurrencyCacheSuite(t *testing.T) {
suite.Run(t, new(ConcurrencyCacheSuite))
}
func (s *ConcurrencyCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewConcurrencyCache(s.rdb, testSlotTTLMinutes, int(testSlotTTL.Seconds()))
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_AcquireAndRelease() {
accountID := int64(10)
reqID1, reqID2, reqID3 := "req1", "req2", "req3"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID1)
require.NoError(s.T(), err, "AcquireAccountSlot 1")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID2)
require.NoError(s.T(), err, "AcquireAccountSlot 2")
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID3)
require.NoError(s.T(), err, "AcquireAccountSlot 3")
require.False(s.T(), ok, "expected third acquire to fail")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency")
require.Equal(s.T(), 2, cur, "concurrency mismatch")
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID1), "ReleaseAccountSlot")
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err, "GetAccountConcurrency after release")
require.Equal(s.T(), 1, cur, "expected 1 after release")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_TTL() {
accountID := int64(11)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, reqID)
require.NoError(s.T(), err, "AcquireAccountSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_DuplicateReqID() {
accountID := int64(12)
reqID := "dup-req"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Acquiring with same reqID should be idempotent
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 2, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur, "expected concurrency=1 (idempotent)")
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_ReleaseIdempotent() {
accountID := int64(13)
reqID := "release-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 1, reqID)
require.NoError(s.T(), err)
require.True(s.T(), ok)
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot")
// Releasing again should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, reqID), "ReleaseAccountSlot again")
// Releasing non-existent should not error
require.NoError(s.T(), s.cache.ReleaseAccountSlot(s.ctx, accountID, "non-existent"), "ReleaseAccountSlot non-existent")
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestAccountSlot_MaxZero() {
accountID := int64(14)
reqID := "max-zero-test"
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 0, reqID)
require.NoError(s.T(), err)
require.False(s.T(), ok, "expected acquire to fail with max=0")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_AcquireAndRelease() {
userID := int64(42)
reqID1, reqID2 := "req1", "req2"
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID1)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ok, err = s.cache.AcquireUserSlot(s.ctx, userID, 1, reqID2)
require.NoError(s.T(), err, "AcquireUserSlot 2")
require.False(s.T(), ok, "expected second acquire to fail at max=1")
cur, err := s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency")
require.Equal(s.T(), 1, cur, "expected concurrency=1")
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, reqID1), "ReleaseUserSlot")
// Releasing a non-existent slot should not error
require.NoError(s.T(), s.cache.ReleaseUserSlot(s.ctx, userID, "non-existent"), "ReleaseUserSlot non-existent")
cur, err = s.cache.GetUserConcurrency(s.ctx, userID)
require.NoError(s.T(), err, "GetUserConcurrency after release")
require.Equal(s.T(), 0, cur, "expected concurrency=0 after release")
}
func (s *ConcurrencyCacheSuite) TestUserSlot_TTL() {
userID := int64(200)
reqID := "req_ttl_test"
slotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
ok, err := s.cache.AcquireUserSlot(s.ctx, userID, 5, reqID)
require.NoError(s.T(), err, "AcquireUserSlot")
require.True(s.T(), ok)
ttl, err := s.rdb.TTL(s.ctx, slotKey).Result()
require.NoError(s.T(), err, "TTL")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_IncrementAndDecrement() {
userID := int64(20)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementWaitCount(s.ctx, userID, 2)
require.NoError(s.T(), err, "IncrementWaitCount 3")
require.False(s.T(), ok, "expected wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected wait count 1")
}
func (s *ConcurrencyCacheSuite) TestWaitQueue_DecrementNoNegative() {
userID := int64(300)
waitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
// Test decrement on non-existent key - should not error and should not create negative value
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on non-existent key")
// Verify no key was created or it's not negative
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count after decrement on empty")
// Set count to 1, then decrement twice
ok, err := s.cache.IncrementWaitCount(s.ctx, userID, 5)
require.NoError(s.T(), err, "IncrementWaitCount")
require.True(s.T(), ok)
// Decrement once (1 -> 0)
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount")
// Decrement again on 0 - should not go negative
require.NoError(s.T(), s.cache.DecrementWaitCount(s.ctx, userID), "DecrementWaitCount on zero")
// Verify count is 0, not negative
val, err = s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey after double decrement")
}
require.GreaterOrEqual(s.T(), val, 0, "expected non-negative wait count")
}
func (s *ConcurrencyCacheSuite) TestAccountWaitQueue_IncrementAndDecrement() {
accountID := int64(30)
waitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
ok, err := s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 1")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 2")
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, accountID, 2)
require.NoError(s.T(), err, "IncrementAccountWaitCount 3")
require.False(s.T(), ok, "expected account wait increment over max to fail")
ttl, err := s.rdb.TTL(s.ctx, waitKey).Result()
require.NoError(s.T(), err, "TTL account waitKey")
s.AssertTTLWithin(ttl, 1*time.Second, testSlotTTL)
require.NoError(s.T(), s.cache.DecrementAccountWaitCount(s.ctx, accountID), "DecrementAccountWaitCount")
val, err := s.rdb.Get(s.ctx, waitKey).Int()
if !errors.Is(err, redis.Nil) {
require.NoError(s.T(), err, "Get waitKey")
}
require.Equal(s.T(), 1, val, "expected account wait count 1")
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots() {
accountID := int64(901)
userID := int64(902)
accountKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := time.Now().Unix()
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountKey,
redis.Z{Score: float64(now), Member: "oldproc-1"},
redis.Z{Score: float64(now), Member: "keep-1"},
).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userKey,
redis.Z{Score: float64(now), Member: "oldproc-2"},
redis.Z{Score: float64(now), Member: "keep-2"},
).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, time.Minute).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, time.Minute).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "keep-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"keep-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.True(s.T(), errors.Is(err, redis.Nil))
}
func (s *ConcurrencyCacheSuite) TestGetAccountConcurrency_Missing() {
// When no slots exist, GetAccountConcurrency should return 0
cur, err := s.cache.GetAccountConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetUserConcurrency_Missing() {
// When no slots exist, GetUserConcurrency should return 0
cur, err := s.cache.GetUserConcurrency(s.ctx, 999)
require.NoError(s.T(), err)
require.Equal(s.T(), 0, cur)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch() {
s.T().Skip("TODO: Fix this test - CurrentConcurrency returns 0 instead of expected value in CI")
// Setup: Create accounts with different load states
account1 := int64(100)
account2 := int64(101)
account3 := int64(102)
// Account 1: 2/3 slots used, 1 waiting
ok, err := s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, account1, 3, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.IncrementAccountWaitCount(s.ctx, account1, 5)
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 2: 1/2 slots used, 0 waiting
ok, err = s.cache.AcquireAccountSlot(s.ctx, account2, 2, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Account 3: 0/1 slots used, 0 waiting (idle)
// Query batch load
accounts := []service.AccountWithConcurrency{
{ID: account1, MaxConcurrency: 3},
{ID: account2, MaxConcurrency: 2},
{ID: account3, MaxConcurrency: 1},
}
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, accounts)
require.NoError(s.T(), err)
require.Len(s.T(), loadMap, 3)
// Verify account1: (2 + 1) / 3 = 100%
load1 := loadMap[account1]
require.NotNil(s.T(), load1)
require.Equal(s.T(), account1, load1.AccountID)
require.Equal(s.T(), 2, load1.CurrentConcurrency)
require.Equal(s.T(), 1, load1.WaitingCount)
require.Equal(s.T(), 100, load1.LoadRate)
// Verify account2: (1 + 0) / 2 = 50%
load2 := loadMap[account2]
require.NotNil(s.T(), load2)
require.Equal(s.T(), account2, load2.AccountID)
require.Equal(s.T(), 1, load2.CurrentConcurrency)
require.Equal(s.T(), 0, load2.WaitingCount)
require.Equal(s.T(), 50, load2.LoadRate)
// Verify account3: (0 + 0) / 1 = 0%
load3 := loadMap[account3]
require.NotNil(s.T(), load3)
require.Equal(s.T(), account3, load3.AccountID)
require.Equal(s.T(), 0, load3.CurrentConcurrency)
require.Equal(s.T(), 0, load3.WaitingCount)
require.Equal(s.T(), 0, load3.LoadRate)
}
func (s *ConcurrencyCacheSuite) TestGetAccountsLoadBatch_Empty() {
// Test with empty account list
loadMap, err := s.cache.GetAccountsLoadBatch(s.ctx, []service.AccountWithConcurrency{})
require.NoError(s.T(), err)
require.Empty(s.T(), loadMap)
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots() {
accountID := int64(200)
slotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
// Acquire 3 slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req3")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Verify 3 slots exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 3, cur)
// Manually set old timestamps for req1 and req2 (simulate expired slots)
now := time.Now().Unix()
expiredTime := now - int64(testSlotTTL.Seconds()) - 10 // 10 seconds past TTL
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req1"}).Err()
require.NoError(s.T(), err)
err = s.rdb.ZAdd(s.ctx, slotKey, redis.Z{Score: float64(expiredTime), Member: "req2"}).Err()
require.NoError(s.T(), err)
// Run cleanup
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify only 1 slot remains (req3)
cur, err = s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 1, cur)
// Verify req3 still exists
members, err := s.rdb.ZRange(s.ctx, slotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Len(s.T(), members, 1)
require.Equal(s.T(), "req3", members[0])
}
func (s *ConcurrencyCacheSuite) TestCleanupExpiredAccountSlots_NoExpired() {
accountID := int64(201)
// Acquire 2 fresh slots
ok, err := s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req1")
require.NoError(s.T(), err)
require.True(s.T(), ok)
ok, err = s.cache.AcquireAccountSlot(s.ctx, accountID, 5, "req2")
require.NoError(s.T(), err)
require.True(s.T(), ok)
// Run cleanup (should not remove anything)
err = s.cache.CleanupExpiredAccountSlots(s.ctx, accountID)
require.NoError(s.T(), err)
// Verify both slots still exist
cur, err := s.cache.GetAccountConcurrency(s.ctx, accountID)
require.NoError(s.T(), err)
require.Equal(s.T(), 2, cur)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_RemovesOldPrefixesAndWaitCounters() {
accountID := int64(901)
userID := int64(902)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
userSlotKey := fmt.Sprintf("%s%d", userSlotKeyPrefix, userID)
userWaitKey := fmt.Sprintf("%s%d", waitQueueKeyPrefix, userID)
accountWaitKey := fmt.Sprintf("%s%d", accountWaitKeyPrefix, accountID)
now := float64(time.Now().Unix())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey,
redis.Z{Score: now, Member: "oldproc-1"},
redis.Z{Score: now, Member: "activeproc-1"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, userSlotKey,
redis.Z{Score: now, Member: "oldproc-2"},
redis.Z{Score: now, Member: "activeproc-2"},
).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, userSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, userWaitKey, 3, testSlotTTL).Err())
require.NoError(s.T(), s.rdb.Set(s.ctx, accountWaitKey, 2, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
accountMembers, err := s.rdb.ZRange(s.ctx, accountSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-1"}, accountMembers)
userMembers, err := s.rdb.ZRange(s.ctx, userSlotKey, 0, -1).Result()
require.NoError(s.T(), err)
require.Equal(s.T(), []string{"activeproc-2"}, userMembers)
_, err = s.rdb.Get(s.ctx, userWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
_, err = s.rdb.Get(s.ctx, accountWaitKey).Result()
require.ErrorIs(s.T(), err, redis.Nil)
}
func (s *ConcurrencyCacheSuite) TestCleanupStaleProcessSlots_DeletesEmptySlotKeys() {
accountID := int64(903)
accountSlotKey := fmt.Sprintf("%s%d", accountSlotKeyPrefix, accountID)
require.NoError(s.T(), s.rdb.ZAdd(s.ctx, accountSlotKey, redis.Z{Score: float64(time.Now().Unix()), Member: "oldproc-1"}).Err())
require.NoError(s.T(), s.rdb.Expire(s.ctx, accountSlotKey, testSlotTTL).Err())
require.NoError(s.T(), s.cache.CleanupStaleProcessSlots(s.ctx, "activeproc-"))
exists, err := s.rdb.Exists(s.ctx, accountSlotKey).Result()
require.NoError(s.T(), err)
require.EqualValues(s.T(), 0, exists)
}

View File

@@ -0,0 +1,150 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// CustomFieldRepository 自定义字段数据访问层
type CustomFieldRepository struct {
db *gorm.DB
}
// NewCustomFieldRepository 创建自定义字段数据访问层
func NewCustomFieldRepository(db *gorm.DB) *CustomFieldRepository {
return &CustomFieldRepository{db: db}
}
// Create 创建自定义字段
func (r *CustomFieldRepository) Create(ctx context.Context, field *domain.CustomField) error {
return r.db.WithContext(ctx).Create(field).Error
}
// Update 更新自定义字段
func (r *CustomFieldRepository) Update(ctx context.Context, field *domain.CustomField) error {
return r.db.WithContext(ctx).Save(field).Error
}
// Delete 删除自定义字段
func (r *CustomFieldRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.CustomField{}, id).Error
}
// GetByID 根据ID获取自定义字段
func (r *CustomFieldRepository) GetByID(ctx context.Context, id int64) (*domain.CustomField, error) {
var field domain.CustomField
err := r.db.WithContext(ctx).First(&field, id).Error
if err != nil {
return nil, err
}
return &field, nil
}
// GetByFieldKey 根据FieldKey获取自定义字段
func (r *CustomFieldRepository) GetByFieldKey(ctx context.Context, fieldKey string) (*domain.CustomField, error) {
var field domain.CustomField
err := r.db.WithContext(ctx).Where("field_key = ?", fieldKey).First(&field).Error
if err != nil {
return nil, err
}
return &field, nil
}
// List 获取所有启用的自定义字段
func (r *CustomFieldRepository) List(ctx context.Context) ([]*domain.CustomField, error) {
var fields []*domain.CustomField
err := r.db.WithContext(ctx).Where("status = ?", 1).Order("sort ASC").Find(&fields).Error
if err != nil {
return nil, err
}
return fields, nil
}
// ListAll 获取所有自定义字段
func (r *CustomFieldRepository) ListAll(ctx context.Context) ([]*domain.CustomField, error) {
var fields []*domain.CustomField
err := r.db.WithContext(ctx).Order("sort ASC").Find(&fields).Error
if err != nil {
return nil, err
}
return fields, nil
}
// UserCustomFieldValueRepository 用户自定义字段值数据访问层
type UserCustomFieldValueRepository struct {
db *gorm.DB
}
// NewUserCustomFieldValueRepository 创建用户自定义字段值数据访问层
func NewUserCustomFieldValueRepository(db *gorm.DB) *UserCustomFieldValueRepository {
return &UserCustomFieldValueRepository{db: db}
}
// Set 为用户设置自定义字段值upsert
func (r *UserCustomFieldValueRepository) Set(ctx context.Context, userID int64, fieldID int64, fieldKey, value string) error {
return r.db.WithContext(ctx).Exec(`
INSERT INTO user_custom_field_values (user_id, field_id, field_key, value, created_at, updated_at)
VALUES (?, ?, ?, ?, NOW(), NOW())
ON CONFLICT(user_id, field_id) DO UPDATE SET value = ?, updated_at = NOW()
`, userID, fieldID, fieldKey, value, value).Error
}
// GetByUserID 获取用户的所有自定义字段值
func (r *UserCustomFieldValueRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserCustomFieldValue, error) {
var values []*domain.UserCustomFieldValue
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&values).Error
if err != nil {
return nil, err
}
return values, nil
}
// GetByUserIDAndFieldKey 获取用户指定字段的值
func (r *UserCustomFieldValueRepository) GetByUserIDAndFieldKey(ctx context.Context, userID int64, fieldKey string) (*domain.UserCustomFieldValue, error) {
var value domain.UserCustomFieldValue
err := r.db.WithContext(ctx).Where("user_id = ? AND field_key = ?", userID, fieldKey).First(&value).Error
if err != nil {
return nil, err
}
return &value, nil
}
// Delete 删除用户的自定义字段值
func (r *UserCustomFieldValueRepository) Delete(ctx context.Context, userID int64, fieldID int64) error {
return r.db.WithContext(ctx).Where("user_id = ? AND field_id = ?", userID, fieldID).Delete(&domain.UserCustomFieldValue{}).Error
}
// DeleteByUserID 删除用户的所有自定义字段值
func (r *UserCustomFieldValueRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserCustomFieldValue{}).Error
}
// BatchSet 批量设置用户的自定义字段值
func (r *UserCustomFieldValueRepository) BatchSet(ctx context.Context, userID int64, values map[string]string) error {
if len(values) == 0 {
return nil
}
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
for fieldKey, value := range values {
if err := tx.Exec(`
INSERT INTO user_custom_field_values (user_id, field_id, field_key, value, created_at, updated_at)
VALUES (
?,
(SELECT id FROM custom_fields WHERE field_key = ? LIMIT 1),
?,
?,
NOW(),
NOW()
)
ON CONFLICT(user_id, field_id) DO UPDATE SET value = ?, updated_at = NOW()
`, userID, fieldKey, fieldKey, value, value).Error; err != nil {
return err
}
}
return nil
})
}

View File

@@ -0,0 +1,32 @@
package repository
import (
"database/sql"
"time"
"github.com/user-management-system/internal/config"
)
type dbPoolSettings struct {
MaxOpenConns int
MaxIdleConns int
ConnMaxLifetime time.Duration
ConnMaxIdleTime time.Duration
}
func buildDBPoolSettings(cfg *config.Config) dbPoolSettings {
return dbPoolSettings{
MaxOpenConns: cfg.Database.MaxOpenConns,
MaxIdleConns: cfg.Database.MaxIdleConns,
ConnMaxLifetime: time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute,
ConnMaxIdleTime: time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute,
}
}
func applyDBPoolSettings(db *sql.DB, cfg *config.Config) {
settings := buildDBPoolSettings(cfg)
db.SetMaxOpenConns(settings.MaxOpenConns)
db.SetMaxIdleConns(settings.MaxIdleConns)
db.SetConnMaxLifetime(settings.ConnMaxLifetime)
db.SetConnMaxIdleTime(settings.ConnMaxIdleTime)
}

View File

@@ -0,0 +1,50 @@
package repository
import (
"database/sql"
"testing"
"time"
"github.com/user-management-system/internal/config"
"github.com/stretchr/testify/require"
_ "github.com/lib/pq"
)
func TestBuildDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 50,
MaxIdleConns: 10,
ConnMaxLifetimeMinutes: 30,
ConnMaxIdleTimeMinutes: 5,
},
}
settings := buildDBPoolSettings(cfg)
require.Equal(t, 50, settings.MaxOpenConns)
require.Equal(t, 10, settings.MaxIdleConns)
require.Equal(t, 30*time.Minute, settings.ConnMaxLifetime)
require.Equal(t, 5*time.Minute, settings.ConnMaxIdleTime)
}
func TestApplyDBPoolSettings(t *testing.T) {
cfg := &config.Config{
Database: config.DatabaseConfig{
MaxOpenConns: 40,
MaxIdleConns: 8,
ConnMaxLifetimeMinutes: 15,
ConnMaxIdleTimeMinutes: 3,
},
}
db, err := sql.Open("postgres", "host=127.0.0.1 port=5432 user=postgres sslmode=disable")
require.NoError(t, err)
t.Cleanup(func() {
_ = db.Close()
})
applyDBPoolSettings(db, cfg)
stats := db.Stats()
require.Equal(t, 40, stats.MaxOpenConnections)
}

View File

@@ -0,0 +1,256 @@
package repository
import (
"context"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// DeviceRepository 设备数据访问层
type DeviceRepository struct {
db *gorm.DB
}
// NewDeviceRepository 创建设备数据访问层
func NewDeviceRepository(db *gorm.DB) *DeviceRepository {
return &DeviceRepository{db: db}
}
// Create 创建设备
func (r *DeviceRepository) Create(ctx context.Context, device *domain.Device) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill inactive status so callers can persist status=0 devices.
requestedStatus := device.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(device).Error; err != nil {
return err
}
if requestedStatus == domain.DeviceStatusInactive {
if err := tx.Model(&domain.Device{}).Where("id = ?", device.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
device.Status = requestedStatus
}
return nil
})
}
// Update 更新设备
func (r *DeviceRepository) Update(ctx context.Context, device *domain.Device) error {
return r.db.WithContext(ctx).Save(device).Error
}
// Delete 删除设备
func (r *DeviceRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Device{}, id).Error
}
// GetByID 根据ID获取设备
func (r *DeviceRepository) GetByID(ctx context.Context, id int64) (*domain.Device, error) {
var device domain.Device
err := r.db.WithContext(ctx).First(&device, id).Error
if err != nil {
return nil, err
}
return &device, nil
}
// GetByDeviceID 根据设备ID和用户ID获取设备
func (r *DeviceRepository) GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
var device domain.Device
err := r.db.WithContext(ctx).Where("user_id = ? AND device_id = ?", userID, deviceID).First(&device).Error
if err != nil {
return nil, err
}
return &device, nil
}
// List 获取设备列表
func (r *DeviceRepository) List(ctx context.Context, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// ListByUserID 根据用户ID获取设备列表
func (r *DeviceRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("user_id = ?", userID)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Order("last_active_time DESC").Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// ListByStatus 根据状态获取设备列表
func (r *DeviceRepository) ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}
// UpdateStatus 更新设备状态
func (r *DeviceRepository) UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error {
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("status", status).Error
}
// UpdateLastActiveTime 更新最后活跃时间
func (r *DeviceRepository) UpdateLastActiveTime(ctx context.Context, id int64) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("last_active_time", now).Error
}
// Exists 检查设备是否存在
func (r *DeviceRepository) Exists(ctx context.Context, userID int64, deviceID string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Device{}).
Where("user_id = ? AND device_id = ?", userID, deviceID).
Count(&count).Error
return count > 0, err
}
// DeleteByUserID 删除用户的所有设备
func (r *DeviceRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.Device{}).Error
}
// GetActiveDevices 获取活跃设备
func (r *DeviceRepository) GetActiveDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
var devices []*domain.Device
thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour)
err := r.db.WithContext(ctx).
Where("user_id = ? AND last_active_time > ?", userID, thirtyDaysAgo).
Order("last_active_time DESC").
Find(&devices).Error
if err != nil {
return nil, err
}
return devices, nil
}
// TrustDevice 设置设备为信任状态
func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expiresAt *time.Time) error {
updates := map[string]interface{}{
"is_trusted": true,
"trust_expires_at": expiresAt,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
}
// UntrustDevice 取消设备信任状态
func (r *DeviceRepository) UntrustDevice(ctx context.Context, deviceID int64) error {
updates := map[string]interface{}{
"is_trusted": false,
"trust_expires_at": nil,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
}
// DeleteAllByUserIDExcept 删除用户的所有设备(除指定设备外)
func (r *DeviceRepository) DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error {
return r.db.WithContext(ctx).
Where("user_id = ? AND id != ?", userID, exceptDeviceID).
Delete(&domain.Device{}).Error
}
// GetTrustedDevices 获取用户的信任设备列表
func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
var devices []*domain.Device
now := time.Now()
err := r.db.WithContext(ctx).
Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now).
Order("last_active_time DESC").
Find(&devices).Error
if err != nil {
return nil, err
}
return devices, nil
}
// ListDevicesParams 设备列表查询参数
type ListDevicesParams struct {
UserID int64
Status domain.DeviceStatus
IsTrusted *bool
Keyword string
Offset int
Limit int
}
// ListAll 获取所有设备列表(支持筛选)
func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParams) ([]*domain.Device, int64, error) {
var devices []*domain.Device
var total int64
query := r.db.WithContext(ctx).Model(&domain.Device{})
// 按用户ID筛选
if params.UserID > 0 {
query = query.Where("user_id = ?", params.UserID)
}
// 按状态筛选
if params.Status >= 0 {
query = query.Where("status = ?", params.Status)
}
// 按信任状态筛选
if params.IsTrusted != nil {
query = query.Where("is_trusted = ?", *params.IsTrusted)
}
// 按关键词筛选(设备名/IP/位置)
if params.Keyword != "" {
search := "%" + params.Keyword + "%"
query = query.Where("device_name LIKE ? OR ip LIKE ? OR location LIKE ?", search, search, search)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(params.Offset).Limit(params.Limit).
Order("last_active_time DESC").Find(&devices).Error; err != nil {
return nil, 0, err
}
return devices, total, nil
}

View File

@@ -0,0 +1,92 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type EmailCacheSuite struct {
IntegrationRedisSuite
cache service.EmailCache
}
func (s *EmailCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewEmailCache(s.rdb)
}
func (s *EmailCacheSuite) TestGetVerificationCode_Missing() {
_, err := s.cache.GetVerificationCode(s.ctx, "nonexistent@example.com")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing verification code")
}
func (s *EmailCacheSuite) TestSetAndGetVerificationCode() {
email := "a@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "123456", Attempts: 1, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
got, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode")
require.Equal(s.T(), "123456", got.Code)
require.Equal(s.T(), 1, got.Attempts)
}
func (s *EmailCacheSuite) TestVerificationCode_TTL() {
email := "ttl@example.com"
emailTTL := 2 * time.Minute
data := &service.VerificationCodeData{Code: "654321", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, emailTTL), "SetVerificationCode")
emailKey := verifyCodeKeyPrefix + email
ttl, err := s.rdb.TTL(s.ctx, emailKey).Result()
require.NoError(s.T(), err, "TTL emailKey")
s.AssertTTLWithin(ttl, 1*time.Second, emailTTL)
}
func (s *EmailCacheSuite) TestDeleteVerificationCode() {
email := "delete@example.com"
data := &service.VerificationCodeData{Code: "999999", Attempts: 0, CreatedAt: time.Now()}
require.NoError(s.T(), s.cache.SetVerificationCode(s.ctx, email, data, 2*time.Minute), "SetVerificationCode")
// Verify it exists
_, err := s.cache.GetVerificationCode(s.ctx, email)
require.NoError(s.T(), err, "GetVerificationCode before delete")
// Delete
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, email), "DeleteVerificationCode")
// Verify it's gone
_, err = s.cache.GetVerificationCode(s.ctx, email)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *EmailCacheSuite) TestDeleteVerificationCode_NonExistent() {
// Deleting a non-existent key should not error
require.NoError(s.T(), s.cache.DeleteVerificationCode(s.ctx, "nonexistent@example.com"), "DeleteVerificationCode non-existent")
}
func (s *EmailCacheSuite) TestGetVerificationCode_JSONCorruption() {
emailKey := verifyCodeKeyPrefix + "corrupted@example.com"
require.NoError(s.T(), s.rdb.Set(s.ctx, emailKey, "not-json", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetVerificationCode(s.ctx, "corrupted@example.com")
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func TestEmailCacheSuite(t *testing.T) {
suite.Run(t, new(EmailCacheSuite))
}

View File

@@ -0,0 +1,45 @@
//go:build unit
package repository
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestVerifyCodeKey(t *testing.T) {
tests := []struct {
name string
email string
expected string
}{
{
name: "normal_email",
email: "user@example.com",
expected: "verify_code:user@example.com",
},
{
name: "empty_email",
email: "",
expected: "verify_code:",
},
{
name: "email_with_plus",
email: "user+tag@example.com",
expected: "verify_code:user+tag@example.com",
},
{
name: "email_with_special_chars",
email: "user.name+tag@sub.domain.com",
expected: "verify_code:user.name+tag@sub.domain.com",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := verifyCodeKey(tc.email)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,109 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GatewayCacheSuite struct {
IntegrationRedisSuite
cache service.GatewayCache
}
func (s *GatewayCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGatewayCache(s.rdb)
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_Missing() {
_, err := s.cache.GetSessionAccountID(s.ctx, 1, "nonexistent")
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing session")
}
func (s *GatewayCacheSuite) TestSetAndGetSessionAccountID() {
sessionID := "s1"
accountID := int64(99)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sid, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.NoError(s.T(), err, "GetSessionAccountID")
require.Equal(s.T(), accountID, sid, "session id mismatch")
}
func (s *GatewayCacheSuite) TestSessionAccountID_TTL() {
sessionID := "s2"
accountID := int64(100)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL sessionKey after Set")
s.AssertTTLWithin(ttl, 1*time.Second, sessionTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL() {
sessionID := "s3"
accountID := int64(101)
groupID := int64(1)
initialTTL := 1 * time.Minute
refreshTTL := 3 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, initialTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.RefreshSessionTTL(s.ctx, groupID, sessionID, refreshTTL), "RefreshSessionTTL")
sessionKey := buildSessionKey(groupID, sessionID)
ttl, err := s.rdb.TTL(s.ctx, sessionKey).Result()
require.NoError(s.T(), err, "TTL after Refresh")
s.AssertTTLWithin(ttl, 1*time.Second, refreshTTL)
}
func (s *GatewayCacheSuite) TestRefreshSessionTTL_MissingKey() {
// RefreshSessionTTL on a missing key should not error (no-op)
err := s.cache.RefreshSessionTTL(s.ctx, 1, "missing-session", 1*time.Minute)
require.NoError(s.T(), err, "RefreshSessionTTL on missing key should not error")
}
func (s *GatewayCacheSuite) TestDeleteSessionAccountID() {
sessionID := "openai:s4"
accountID := int64(102)
groupID := int64(1)
sessionTTL := 1 * time.Minute
require.NoError(s.T(), s.cache.SetSessionAccountID(s.ctx, groupID, sessionID, accountID, sessionTTL), "SetSessionAccountID")
require.NoError(s.T(), s.cache.DeleteSessionAccountID(s.ctx, groupID, sessionID), "DeleteSessionAccountID")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() {
sessionID := "corrupted"
groupID := int64(1)
sessionKey := buildSessionKey(groupID, sessionID)
// Set a non-integer value
require.NoError(s.T(), s.rdb.Set(s.ctx, sessionKey, "not-a-number", 1*time.Minute).Err(), "Set invalid value")
_, err := s.cache.GetSessionAccountID(s.ctx, groupID, sessionID)
require.Error(s.T(), err, "expected error for corrupted value")
require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil")
}
func TestGatewayCacheSuite(t *testing.T) {
suite.Run(t, new(GatewayCacheSuite))
}

View File

@@ -0,0 +1,250 @@
//go:build integration
package repository
import (
"context"
"testing"
dbent "github.com/user-management-system/ent"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/suite"
)
// GatewayRoutingSuite 测试网关路由相关的数据库查询
// 验证账户选择和分流逻辑在真实数据库环境下的行为
type GatewayRoutingSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
accountRepo *accountRepository
}
func (s *GatewayRoutingSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.accountRepo = newAccountRepositoryWithSQL(s.client, tx, nil)
}
func TestGatewayRoutingSuite(t *testing.T) {
suite.Run(t, new(GatewayRoutingSuite))
}
// TestListSchedulableByPlatforms_GeminiAndAntigravity 验证多平台账户查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatforms_GeminiAndAntigravity() {
// 创建各平台账户
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-oauth",
Platform: service.PlatformGemini,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 1,
})
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-oauth",
Platform: service.PlatformAntigravity,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 2,
Credentials: map[string]any{
"access_token": "test-token",
"refresh_token": "test-refresh",
"project_id": "test-project",
},
})
// 创建不应被选中的 anthropic 账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "anthropic-oauth",
Platform: service.PlatformAnthropic,
Type: service.AccountTypeOAuth,
Status: service.StatusActive,
Schedulable: true,
Priority: 0,
})
// 查询 gemini + antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatforms(s.ctx, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 2, "应返回 gemini 和 antigravity 两个账户")
// 验证返回的账户平台
platforms := make(map[string]bool)
for _, acc := range accounts {
platforms[acc.Platform] = true
}
s.Require().True(platforms[service.PlatformGemini], "应包含 gemini 账户")
s.Require().True(platforms[service.PlatformAntigravity], "应包含 antigravity 账户")
s.Require().False(platforms[service.PlatformAnthropic], "不应包含 anthropic 账户")
// 验证账户 ID 匹配
ids := make(map[int64]bool)
for _, acc := range accounts {
ids[acc.ID] = true
}
s.Require().True(ids[geminiAcc.ID])
s.Require().True(ids[antigravityAcc.ID])
}
// TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding 验证按分组过滤
func (s *GatewayRoutingSuite) TestListSchedulableByGroupIDAndPlatforms_WithGroupBinding() {
// 创建 gemini 分组
group := mustCreateGroup(s.T(), s.client, &service.Group{
Name: "gemini-group",
Platform: service.PlatformGemini,
Status: service.StatusActive,
})
// 创建账户
boundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "bound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
unboundAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "unbound-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只绑定一个账户到分组
mustBindAccountToGroup(s.T(), s.client, boundAcc.ID, group.ID, 1)
// 查询分组内的账户
accounts, err := s.accountRepo.ListSchedulableByGroupIDAndPlatforms(s.ctx, group.ID, []string{
service.PlatformGemini,
service.PlatformAntigravity,
})
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回绑定到分组的账户")
s.Require().Equal(boundAcc.ID, accounts[0].ID)
// 确认未绑定的账户不在结果中
for _, acc := range accounts {
s.Require().NotEqual(unboundAcc.ID, acc.ID, "不应包含未绑定的账户")
}
}
// TestListSchedulableByPlatform_Antigravity 验证单平台查询
func (s *GatewayRoutingSuite) TestListSchedulableByPlatform_Antigravity() {
// 创建多种平台账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-1",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravity := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-1",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 只查询 antigravity 平台
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1)
s.Require().Equal(antigravity.ID, accounts[0].ID)
s.Require().Equal(service.PlatformAntigravity, accounts[0].Platform)
}
// TestSchedulableFilter_ExcludesInactive 验证不可调度账户被过滤
func (s *GatewayRoutingSuite) TestSchedulableFilter_ExcludesInactive() {
// 创建可调度账户
activeAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "active-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
// 创建不可调度账户(需要先创建再更新,因为 fixture 默认设置 Schedulable=true
inactiveAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "inactive-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
})
s.Require().NoError(s.client.Account.UpdateOneID(inactiveAcc.ID).SetSchedulable(false).Exec(s.ctx))
// 创建错误状态账户
mustCreateAccount(s.T(), s.client, &service.Account{
Name: "error-antigravity",
Platform: service.PlatformAntigravity,
Status: service.StatusError,
Schedulable: true,
})
accounts, err := s.accountRepo.ListSchedulableByPlatform(s.ctx, service.PlatformAntigravity)
s.Require().NoError(err)
s.Require().Len(accounts, 1, "应只返回可调度的 active 账户")
s.Require().Equal(activeAcc.ID, accounts[0].ID)
}
// TestPlatformRoutingDecision 验证平台路由决策
// 这个测试模拟 Handler 层在选择账户后的路由决策逻辑
func (s *GatewayRoutingSuite) TestPlatformRoutingDecision() {
// 创建两种平台的账户
geminiAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "gemini-route-test",
Platform: service.PlatformGemini,
Status: service.StatusActive,
Schedulable: true,
})
antigravityAcc := mustCreateAccount(s.T(), s.client, &service.Account{
Name: "antigravity-route-test",
Platform: service.PlatformAntigravity,
Status: service.StatusActive,
Schedulable: true,
})
tests := []struct {
name string
accountID int64
expectedService string
}{
{
name: "Gemini账户路由到ForwardNative",
accountID: geminiAcc.ID,
expectedService: "GeminiMessagesCompatService.ForwardNative",
},
{
name: "Antigravity账户路由到ForwardGemini",
accountID: antigravityAcc.ID,
expectedService: "AntigravityGatewayService.ForwardGemini",
},
}
for _, tt := range tests {
s.Run(tt.name, func() {
// 从数据库获取账户
account, err := s.accountRepo.GetByID(s.ctx, tt.accountID)
s.Require().NoError(err)
// 模拟 Handler 层的路由决策
var routedService string
if account.Platform == service.PlatformAntigravity {
routedService = "AntigravityGatewayService.ForwardGemini"
} else {
routedService = "GeminiMessagesCompatService.ForwardNative"
}
s.Require().Equal(tt.expectedService, routedService)
})
}
}

View File

@@ -0,0 +1,9 @@
package repository
import "github.com/user-management-system/internal/pkg/geminicli"
// NewGeminiDriveClient creates a concrete DriveClient for Google Drive API operations.
// Returned as geminicli.DriveClient interface for DI (Strategy A).
func NewGeminiDriveClient() geminicli.DriveClient {
return geminicli.NewDriveClient()
}

View File

@@ -0,0 +1,47 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type GeminiTokenCacheSuite struct {
IntegrationRedisSuite
cache service.GeminiTokenCache
}
func (s *GeminiTokenCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewGeminiTokenCache(s.rdb)
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken() {
cacheKey := "project-123"
token := "token-value"
require.NoError(s.T(), s.cache.SetAccessToken(s.ctx, cacheKey, token, time.Minute))
got, err := s.cache.GetAccessToken(s.ctx, cacheKey)
require.NoError(s.T(), err)
require.Equal(s.T(), token, got)
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, cacheKey))
_, err = s.cache.GetAccessToken(s.ctx, cacheKey)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil after delete")
}
func (s *GeminiTokenCacheSuite) TestDeleteAccessToken_MissingKey() {
require.NoError(s.T(), s.cache.DeleteAccessToken(s.ctx, "missing-key"))
}
func TestGeminiTokenCacheSuite(t *testing.T) {
suite.Run(t, new(GeminiTokenCacheSuite))
}

View File

@@ -0,0 +1,28 @@
//go:build unit
package repository
import (
"context"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
func TestGeminiTokenCache_DeleteAccessToken_RedisError(t *testing.T) {
rdb := redis.NewClient(&redis.Options{
Addr: "127.0.0.1:1",
DialTimeout: 50 * time.Millisecond,
ReadTimeout: 50 * time.Millisecond,
WriteTimeout: 50 * time.Millisecond,
})
t.Cleanup(func() {
_ = rdb.Close()
})
cache := NewGeminiTokenCache(rdb)
err := cache.DeleteAccessToken(context.Background(), "broken")
require.Error(t, err)
}

View File

@@ -0,0 +1,67 @@
//go:build integration
package repository
import (
"errors"
"fmt"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type IdentityCacheSuite struct {
IntegrationRedisSuite
cache *identityCache
}
func (s *IdentityCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewIdentityCache(s.rdb).(*identityCache)
}
func (s *IdentityCacheSuite) TestGetFingerprint_Missing() {
_, err := s.cache.GetFingerprint(s.ctx, 1)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing fingerprint")
}
func (s *IdentityCacheSuite) TestSetAndGetFingerprint() {
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 1, fp), "SetFingerprint")
gotFP, err := s.cache.GetFingerprint(s.ctx, 1)
require.NoError(s.T(), err, "GetFingerprint")
require.Equal(s.T(), "c1", gotFP.ClientID)
require.Equal(s.T(), "ua", gotFP.UserAgent)
}
func (s *IdentityCacheSuite) TestFingerprint_TTL() {
fp := &service.Fingerprint{ClientID: "c1", UserAgent: "ua"}
require.NoError(s.T(), s.cache.SetFingerprint(s.ctx, 2, fp))
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 2)
ttl, err := s.rdb.TTL(s.ctx, fpKey).Result()
require.NoError(s.T(), err, "TTL fpKey")
s.AssertTTLWithin(ttl, 1*time.Second, fingerprintTTL)
}
func (s *IdentityCacheSuite) TestGetFingerprint_JSONCorruption() {
fpKey := fmt.Sprintf("%s%d", fingerprintKeyPrefix, 999)
require.NoError(s.T(), s.rdb.Set(s.ctx, fpKey, "invalid-json-data", 1*time.Minute).Err(), "Set invalid JSON")
_, err := s.cache.GetFingerprint(s.ctx, 999)
require.Error(s.T(), err, "expected error for corrupted JSON")
require.False(s.T(), errors.Is(err, redis.Nil), "expected decoding error, not redis.Nil")
}
func (s *IdentityCacheSuite) TestSetFingerprint_Nil() {
err := s.cache.SetFingerprint(s.ctx, 100, nil)
require.NoError(s.T(), err, "SetFingerprint(nil) should succeed")
}
func TestIdentityCacheSuite(t *testing.T) {
suite.Run(t, new(IdentityCacheSuite))
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package repository
import (
"math"
"testing"
"github.com/stretchr/testify/require"
)
func TestFingerprintKey(t *testing.T) {
tests := []struct {
name string
accountID int64
expected string
}{
{
name: "normal_account_id",
accountID: 123,
expected: "fingerprint:123",
},
{
name: "zero_account_id",
accountID: 0,
expected: "fingerprint:0",
},
{
name: "negative_account_id",
accountID: -1,
expected: "fingerprint:-1",
},
{
name: "max_int64",
accountID: math.MaxInt64,
expected: "fingerprint:9223372036854775807",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := fingerprintKey(tc.accountID)
require.Equal(t, tc.expected, got)
})
}
}

View File

@@ -0,0 +1,63 @@
package repository
import (
"bytes"
"io"
"net"
"net/http"
"net/http/httptest"
"sync"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
// newInProcessTransport adapts an http.HandlerFunc into an http.RoundTripper without opening sockets.
// It captures the request body (if any) and then rewinds it before invoking the handler.
func newInProcessTransport(handler http.HandlerFunc, capture func(r *http.Request, body []byte)) http.RoundTripper {
return roundTripFunc(func(r *http.Request) (*http.Response, error) {
var body []byte
if r.Body != nil {
body, _ = io.ReadAll(r.Body)
_ = r.Body.Close()
r.Body = io.NopCloser(bytes.NewReader(body))
}
if capture != nil {
capture(r, body)
}
rec := httptest.NewRecorder()
handler(rec, r)
return rec.Result(), nil
})
}
var (
canListenOnce sync.Once
canListen bool
canListenErr error
)
func localListenerAvailable() bool {
canListenOnce.Do(func() {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
canListenErr = err
canListen = false
return
}
_ = ln.Close()
canListen = true
})
return canListen
}
func newLocalTestServer(tb testing.TB, handler http.Handler) *httptest.Server {
tb.Helper()
if !localListenerAvailable() {
tb.Skipf("local listeners are not permitted in this environment: %v", canListenErr)
}
return httptest.NewServer(handler)
}

View File

@@ -0,0 +1,140 @@
package repository
import (
"context"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// LoginLogRepository 登录日志仓储
type LoginLogRepository struct {
db *gorm.DB
}
// NewLoginLogRepository 创建登录日志仓储
func NewLoginLogRepository(db *gorm.DB) *LoginLogRepository {
return &LoginLogRepository{db: db}
}
// Create 创建登录日志
func (r *LoginLogRepository) Create(ctx context.Context, log *domain.LoginLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
// GetByID 根据ID获取登录日志
func (r *LoginLogRepository) GetByID(ctx context.Context, id int64) (*domain.LoginLog, error) {
var log domain.LoginLog
if err := r.db.WithContext(ctx).First(&log, id).Error; err != nil {
return nil, err
}
return &log, nil
}
// ListByUserID 获取用户的登录日志列表
func (r *LoginLogRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.LoginLog, int64, error) {
var logs []*domain.LoginLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("user_id = ?", userID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// List 获取登录日志列表(管理员用)
func (r *LoginLogRepository) List(ctx context.Context, offset, limit int) ([]*domain.LoginLog, int64, error) {
var logs []*domain.LoginLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.LoginLog{})
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ListByStatus 按状态查询登录日志
func (r *LoginLogRepository) ListByStatus(ctx context.Context, status int, offset, limit int) ([]*domain.LoginLog, int64, error) {
var logs []*domain.LoginLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).Where("status = ?", status)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ListByTimeRange 按时间范围查询登录日志
func (r *LoginLogRepository) ListByTimeRange(ctx context.Context, start, end time.Time, offset, limit int) ([]*domain.LoginLog, int64, error) {
var logs []*domain.LoginLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.LoginLog{}).
Where("created_at >= ? AND created_at <= ?", start, end)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// DeleteByUserID 删除用户所有登录日志
func (r *LoginLogRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.LoginLog{}).Error
}
// DeleteOlderThan 删除指定天数前的日志
func (r *LoginLogRepository) DeleteOlderThan(ctx context.Context, days int) error {
cutoff := time.Now().AddDate(0, 0, -days)
return r.db.WithContext(ctx).Where("created_at < ?", cutoff).Delete(&domain.LoginLog{}).Error
}
// CountByResultSince 统计指定时间之后特定结果的登录次数
// success=true 统计成功次数false 统计失败次数
func (r *LoginLogRepository) CountByResultSince(ctx context.Context, success bool, since time.Time) int64 {
status := 0 // 失败
if success {
status = 1 // 成功
}
var count int64
r.db.WithContext(ctx).Model(&domain.LoginLog{}).
Where("status = ? AND created_at >= ?", status, since).
Count(&count)
return count
}
// ListAllForExport 获取所有登录日志(用于导出,无分页)
func (r *LoginLogRepository) ListAllForExport(ctx context.Context, userID int64, status int, startAt, endAt *time.Time) ([]*domain.LoginLog, error) {
var logs []*domain.LoginLog
query := r.db.WithContext(ctx).Model(&domain.LoginLog{})
if userID > 0 {
query = query.Where("user_id = ?", userID)
}
if status == 0 || status == 1 {
query = query.Where("status = ?", status)
}
if startAt != nil {
query = query.Where("created_at >= ?", startAt)
}
if endAt != nil {
query = query.Where("created_at <= ?", endAt)
}
if err := query.Order("created_at DESC").Find(&logs).Error; err != nil {
return nil, err
}
return logs, nil
}

View File

@@ -0,0 +1,113 @@
package repository
import (
"context"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// OperationLogRepository 操作日志仓储
type OperationLogRepository struct {
db *gorm.DB
}
// NewOperationLogRepository 创建操作日志仓储
func NewOperationLogRepository(db *gorm.DB) *OperationLogRepository {
return &OperationLogRepository{db: db}
}
// Create 创建操作日志
func (r *OperationLogRepository) Create(ctx context.Context, log *domain.OperationLog) error {
return r.db.WithContext(ctx).Create(log).Error
}
// GetByID 根据ID获取操作日志
func (r *OperationLogRepository) GetByID(ctx context.Context, id int64) (*domain.OperationLog, error) {
var log domain.OperationLog
if err := r.db.WithContext(ctx).First(&log, id).Error; err != nil {
return nil, err
}
return &log, nil
}
// ListByUserID 获取用户的操作日志列表
func (r *OperationLogRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.OperationLog, int64, error) {
var logs []*domain.OperationLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).Where("user_id = ?", userID)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// List 获取操作日志列表(管理员用)
func (r *OperationLogRepository) List(ctx context.Context, offset, limit int) ([]*domain.OperationLog, int64, error) {
var logs []*domain.OperationLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.OperationLog{})
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ListByMethod 按HTTP方法查询操作日志
func (r *OperationLogRepository) ListByMethod(ctx context.Context, method string, offset, limit int) ([]*domain.OperationLog, int64, error) {
var logs []*domain.OperationLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).Where("request_method = ?", method)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// ListByTimeRange 按时间范围查询操作日志
func (r *OperationLogRepository) ListByTimeRange(ctx context.Context, start, end time.Time, offset, limit int) ([]*domain.OperationLog, int64, error) {
var logs []*domain.OperationLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).
Where("created_at >= ? AND created_at <= ?", start, end)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}
// DeleteOlderThan 删除指定天数前的日志
func (r *OperationLogRepository) DeleteOlderThan(ctx context.Context, days int) error {
cutoff := time.Now().AddDate(0, 0, -days)
return r.db.WithContext(ctx).Where("created_at < ?", cutoff).Delete(&domain.OperationLog{}).Error
}
// Search 按关键词搜索操作日志
func (r *OperationLogRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.OperationLog, int64, error) {
var logs []*domain.OperationLog
var total int64
query := r.db.WithContext(ctx).Model(&domain.OperationLog{}).
Where("operation_name LIKE ? OR request_path LIKE ? OR operation_type LIKE ?",
"%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&logs).Error; err != nil {
return nil, 0, err
}
return logs, total, nil
}

View File

@@ -0,0 +1,79 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/require"
)
func TestOpsRepositoryBatchInsertErrorLogs(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE ops_error_logs RESTART IDENTITY")
repo := NewOpsRepository(integrationDB).(*opsRepository)
now := time.Now().UTC()
inserted, err := repo.BatchInsertErrorLogs(ctx, []*service.OpsInsertErrorLogInput{
{
RequestID: "batch-ops-1",
ErrorPhase: "upstream",
ErrorType: "upstream_error",
Severity: "error",
StatusCode: 429,
ErrorMessage: "rate limited",
CreatedAt: now,
},
{
RequestID: "batch-ops-2",
ErrorPhase: "internal",
ErrorType: "api_error",
Severity: "error",
StatusCode: 500,
ErrorMessage: "internal error",
CreatedAt: now.Add(time.Millisecond),
},
})
require.NoError(t, err)
require.EqualValues(t, 2, inserted)
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM ops_error_logs WHERE request_id IN ('batch-ops-1', 'batch-ops-2')").Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DeduplicatesIdempotentEvents(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(12345)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 1, count)
time.Sleep(schedulerOutboxDedupWindow + 150*time.Millisecond)
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil))
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountChanged).Scan(&count))
require.Equal(t, 2, count)
}
func TestEnqueueSchedulerOutbox_DoesNotDeduplicateLastUsed(t *testing.T) {
ctx := context.Background()
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox RESTART IDENTITY")
accountID := int64(67890)
payload1 := map[string]any{"last_used": map[string]int64{"67890": 100}}
payload2 := map[string]any{"last_used": map[string]int64{"67890": 200}}
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload1))
require.NoError(t, enqueueSchedulerOutbox(ctx, integrationDB, service.SchedulerOutboxEventAccountLastUsed, &accountID, nil, payload2))
var count int
require.NoError(t, integrationDB.QueryRowContext(ctx, "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1", service.SchedulerOutboxEventAccountLastUsed).Scan(&count))
require.Equal(t, 2, count)
}

View File

@@ -0,0 +1,16 @@
package repository
import "github.com/user-management-system/internal/pkg/pagination"
func paginationResultFromTotal(total int64, params pagination.PaginationParams) *pagination.PaginationResult {
pages := int(total) / params.Limit()
if int(total)%params.Limit() > 0 {
pages++
}
return &pagination.PaginationResult{
Total: total,
Page: params.Page,
PageSize: params.Limit(),
Pages: pages,
}
}

View File

@@ -0,0 +1,58 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// PasswordHistoryRepository 密码历史记录数据访问层
type PasswordHistoryRepository struct {
db *gorm.DB
}
// NewPasswordHistoryRepository 创建密码历史记录数据访问层
func NewPasswordHistoryRepository(db *gorm.DB) *PasswordHistoryRepository {
return &PasswordHistoryRepository{db: db}
}
// Create 创建密码历史记录
func (r *PasswordHistoryRepository) Create(ctx context.Context, history *domain.PasswordHistory) error {
return r.db.WithContext(ctx).Create(history).Error
}
// GetByUserID 获取用户的密码历史记录(最近 N 条,按时间倒序)
func (r *PasswordHistoryRepository) GetByUserID(ctx context.Context, userID int64, limit int) ([]*domain.PasswordHistory, error) {
var histories []*domain.PasswordHistory
err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Limit(limit).
Find(&histories).Error
return histories, err
}
// DeleteOldRecords 删除超过 keepCount 条的旧记录(保留最新的 keepCount 条)
func (r *PasswordHistoryRepository) DeleteOldRecords(ctx context.Context, userID int64, keepCount int) error {
// 找出要保留的最后一条记录的 ID
var ids []int64
err := r.db.WithContext(ctx).
Model(&domain.PasswordHistory{}).
Where("user_id = ?", userID).
Order("created_at DESC").
Limit(keepCount).
Pluck("id", &ids).Error
if err != nil {
return err
}
if len(ids) == 0 {
return nil
}
// 删除不在保留列表中的记录
return r.db.WithContext(ctx).
Where("user_id = ? AND id NOT IN ?", userID, ids).
Delete(&domain.PasswordHistory{}).Error
}

View File

@@ -0,0 +1,202 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// PermissionRepository 权限数据访问层
type PermissionRepository struct {
db *gorm.DB
}
// NewPermissionRepository 创建权限数据访问层
func NewPermissionRepository(db *gorm.DB) *PermissionRepository {
return &PermissionRepository{db: db}
}
// Create 创建权限
func (r *PermissionRepository) Create(ctx context.Context, permission *domain.Permission) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 permissions.
requestedStatus := permission.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(permission).Error; err != nil {
return err
}
if requestedStatus == domain.PermissionStatusDisabled {
if err := tx.Model(&domain.Permission{}).Where("id = ?", permission.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
permission.Status = requestedStatus
}
return nil
})
}
// Update 更新权限
func (r *PermissionRepository) Update(ctx context.Context, permission *domain.Permission) error {
return r.db.WithContext(ctx).Save(permission).Error
}
// Delete 删除权限
func (r *PermissionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Permission{}, id).Error
}
// GetByID 根据ID获取权限
func (r *PermissionRepository) GetByID(ctx context.Context, id int64) (*domain.Permission, error) {
var permission domain.Permission
err := r.db.WithContext(ctx).First(&permission, id).Error
if err != nil {
return nil, err
}
return &permission, nil
}
// GetByCode 根据代码获取权限
func (r *PermissionRepository) GetByCode(ctx context.Context, code string) (*domain.Permission, error) {
var permission domain.Permission
err := r.db.WithContext(ctx).Where("code = ?", code).First(&permission).Error
if err != nil {
return nil, err
}
return &permission, nil
}
// List 获取权限列表
func (r *PermissionRepository) List(ctx context.Context, offset, limit int) ([]*domain.Permission, int64, error) {
var permissions []*domain.Permission
var total int64
query := r.db.WithContext(ctx).Model(&domain.Permission{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
return nil, 0, err
}
return permissions, total, nil
}
// ListByType 根据类型获取权限列表
func (r *PermissionRepository) ListByType(ctx context.Context, permissionType domain.PermissionType, offset, limit int) ([]*domain.Permission, int64, error) {
var permissions []*domain.Permission
var total int64
query := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("type = ?", permissionType)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
return nil, 0, err
}
return permissions, total, nil
}
// ListByStatus 根据状态获取权限列表
func (r *PermissionRepository) ListByStatus(ctx context.Context, status domain.PermissionStatus, offset, limit int) ([]*domain.Permission, int64, error) {
var permissions []*domain.Permission
var total int64
query := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
return nil, 0, err
}
return permissions, total, nil
}
// GetByRoleIDs 根据角色ID获取权限列表
func (r *PermissionRepository) GetByRoleIDs(ctx context.Context, roleIDs []int64) ([]*domain.Permission, error) {
var permissions []*domain.Permission
err := r.db.WithContext(ctx).
Joins("INNER JOIN role_permissions ON permissions.id = role_permissions.permission_id").
Where("role_permissions.role_id IN ?", roleIDs).
Where("permissions.status = ?", domain.PermissionStatusEnabled).
Find(&permissions).Error
if err != nil {
return nil, err
}
return permissions, nil
}
// ExistsByCode 检查权限代码是否存在
func (r *PermissionRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Permission{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新权限状态
func (r *PermissionRepository) UpdateStatus(ctx context.Context, id int64, status domain.PermissionStatus) error {
return r.db.WithContext(ctx).Model(&domain.Permission{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索权限
func (r *PermissionRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Permission, int64, error) {
var permissions []*domain.Permission
var total int64
query := r.db.WithContext(ctx).Model(&domain.Permission{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&permissions).Error; err != nil {
return nil, 0, err
}
return permissions, total, nil
}
// ListByParentID 根据父ID获取权限列表
func (r *PermissionRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Permission, error) {
var permissions []*domain.Permission
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&permissions).Error
if err != nil {
return nil, err
}
return permissions, nil
}
// GetByIDs 根据ID列表批量获取权限
func (r *PermissionRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Permission, error) {
if len(ids) == 0 {
return []*domain.Permission{}, nil
}
var permissions []*domain.Permission
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&permissions).Error
if err != nil {
return nil, err
}
return permissions, nil
}

View File

@@ -0,0 +1,49 @@
package repository
import (
"crypto/tls"
"time"
"github.com/user-management-system/internal/config"
"github.com/redis/go-redis/v9"
)
// InitRedis 初始化 Redis 客户端
//
// 性能优化说明:
// 原实现使用 go-redis 默认配置,未设置连接池和超时参数:
// 1. 默认连接池大小可能不足以支撑高并发
// 2. 无超时控制可能导致慢操作阻塞
//
// 新实现支持可配置的连接池和超时参数:
// 1. PoolSize: 控制最大并发连接数(默认 128
// 2. MinIdleConns: 保持最小空闲连接,减少冷启动延迟(默认 10
// 3. DialTimeout/ReadTimeout/WriteTimeout: 精确控制各阶段超时
func InitRedis(cfg *config.Config) *redis.Client {
return redis.NewClient(buildRedisOptions(cfg))
}
// buildRedisOptions 构建 Redis 连接选项
// 从配置文件读取连接池和超时参数,支持生产环境调优
func buildRedisOptions(cfg *config.Config) *redis.Options {
opts := &redis.Options{
Addr: cfg.Redis.Address(),
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
DialTimeout: time.Duration(cfg.Redis.DialTimeoutSeconds) * time.Second, // 建连超时
ReadTimeout: time.Duration(cfg.Redis.ReadTimeoutSeconds) * time.Second, // 读取超时
WriteTimeout: time.Duration(cfg.Redis.WriteTimeoutSeconds) * time.Second, // 写入超时
PoolSize: cfg.Redis.PoolSize, // 连接池大小
MinIdleConns: cfg.Redis.MinIdleConns, // 最小空闲连接
}
if cfg.Redis.EnableTLS {
opts.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
ServerName: cfg.Redis.Host,
}
}
return opts
}

View File

@@ -0,0 +1,47 @@
package repository
import (
"testing"
"time"
"github.com/user-management-system/internal/config"
"github.com/stretchr/testify/require"
)
func TestBuildRedisOptions(t *testing.T) {
cfg := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
Port: 6379,
Password: "secret",
DB: 2,
DialTimeoutSeconds: 5,
ReadTimeoutSeconds: 3,
WriteTimeoutSeconds: 4,
PoolSize: 100,
MinIdleConns: 10,
},
}
opts := buildRedisOptions(cfg)
require.Equal(t, "localhost:6379", opts.Addr)
require.Equal(t, "secret", opts.Password)
require.Equal(t, 2, opts.DB)
require.Equal(t, 5*time.Second, opts.DialTimeout)
require.Equal(t, 3*time.Second, opts.ReadTimeout)
require.Equal(t, 4*time.Second, opts.WriteTimeout)
require.Equal(t, 100, opts.PoolSize)
require.Equal(t, 10, opts.MinIdleConns)
require.Nil(t, opts.TLSConfig)
// Test case with TLS enabled
cfgTLS := &config.Config{
Redis: config.RedisConfig{
Host: "localhost",
EnableTLS: true,
},
}
optsTLS := buildRedisOptions(cfgTLS)
require.NotNil(t, optsTLS.TLSConfig)
require.Equal(t, "localhost", optsTLS.TLSConfig.ServerName)
}

View File

@@ -0,0 +1,305 @@
// repo_bench_test.go — repository 层性能基准测试
// 覆盖:批量写入、并发只读查询、分页列表、更新状态、软删除
package repository
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
_ "modernc.org/sqlite"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
)
var repoBenchCounter int64
// openBenchDB 为 Benchmark 打开独立内存 DB不依赖 *testing.T
func openBenchDB(b *testing.B) *gorm.DB {
b.Helper()
id := atomic.AddInt64(&repoBenchCounter, 1)
dsn := fmt.Sprintf("file:repobenchdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
b.Fatalf("openBenchDB: %v", err)
}
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
); err != nil {
b.Fatalf("AutoMigrate: %v", err)
}
return db
}
// seedUsers 往 DB 插入 n 条用户
func seedUsers(b *testing.B, repo *UserRepository, n int) {
b.Helper()
ctx := context.Background()
for i := 0; i < n; i++ {
if err := repo.Create(ctx, &domain.User{
Username: fmt.Sprintf("benchuser%06d", i),
Email: domain.StrPtr(fmt.Sprintf("bench%06d@example.com", i)),
Phone: domain.StrPtr(fmt.Sprintf("1380000%04d", i%10000)),
Password: "hashed_placeholder",
Status: domain.UserStatusActive,
}); err != nil {
b.Fatalf("seedUsers i=%d: %v", i, err)
}
}
}
// ---------- BenchmarkRepo_Create — 单条写入吞吐 ----------
func BenchmarkRepo_Create(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
repo.Create(ctx, &domain.User{ //nolint:errcheck
Username: fmt.Sprintf("cr_%d_%d", b.N, i),
Email: domain.StrPtr(fmt.Sprintf("cr_%d_%d@bench.com", b.N, i)),
Password: "hash",
Status: domain.UserStatusActive,
})
}
}
// ---------- BenchmarkRepo_BulkCreate — 批量写入(串行) ----------
func BenchmarkRepo_BulkCreate(b *testing.B) {
sizes := []int{10, 100, 500}
for _, size := range sizes {
size := size
b.Run(fmt.Sprintf("batch=%d", size), func(b *testing.B) {
for i := 0; i < b.N; i++ {
b.StopTimer()
db := openBenchDB(b)
repo := NewUserRepository(db)
ctx := context.Background()
users := make([]*domain.User, size)
for j := 0; j < size; j++ {
users[j] = &domain.User{
Username: fmt.Sprintf("bulk_%d_%d_%d", i, j, size),
Password: "hash",
Status: domain.UserStatusActive,
}
}
b.StartTimer()
for _, u := range users {
repo.Create(ctx, u) //nolint:errcheck
}
}
})
}
}
// ---------- BenchmarkRepo_GetByID — 主键查询 ----------
func BenchmarkRepo_GetByID(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 1000)
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
id := int64(1)
for pb.Next() {
repo.GetByID(ctx, id) //nolint:errcheck
id++
if id > 1000 {
id = 1
}
}
})
}
// ---------- BenchmarkRepo_GetByUsername — 索引查询 ----------
func BenchmarkRepo_GetByUsername(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 500)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
repo.GetByUsername(ctx, fmt.Sprintf("benchuser%06d", i%500)) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_GetByEmail — 索引查询 ----------
func BenchmarkRepo_GetByEmail(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 500)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
repo.GetByEmail(ctx, fmt.Sprintf("bench%06d@example.com", i%500)) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_List — 分页列表 ----------
func BenchmarkRepo_List(b *testing.B) {
pageSizes := []int{10, 50, 200}
for _, ps := range pageSizes {
ps := ps
b.Run(fmt.Sprintf("pageSize=%d", ps), func(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 1000)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
repo.List(ctx, 0, ps) //nolint:errcheck
}
})
}
}
// ---------- BenchmarkRepo_ListByStatus ----------
func BenchmarkRepo_ListByStatus(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 1000)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
repo.ListByStatus(ctx, domain.UserStatusActive, 0, 20) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_UpdateStatus ----------
func BenchmarkRepo_UpdateStatus(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 200)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
id := int64(i%200) + 1
repo.UpdateStatus(ctx, id, domain.UserStatusActive) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_Update — 全字段更新 ----------
func BenchmarkRepo_Update(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 100)
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
id := int64(i%100) + 1
u, err := repo.GetByID(ctx, id)
if err != nil {
continue
}
u.Nickname = fmt.Sprintf("nick%d", i)
repo.Update(ctx, u) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_Delete — 软删除 ----------
func BenchmarkRepo_Delete(b *testing.B) {
ctx := context.Background()
b.ResetTimer()
for i := 0; i < b.N; i++ {
b.StopTimer()
db := openBenchDB(b)
repo := NewUserRepository(db)
repo.Create(ctx, &domain.User{Username: "victim", Password: "hash", Status: domain.UserStatusActive}) //nolint:errcheck
b.StartTimer()
repo.Delete(ctx, 1) //nolint:errcheck
}
}
// ---------- BenchmarkRepo_ExistsByUsername ----------
func BenchmarkRepo_ExistsByUsername(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 500)
ctx := context.Background()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := 0
for pb.Next() {
repo.ExistsByUsername(ctx, fmt.Sprintf("benchuser%06d", i%500)) //nolint:errcheck
i++
}
})
}
// ---------- BenchmarkRepo_ConcurrentReadWrite — 高并发读写混合 ----------
func BenchmarkRepo_ConcurrentReadWrite(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 200)
ctx := context.Background()
var mu sync.Mutex // SQLite 不支持多写并发,需要序列化写入
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
i := int64(1)
for pb.Next() {
if i%10 == 0 {
// 10% 写操作
mu.Lock()
repo.UpdateLastLogin(ctx, i%200+1, "10.0.0.1") //nolint:errcheck
mu.Unlock()
} else {
// 90% 读操作
repo.GetByID(ctx, i%200+1) //nolint:errcheck
}
i++
}
})
}
// ---------- BenchmarkRepo_Search — 模糊搜索 ----------
func BenchmarkRepo_Search(b *testing.B) {
db := openBenchDB(b)
repo := NewUserRepository(db)
seedUsers(b, repo, 2000)
ctx := context.Background()
b.ResetTimer()
keywords := []string{"benchuser000", "bench0001", "benchuser05"}
for i := 0; i < b.N; i++ {
repo.Search(ctx, keywords[i%len(keywords)], 0, 20) //nolint:errcheck
}
}

View File

@@ -0,0 +1,536 @@
// repo_robustness_test.go — repository 层鲁棒性测试
// 覆盖:重复主键、唯一索引冲突、大量数据分页正确性、
// SQL 注入防护(参数化查询验证)、软删除后查询、
// 空字符串/极值/特殊字符输入、上下文取消
package repository
import (
"context"
"fmt"
"strings"
"sync"
"testing"
"github.com/user-management-system/internal/domain"
)
// ============================================================
// 1. 唯一索引冲突
// ============================================================
func TestRepo_Robust_DuplicateUsername(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u1 := &domain.User{Username: "dupuser", Password: "hash", Status: domain.UserStatusActive}
if err := repo.Create(ctx, u1); err != nil {
t.Fatalf("第一次创建应成功: %v", err)
}
u2 := &domain.User{Username: "dupuser", Password: "hash2", Status: domain.UserStatusActive}
err := repo.Create(ctx, u2)
if err == nil {
t.Error("重复用户名应返回唯一索引冲突错误")
}
}
func TestRepo_Robust_DuplicateEmail(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
email := "dup@example.com"
repo.Create(ctx, &domain.User{Username: "user1", Email: domain.StrPtr(email), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
err := repo.Create(ctx, &domain.User{Username: "user2", Email: domain.StrPtr(email), Password: "h", Status: domain.UserStatusActive})
if err == nil {
t.Error("重复邮箱应返回唯一索引冲突错误")
}
}
func TestRepo_Robust_DuplicatePhone(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
phone := "13900000001"
repo.Create(ctx, &domain.User{Username: "pa", Phone: domain.StrPtr(phone), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
err := repo.Create(ctx, &domain.User{Username: "pb", Phone: domain.StrPtr(phone), Password: "h", Status: domain.UserStatusActive})
if err == nil {
t.Error("重复手机号应返回唯一索引冲突错误")
}
}
func TestRepo_Robust_MultipleNullEmail(t *testing.T) {
// NULL 不触发唯一约束,多个用户可以都没有邮箱
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
err := repo.Create(ctx, &domain.User{
Username: fmt.Sprintf("nomail%d", i),
Email: nil, // NULL
Password: "hash",
Status: domain.UserStatusActive,
})
if err != nil {
t.Fatalf("NULL email 用户%d 创建失败: %v", i, err)
}
}
}
// ============================================================
// 2. 查询不存在的记录
// ============================================================
func TestRepo_Robust_GetByID_NotFound(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
_, err := repo.GetByID(context.Background(), 99999)
if err == nil {
t.Error("查询不存在的 ID 应返回错误")
}
}
func TestRepo_Robust_GetByUsername_NotFound(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
_, err := repo.GetByUsername(context.Background(), "ghost")
if err == nil {
t.Error("查询不存在的用户名应返回错误")
}
}
func TestRepo_Robust_GetByEmail_NotFound(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
_, err := repo.GetByEmail(context.Background(), "nope@none.com")
if err == nil {
t.Error("查询不存在的邮箱应返回错误")
}
}
func TestRepo_Robust_GetByPhone_NotFound(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
_, err := repo.GetByPhone(context.Background(), "00000000000")
if err == nil {
t.Error("查询不存在的手机号应返回错误")
}
}
// ============================================================
// 3. 软删除后的查询行为
// ============================================================
func TestRepo_Robust_SoftDelete_HiddenFromGet(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u := &domain.User{Username: "softdel", Password: "h", Status: domain.UserStatusActive}
repo.Create(ctx, u) //nolint:errcheck
id := u.ID
if err := repo.Delete(ctx, id); err != nil {
t.Fatalf("Delete: %v", err)
}
_, err := repo.GetByID(ctx, id)
if err == nil {
t.Error("软删除后 GetByID 应返回错误(记录被隐藏)")
}
}
func TestRepo_Robust_SoftDelete_HiddenFromList(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("listdel%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
users, total, _ := repo.List(ctx, 0, 100)
initialCount := len(users)
initialTotal := total
// 删除第一个
repo.Delete(ctx, users[0].ID) //nolint:errcheck
users2, total2, _ := repo.List(ctx, 0, 100)
if len(users2) != initialCount-1 {
t.Errorf("删除后 List 应减少 1 条,实际 %d -> %d", initialCount, len(users2))
}
if total2 != initialTotal-1 {
t.Errorf("删除后 total 应减少 1实际 %d -> %d", initialTotal, total2)
}
}
func TestRepo_Robust_DeleteNonExistent(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
// 软删除一个不存在的 IDGORM 通常返回 nilRowsAffected=0 不报错)
err := repo.Delete(context.Background(), 99999)
_ = err // 不 panic 即可
}
// ============================================================
// 4. SQL 注入防护(参数化查询)
// ============================================================
func TestRepo_Robust_SQLInjection_GetByUsername(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
// 先插入一个真实用户
repo.Create(ctx, &domain.User{Username: "legit", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
// 注入载荷:尝试用 OR '1'='1' 绕过查询
injections := []string{
"' OR '1'='1",
"'; DROP TABLE users; --",
`" OR "1"="1`,
"admin'--",
"legit' UNION SELECT * FROM users --",
}
for _, payload := range injections {
_, err := repo.GetByUsername(ctx, payload)
if err == nil {
t.Errorf("SQL 注入载荷 %q 不应返回用户(应返回 not found", payload)
}
}
}
func TestRepo_Robust_SQLInjection_Search(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "victim", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
injections := []string{
"' OR '1'='1",
"%; SELECT * FROM users; --",
"victim' UNION SELECT username FROM users --",
}
for _, payload := range injections {
users, _, err := repo.Search(ctx, payload, 0, 100)
if err != nil {
continue // 参数化查询报错也可接受
}
for _, u := range users {
if u.Username == "victim" && !strings.Contains(payload, "victim") {
t.Errorf("SQL 注入载荷 %q 不应返回不匹配的用户", payload)
}
}
}
}
func TestRepo_Robust_SQLInjection_ExistsByUsername(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "realuser", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
// 这些载荷不应导致 ExistsByUsername("' OR '1'='1") 返回 true找到不存在的用户
exists, err := repo.ExistsByUsername(ctx, "' OR '1'='1")
if err != nil {
t.Logf("ExistsByUsername SQL注入: err=%v (可接受)", err)
return
}
if exists {
t.Error("SQL 注入载荷在 ExistsByUsername 中不应返回 true")
}
}
// ============================================================
// 5. 分页边界值
// ============================================================
func TestRepo_Robust_List_ZeroOffset(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("pg%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
users, total, err := repo.List(ctx, 0, 3)
if err != nil {
t.Fatalf("List: %v", err)
}
if len(users) != 3 {
t.Errorf("offset=0, limit=3 应返回 3 条,实际 %d", len(users))
}
if total != 5 {
t.Errorf("total 应为 5实际 %d", total)
}
}
func TestRepo_Robust_List_OffsetBeyondTotal(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 3; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("ov%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
users, total, err := repo.List(ctx, 100, 10)
if err != nil {
t.Fatalf("List: %v", err)
}
if len(users) != 0 {
t.Errorf("offset 超过总数应返回空列表,实际 %d 条", len(users))
}
if total != 3 {
t.Errorf("total 应为 3实际 %d", total)
}
}
func TestRepo_Robust_List_LargeLimit(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 10; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("ll%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
users, _, err := repo.List(ctx, 0, 999999)
if err != nil {
t.Fatalf("List with huge limit: %v", err)
}
if len(users) != 10 {
t.Errorf("超大 limit 应返回全部 10 条,实际 %d", len(users))
}
}
func TestRepo_Robust_List_EmptyDB(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
users, total, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("空 DB List 应无错误: %v", err)
}
if total != 0 {
t.Errorf("空 DB total 应为 0实际 %d", total)
}
if len(users) != 0 {
t.Errorf("空 DB 应返回空列表,实际 %d 条", len(users))
}
}
// ============================================================
// 6. 搜索边界值
// ============================================================
func TestRepo_Robust_Search_EmptyKeyword(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("sk%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
users, total, err := repo.Search(ctx, "", 0, 20)
// 空关键字 → LIKE '%%' 匹配所有;验证不报错
if err != nil {
t.Fatalf("空关键字 Search 应无错误: %v", err)
}
if total < 5 {
t.Errorf("空关键字应匹配所有用户(>=5实际 total=%drows=%d", total, len(users))
}
}
func TestRepo_Robust_Search_SpecialCharsKeyword(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "normaluser", Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
// 含 LIKE 元字符
for _, kw := range []string{"%", "_", "\\", "%_%", "%%"} {
_, _, err := repo.Search(ctx, kw, 0, 10)
if err != nil {
t.Logf("特殊关键字 %q 搜索出错(可接受): %v", kw, err)
}
// 主要验证不 panic
}
}
func TestRepo_Robust_Search_VeryLongKeyword(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
longKw := strings.Repeat("a", 10000)
_, _, err := repo.Search(ctx, longKw, 0, 10)
_ = err // 不应 panic
}
// ============================================================
// 7. 超长字段存储
// ============================================================
func TestRepo_Robust_LongFieldValues(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u := &domain.User{
Username: strings.Repeat("x", 45), // varchar(50) 以内
Password: strings.Repeat("y", 200),
Nickname: strings.Repeat("n", 45),
Status: domain.UserStatusActive,
}
err := repo.Create(ctx, u)
// SQLite 不严格限制 varchar 长度,期望成功;其他数据库可能截断或报错
if err != nil {
t.Logf("超长字段创建结果: %vSQLite 可能允许)", err)
}
}
// ============================================================
// 8. UpdateLastLogin 特殊 IP
// ============================================================
func TestRepo_Robust_UpdateLastLogin_EmptyIP(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u := &domain.User{Username: "iptest", Password: "h", Status: domain.UserStatusActive}
repo.Create(ctx, u) //nolint:errcheck
// 空 IP 不应报错
if err := repo.UpdateLastLogin(ctx, u.ID, ""); err != nil {
t.Errorf("空 IP UpdateLastLogin 应无错误: %v", err)
}
}
func TestRepo_Robust_UpdateLastLogin_LongIP(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
u := &domain.User{Username: "longiptest", Password: "h", Status: domain.UserStatusActive}
repo.Create(ctx, u) //nolint:errcheck
longIP := strings.Repeat("1", 500)
err := repo.UpdateLastLogin(ctx, u.ID, longIP)
_ = err // SQLite 宽容,不 panic 即可
}
// ============================================================
// 9. 并发写入安全SQLite 序列化写入)
// ============================================================
func TestRepo_Robust_ConcurrentCreate_NoDeadlock(t *testing.T) {
db := openTestDB(t)
// 启用 WAL 模式可减少锁冲突,这里使用默认设置
repo := NewUserRepository(db)
ctx := context.Background()
const goroutines = 20
var wg sync.WaitGroup
var mu sync.Mutex // SQLite 只允许单写,用互斥锁序列化
errorCount := 0
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
mu.Lock()
defer mu.Unlock()
err := repo.Create(ctx, &domain.User{
Username: fmt.Sprintf("concurrent_%d", idx),
Password: "hash",
Status: domain.UserStatusActive,
})
if err != nil {
errorCount++
}
}(i)
}
wg.Wait()
if errorCount > 0 {
t.Errorf("序列化并发写入:%d/%d 次失败", errorCount, goroutines)
}
}
func TestRepo_Robust_ConcurrentReadWrite_NoDataRace(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
// 预先插入数据
for i := 0; i < 10; i++ {
repo.Create(ctx, &domain.User{Username: fmt.Sprintf("rw%d", i), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
}
var wg sync.WaitGroup
var writeMu sync.Mutex
for i := 0; i < 30; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
if idx%5 == 0 {
writeMu.Lock()
repo.UpdateStatus(ctx, int64(idx%10)+1, domain.UserStatusActive) //nolint:errcheck
writeMu.Unlock()
} else {
repo.GetByID(ctx, int64(idx%10)+1) //nolint:errcheck
}
}(i)
}
wg.Wait()
// 无 panic / 数据竞争即通过
}
// ============================================================
// 10. Exists 方法边界
// ============================================================
func TestRepo_Robust_ExistsByUsername_EmptyString(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
// 查询空字符串用户名,不应 panic
exists, err := repo.ExistsByUsername(context.Background(), "")
if err != nil {
t.Logf("ExistsByUsername('') err: %v", err)
}
_ = exists
}
func TestRepo_Robust_ExistsByEmail_NilEquivalent(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
// 查询空邮箱
exists, err := repo.ExistsByEmail(context.Background(), "")
_ = err
_ = exists
}
func TestRepo_Robust_ExistsByPhone_SQLInjection(t *testing.T) {
db := openTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
repo.Create(ctx, &domain.User{Username: "phoneuser", Phone: domain.StrPtr("13900000001"), Password: "h", Status: domain.UserStatusActive}) //nolint:errcheck
exists, err := repo.ExistsByPhone(ctx, "' OR '1'='1")
if err != nil {
t.Logf("ExistsByPhone SQL注入 err: %v", err)
return
}
if exists {
t.Error("SQL 注入载荷在 ExistsByPhone 中不应返回 true")
}
}

View File

@@ -0,0 +1,466 @@
package repository
import (
"context"
"testing"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
func migrateRepositoryTables(t *testing.T, db *gorm.DB, tables ...interface{}) {
t.Helper()
if err := db.AutoMigrate(tables...); err != nil {
t.Fatalf("migrate repository tables failed: %v", err)
}
}
func int64Ptr(v int64) *int64 {
return &v
}
func TestDeviceRepositoryLifecycleAndQueries(t *testing.T) {
db := openTestDB(t)
migrateRepositoryTables(t, db, &domain.Device{})
repo := NewDeviceRepository(db)
ctx := context.Background()
now := time.Now().UTC()
devices := []*domain.Device{
{
UserID: 1,
DeviceID: "device-alpha",
DeviceName: "Alpha",
DeviceType: domain.DeviceTypeDesktop,
DeviceOS: "Windows",
DeviceBrowser: "Chrome",
IP: "10.0.0.1",
Location: "Shanghai",
Status: domain.DeviceStatusActive,
LastActiveTime: now.Add(-1 * time.Hour),
},
{
UserID: 1,
DeviceID: "device-beta",
DeviceName: "Beta",
DeviceType: domain.DeviceTypeWeb,
DeviceOS: "macOS",
DeviceBrowser: "Safari",
IP: "10.0.0.2",
Location: "Hangzhou",
Status: domain.DeviceStatusInactive,
LastActiveTime: now.Add(-2 * time.Hour),
},
{
UserID: 2,
DeviceID: "device-gamma",
DeviceName: "Gamma",
DeviceType: domain.DeviceTypeMobile,
DeviceOS: "Android",
DeviceBrowser: "WebView",
IP: "10.0.0.3",
Location: "Beijing",
Status: domain.DeviceStatusActive,
LastActiveTime: now.Add(-40 * 24 * time.Hour),
},
}
for _, device := range devices {
if err := repo.Create(ctx, device); err != nil {
t.Fatalf("Create(%s) failed: %v", device.DeviceID, err)
}
}
if allDevices, total, err := repo.List(ctx, 0, 10); err != nil {
t.Fatalf("List failed: %v", err)
} else if total != 3 || len(allDevices) != 3 {
t.Fatalf("expected 3 devices, got total=%d len=%d", total, len(allDevices))
}
loadedByDeviceID, err := repo.GetByDeviceID(ctx, 1, "device-beta")
if err != nil {
t.Fatalf("GetByDeviceID failed: %v", err)
}
if loadedByDeviceID.DeviceName != "Beta" {
t.Fatalf("expected device name Beta, got %q", loadedByDeviceID.DeviceName)
}
exists, err := repo.Exists(ctx, 1, "device-alpha")
if err != nil {
t.Fatalf("Exists(device-alpha) failed: %v", err)
}
if !exists {
t.Fatal("expected device-alpha to exist")
}
missing, err := repo.Exists(ctx, 1, "missing-device")
if err != nil {
t.Fatalf("Exists(missing-device) failed: %v", err)
}
if missing {
t.Fatal("expected missing-device to be absent")
}
userDevices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByUserID failed: %v", err)
}
if total != 2 || len(userDevices) != 2 {
t.Fatalf("expected 2 devices for user 1, got total=%d len=%d", total, len(userDevices))
}
if userDevices[0].DeviceID != "device-alpha" {
t.Fatalf("expected latest active device first, got %q", userDevices[0].DeviceID)
}
activeDevices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
if err != nil {
t.Fatalf("ListByStatus failed: %v", err)
}
if total != 2 || len(activeDevices) != 2 {
t.Fatalf("expected 2 active devices, got total=%d len=%d", total, len(activeDevices))
}
if err := repo.UpdateStatus(ctx, devices[1].ID, domain.DeviceStatusActive); err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
beforeTouch, err := repo.GetByID(ctx, devices[1].ID)
if err != nil {
t.Fatalf("GetByID before UpdateLastActiveTime failed: %v", err)
}
time.Sleep(10 * time.Millisecond)
if err := repo.UpdateLastActiveTime(ctx, devices[1].ID); err != nil {
t.Fatalf("UpdateLastActiveTime failed: %v", err)
}
afterTouch, err := repo.GetByID(ctx, devices[1].ID)
if err != nil {
t.Fatalf("GetByID after UpdateLastActiveTime failed: %v", err)
}
if !afterTouch.LastActiveTime.After(beforeTouch.LastActiveTime) {
t.Fatal("expected last_active_time to move forward")
}
recentDevices, err := repo.GetActiveDevices(ctx, 1)
if err != nil {
t.Fatalf("GetActiveDevices failed: %v", err)
}
if len(recentDevices) != 2 {
t.Fatalf("expected 2 recent devices for user 1, got %d", len(recentDevices))
}
if err := repo.DeleteByUserID(ctx, 1); err != nil {
t.Fatalf("DeleteByUserID failed: %v", err)
}
remainingDevices, remainingTotal, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List after DeleteByUserID failed: %v", err)
}
if remainingTotal != 1 || len(remainingDevices) != 1 {
t.Fatalf("expected 1 remaining device, got total=%d len=%d", remainingTotal, len(remainingDevices))
}
if err := repo.Delete(ctx, devices[2].ID); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := repo.GetByID(ctx, devices[2].ID); err == nil {
t.Fatal("expected deleted device lookup to fail")
}
}
func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) {
db := openTestDB(t)
migrateRepositoryTables(t, db, &domain.LoginLog{})
repo := NewLoginLogRepository(db)
ctx := context.Background()
now := time.Now().UTC()
logs := []*domain.LoginLog{
{
UserID: int64Ptr(1),
LoginType: int(domain.LoginTypePassword),
DeviceID: "device-alpha",
IP: "10.0.0.1",
Location: "Shanghai",
Status: 1,
CreatedAt: now.Add(-1 * time.Hour),
},
{
UserID: int64Ptr(1),
LoginType: int(domain.LoginTypeSMSCode),
DeviceID: "device-beta",
IP: "10.0.0.2",
Location: "Hangzhou",
Status: 0,
FailReason: "code expired",
CreatedAt: now.Add(-30 * time.Minute),
},
{
UserID: int64Ptr(2),
LoginType: int(domain.LoginTypeOAuth),
DeviceID: "device-gamma",
IP: "10.0.0.3",
Location: "Beijing",
Status: 1,
CreatedAt: now.Add(-45 * 24 * time.Hour),
},
}
for _, log := range logs {
if err := repo.Create(ctx, log); err != nil {
t.Fatalf("Create login log failed: %v", err)
}
}
loaded, err := repo.GetByID(ctx, logs[0].ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if loaded.DeviceID != "device-alpha" {
t.Fatalf("expected device-alpha, got %q", loaded.DeviceID)
}
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByUserID failed: %v", err)
}
if total != 2 || len(userLogs) != 2 {
t.Fatalf("expected 2 user logs, got total=%d len=%d", total, len(userLogs))
}
if userLogs[0].DeviceID != "device-beta" {
t.Fatalf("expected newest login log first, got %q", userLogs[0].DeviceID)
}
allLogs, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if total != 3 || len(allLogs) != 3 {
t.Fatalf("expected 3 total logs, got total=%d len=%d", total, len(allLogs))
}
successLogs, total, err := repo.ListByStatus(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByStatus failed: %v", err)
}
if total != 2 || len(successLogs) != 2 {
t.Fatalf("expected 2 success logs, got total=%d len=%d", total, len(successLogs))
}
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-2*time.Hour), now, 0, 10)
if err != nil {
t.Fatalf("ListByTimeRange failed: %v", err)
}
if total != 2 || len(recentLogs) != 2 {
t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs))
}
if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 {
t.Fatalf("expected 1 recent success login, got %d", count)
}
if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 {
t.Fatalf("expected 1 recent failed login, got %d", count)
}
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
t.Fatalf("DeleteOlderThan failed: %v", err)
}
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List after DeleteOlderThan failed: %v", err)
}
if retainedTotal != 2 || len(retainedLogs) != 2 {
t.Fatalf("expected 2 retained logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
}
if err := repo.DeleteByUserID(ctx, 1); err != nil {
t.Fatalf("DeleteByUserID failed: %v", err)
}
finalLogs, finalTotal, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List after DeleteByUserID failed: %v", err)
}
if finalTotal != 0 || len(finalLogs) != 0 {
t.Fatalf("expected all logs removed, got total=%d len=%d", finalTotal, len(finalLogs))
}
}
func TestPasswordHistoryRepositoryKeepsNewestRecords(t *testing.T) {
db := openTestDB(t)
migrateRepositoryTables(t, db, &domain.PasswordHistory{})
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
now := time.Now().UTC()
histories := []*domain.PasswordHistory{
{UserID: 1, PasswordHash: "hash-1", CreatedAt: now.Add(-4 * time.Hour)},
{UserID: 1, PasswordHash: "hash-2", CreatedAt: now.Add(-3 * time.Hour)},
{UserID: 1, PasswordHash: "hash-3", CreatedAt: now.Add(-2 * time.Hour)},
{UserID: 1, PasswordHash: "hash-4", CreatedAt: now.Add(-1 * time.Hour)},
{UserID: 2, PasswordHash: "hash-foreign", CreatedAt: now.Add(-30 * time.Minute)},
}
for _, history := range histories {
if err := repo.Create(ctx, history); err != nil {
t.Fatalf("Create password history failed: %v", err)
}
}
latestTwo, err := repo.GetByUserID(ctx, 1, 2)
if err != nil {
t.Fatalf("GetByUserID(limit=2) failed: %v", err)
}
if len(latestTwo) != 2 {
t.Fatalf("expected 2 latest password histories, got %d", len(latestTwo))
}
if latestTwo[0].PasswordHash != "hash-4" || latestTwo[1].PasswordHash != "hash-3" {
t.Fatalf("expected newest password hashes to be retained, got %q and %q", latestTwo[0].PasswordHash, latestTwo[1].PasswordHash)
}
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
t.Fatalf("DeleteOldRecords failed: %v", err)
}
remainingHistories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("GetByUserID after DeleteOldRecords failed: %v", err)
}
if len(remainingHistories) != 2 {
t.Fatalf("expected 2 remaining histories, got %d", len(remainingHistories))
}
if remainingHistories[0].PasswordHash != "hash-4" || remainingHistories[1].PasswordHash != "hash-3" {
t.Fatalf("unexpected remaining password hashes: %q and %q", remainingHistories[0].PasswordHash, remainingHistories[1].PasswordHash)
}
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
t.Fatalf("DeleteOldRecords for missing user failed: %v", err)
}
}
func TestOperationLogRepositorySearchAndRetention(t *testing.T) {
db := openTestDB(t)
migrateRepositoryTables(t, db, &domain.OperationLog{})
repo := NewOperationLogRepository(db)
ctx := context.Background()
now := time.Now().UTC()
logs := []*domain.OperationLog{
{
UserID: int64Ptr(1),
OperationType: "user",
OperationName: "create user",
RequestMethod: "POST",
RequestPath: "/api/v1/users",
RequestParams: `{"username":"alice"}`,
ResponseStatus: 201,
IP: "10.0.0.1",
UserAgent: "Chrome",
CreatedAt: now.Add(-20 * time.Minute),
},
{
UserID: int64Ptr(1),
OperationType: "dashboard",
OperationName: "view dashboard",
RequestMethod: "GET",
RequestPath: "/dashboard",
RequestParams: "{}",
ResponseStatus: 200,
IP: "10.0.0.2",
UserAgent: "Chrome",
CreatedAt: now.Add(-10 * time.Minute),
},
{
UserID: int64Ptr(2),
OperationType: "user",
OperationName: "delete user",
RequestMethod: "DELETE",
RequestPath: "/api/v1/users/7",
RequestParams: "{}",
ResponseStatus: 204,
IP: "10.0.0.3",
UserAgent: "Firefox",
CreatedAt: now.Add(-40 * 24 * time.Hour),
},
}
for _, log := range logs {
if err := repo.Create(ctx, log); err != nil {
t.Fatalf("Create operation log failed: %v", err)
}
}
loaded, err := repo.GetByID(ctx, logs[0].ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if loaded.OperationName != "create user" {
t.Fatalf("expected create user log, got %q", loaded.OperationName)
}
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
if err != nil {
t.Fatalf("ListByUserID failed: %v", err)
}
if total != 2 || len(userLogs) != 2 {
t.Fatalf("expected 2 user operation logs, got total=%d len=%d", total, len(userLogs))
}
if userLogs[0].OperationName != "view dashboard" {
t.Fatalf("expected newest operation log first, got %q", userLogs[0].OperationName)
}
allLogs, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if total != 3 || len(allLogs) != 3 {
t.Fatalf("expected 3 total operation logs, got total=%d len=%d", total, len(allLogs))
}
postLogs, total, err := repo.ListByMethod(ctx, "POST", 0, 10)
if err != nil {
t.Fatalf("ListByMethod failed: %v", err)
}
if total != 1 || len(postLogs) != 1 || postLogs[0].OperationName != "create user" {
t.Fatalf("expected a single POST operation log, got total=%d len=%d", total, len(postLogs))
}
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-1*time.Hour), now, 0, 10)
if err != nil {
t.Fatalf("ListByTimeRange failed: %v", err)
}
if total != 2 || len(recentLogs) != 2 {
t.Fatalf("expected 2 recent operation logs, got total=%d len=%d", total, len(recentLogs))
}
searchResults, total, err := repo.Search(ctx, "user", 0, 10)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if total != 2 || len(searchResults) != 2 {
t.Fatalf("expected 2 operation logs matching user, got total=%d len=%d", total, len(searchResults))
}
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
t.Fatalf("DeleteOlderThan failed: %v", err)
}
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List after DeleteOlderThan failed: %v", err)
}
if retainedTotal != 2 || len(retainedLogs) != 2 {
t.Fatalf("expected 2 retained operation logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
}
}

View File

@@ -0,0 +1,603 @@
package repository
import (
"context"
"testing"
"github.com/user-management-system/internal/domain"
)
func containsInt64(values []int64, target int64) bool {
for _, value := range values {
if value == target {
return true
}
}
return false
}
func TestRoleRepositoryLifecycleAndQueries(t *testing.T) {
db := openTestDB(t)
repo := NewRoleRepository(db)
ctx := context.Background()
admin := &domain.Role{
Name: "Admin Test",
Code: "admin-test",
Description: "root role",
Level: 1,
IsSystem: true,
Status: domain.RoleStatusEnabled,
}
if err := repo.Create(ctx, admin); err != nil {
t.Fatalf("Create(admin) failed: %v", err)
}
parentID := admin.ID
auditor := &domain.Role{
Name: "Auditor Test",
Code: "auditor-test",
Description: "audit role",
ParentID: &parentID,
Level: 2,
IsDefault: true,
Status: domain.RoleStatusDisabled,
}
viewer := &domain.Role{
Name: "Viewer Test",
Code: "viewer-test",
Description: "view role",
Level: 1,
Status: domain.RoleStatusEnabled,
}
for _, role := range []*domain.Role{auditor, viewer} {
if err := repo.Create(ctx, role); err != nil {
t.Fatalf("Create(%s) failed: %v", role.Code, err)
}
}
loadedByID, err := repo.GetByID(ctx, admin.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if loadedByID.Code != "admin-test" {
t.Fatalf("expected admin-test, got %q", loadedByID.Code)
}
loadedByCode, err := repo.GetByCode(ctx, "auditor-test")
if err != nil {
t.Fatalf("GetByCode failed: %v", err)
}
if loadedByCode.ID != auditor.ID {
t.Fatalf("expected auditor id %d, got %d", auditor.ID, loadedByCode.ID)
}
allRoles, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if total != 3 || len(allRoles) != 3 {
t.Fatalf("expected 3 roles, got total=%d len=%d", total, len(allRoles))
}
enabledRoles, total, err := repo.ListByStatus(ctx, domain.RoleStatusEnabled, 0, 10)
if err != nil {
t.Fatalf("ListByStatus failed: %v", err)
}
if total != 2 || len(enabledRoles) != 2 {
t.Fatalf("expected 2 enabled roles, got total=%d len=%d", total, len(enabledRoles))
}
defaultRoles, err := repo.GetDefaultRoles(ctx)
if err != nil {
t.Fatalf("GetDefaultRoles failed: %v", err)
}
if len(defaultRoles) != 1 || defaultRoles[0].ID != auditor.ID {
t.Fatalf("expected auditor as default role, got %+v", defaultRoles)
}
exists, err := repo.ExistsByCode(ctx, "viewer-test")
if err != nil {
t.Fatalf("ExistsByCode(viewer-test) failed: %v", err)
}
if !exists {
t.Fatal("expected viewer-test to exist")
}
missing, err := repo.ExistsByCode(ctx, "missing-role")
if err != nil {
t.Fatalf("ExistsByCode(missing-role) failed: %v", err)
}
if missing {
t.Fatal("expected missing-role to be absent")
}
auditor.Description = "audit role updated"
if err := repo.Update(ctx, auditor); err != nil {
t.Fatalf("Update failed: %v", err)
}
if err := repo.UpdateStatus(ctx, auditor.ID, domain.RoleStatusEnabled); err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
searchResults, total, err := repo.Search(ctx, "audit", 0, 10)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if total != 1 || len(searchResults) != 1 || searchResults[0].ID != auditor.ID {
t.Fatalf("expected auditor search hit, got total=%d len=%d", total, len(searchResults))
}
childRoles, err := repo.ListByParentID(ctx, admin.ID)
if err != nil {
t.Fatalf("ListByParentID failed: %v", err)
}
if len(childRoles) != 1 || childRoles[0].ID != auditor.ID {
t.Fatalf("expected auditor child role, got %+v", childRoles)
}
roleSubset, err := repo.GetByIDs(ctx, []int64{admin.ID, auditor.ID})
if err != nil {
t.Fatalf("GetByIDs failed: %v", err)
}
if len(roleSubset) != 2 {
t.Fatalf("expected 2 roles from GetByIDs, got %d", len(roleSubset))
}
emptySubset, err := repo.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs(empty) failed: %v", err)
}
if len(emptySubset) != 0 {
t.Fatalf("expected empty slice for GetByIDs(empty), got %d", len(emptySubset))
}
if err := repo.Delete(ctx, viewer.ID); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := repo.GetByID(ctx, viewer.ID); err == nil {
t.Fatal("expected deleted role lookup to fail")
}
}
func TestPermissionRepositoryLifecycleAndQueries(t *testing.T) {
db := openTestDB(t)
repo := NewPermissionRepository(db)
ctx := context.Background()
parent := &domain.Permission{
Name: "Dashboard",
Code: "dashboard:view",
Type: domain.PermissionTypeMenu,
Description: "dashboard menu",
Path: "/dashboard",
Sort: 1,
Status: domain.PermissionStatusEnabled,
}
if err := repo.Create(ctx, parent); err != nil {
t.Fatalf("Create(parent) failed: %v", err)
}
parentID := parent.ID
apiPermission := &domain.Permission{
Name: "Audit API",
Code: "audit:read",
Type: domain.PermissionTypeAPI,
Description: "audit api",
ParentID: &parentID,
Path: "/api/audit",
Method: "GET",
Sort: 2,
Status: domain.PermissionStatusDisabled,
}
buttonPermission := &domain.Permission{
Name: "Audit Button",
Code: "audit:button",
Type: domain.PermissionTypeButton,
Description: "audit action",
Sort: 3,
Status: domain.PermissionStatusEnabled,
}
for _, permission := range []*domain.Permission{apiPermission, buttonPermission} {
if err := repo.Create(ctx, permission); err != nil {
t.Fatalf("Create(%s) failed: %v", permission.Code, err)
}
}
role := &domain.Role{
Name: "Permission Role",
Code: "permission-role",
Description: "role for permission join queries",
Status: domain.RoleStatusEnabled,
}
if err := db.WithContext(ctx).Create(role).Error; err != nil {
t.Fatalf("create role for permission joins failed: %v", err)
}
for _, rolePermission := range []*domain.RolePermission{
{RoleID: role.ID, PermissionID: parent.ID},
{RoleID: role.ID, PermissionID: apiPermission.ID},
} {
if err := db.WithContext(ctx).Create(rolePermission).Error; err != nil {
t.Fatalf("create role_permission failed: %v", err)
}
}
loadedByID, err := repo.GetByID(ctx, parent.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if loadedByID.Code != "dashboard:view" {
t.Fatalf("expected dashboard:view, got %q", loadedByID.Code)
}
loadedByCode, err := repo.GetByCode(ctx, "audit:read")
if err != nil {
t.Fatalf("GetByCode failed: %v", err)
}
if loadedByCode.ID != apiPermission.ID {
t.Fatalf("expected audit:read id %d, got %d", apiPermission.ID, loadedByCode.ID)
}
allPermissions, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List failed: %v", err)
}
if total != 3 || len(allPermissions) != 3 {
t.Fatalf("expected 3 permissions, got total=%d len=%d", total, len(allPermissions))
}
apiPermissions, total, err := repo.ListByType(ctx, domain.PermissionTypeAPI, 0, 10)
if err != nil {
t.Fatalf("ListByType failed: %v", err)
}
if total != 1 || len(apiPermissions) != 1 || apiPermissions[0].ID != apiPermission.ID {
t.Fatalf("expected audit api permission, got total=%d len=%d", total, len(apiPermissions))
}
enabledPermissions, total, err := repo.ListByStatus(ctx, domain.PermissionStatusEnabled, 0, 10)
if err != nil {
t.Fatalf("ListByStatus failed: %v", err)
}
if total != 2 || len(enabledPermissions) != 2 {
t.Fatalf("expected 2 enabled permissions, got total=%d len=%d", total, len(enabledPermissions))
}
rolePermissions, err := repo.GetByRoleIDs(ctx, []int64{role.ID})
if err != nil {
t.Fatalf("GetByRoleIDs failed: %v", err)
}
if len(rolePermissions) != 1 || rolePermissions[0].ID != parent.ID {
t.Fatalf("expected only enabled parent permission in join query, got %+v", rolePermissions)
}
exists, err := repo.ExistsByCode(ctx, "audit:button")
if err != nil {
t.Fatalf("ExistsByCode(audit:button) failed: %v", err)
}
if !exists {
t.Fatal("expected audit:button to exist")
}
missing, err := repo.ExistsByCode(ctx, "permission:missing")
if err != nil {
t.Fatalf("ExistsByCode(missing) failed: %v", err)
}
if missing {
t.Fatal("expected permission:missing to be absent")
}
apiPermission.Description = "audit api updated"
if err := repo.Update(ctx, apiPermission); err != nil {
t.Fatalf("Update failed: %v", err)
}
if err := repo.UpdateStatus(ctx, apiPermission.ID, domain.PermissionStatusEnabled); err != nil {
t.Fatalf("UpdateStatus failed: %v", err)
}
searchResults, total, err := repo.Search(ctx, "audit", 0, 10)
if err != nil {
t.Fatalf("Search failed: %v", err)
}
if total != 2 || len(searchResults) != 2 {
t.Fatalf("expected 2 audit-related permissions, got total=%d len=%d", total, len(searchResults))
}
childPermissions, err := repo.ListByParentID(ctx, parent.ID)
if err != nil {
t.Fatalf("ListByParentID failed: %v", err)
}
if len(childPermissions) != 1 || childPermissions[0].ID != apiPermission.ID {
t.Fatalf("expected api permission child, got %+v", childPermissions)
}
permissionSubset, err := repo.GetByIDs(ctx, []int64{parent.ID, apiPermission.ID})
if err != nil {
t.Fatalf("GetByIDs failed: %v", err)
}
if len(permissionSubset) != 2 {
t.Fatalf("expected 2 permissions from GetByIDs, got %d", len(permissionSubset))
}
emptySubset, err := repo.GetByIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetByIDs(empty) failed: %v", err)
}
if len(emptySubset) != 0 {
t.Fatalf("expected empty slice for GetByIDs(empty), got %d", len(emptySubset))
}
if err := repo.Delete(ctx, buttonPermission.ID); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := repo.GetByID(ctx, buttonPermission.ID); err == nil {
t.Fatal("expected deleted permission lookup to fail")
}
}
func TestUserRoleAndRolePermissionRepositoriesLifecycle(t *testing.T) {
db := openTestDB(t)
userRoleRepo := NewUserRoleRepository(db)
rolePermissionRepo := NewRolePermissionRepository(db)
ctx := context.Background()
users := []*domain.User{
{Username: "repo-user-1", Password: "hash", Status: domain.UserStatusActive},
{Username: "repo-user-2", Password: "hash", Status: domain.UserStatusActive},
}
for _, user := range users {
if err := db.WithContext(ctx).Create(user).Error; err != nil {
t.Fatalf("create user failed: %v", err)
}
}
roles := []*domain.Role{
{Name: "Repo Role 1", Code: "repo-role-1", Status: domain.RoleStatusEnabled},
{Name: "Repo Role 2", Code: "repo-role-2", Status: domain.RoleStatusEnabled},
}
for _, role := range roles {
if err := db.WithContext(ctx).Create(role).Error; err != nil {
t.Fatalf("create role failed: %v", err)
}
}
permissions := []*domain.Permission{
{Name: "Repo Permission 1", Code: "repo:permission:1", Type: domain.PermissionTypeAPI, Status: domain.PermissionStatusEnabled},
{Name: "Repo Permission 2", Code: "repo:permission:2", Type: domain.PermissionTypeAPI, Status: domain.PermissionStatusEnabled},
}
for _, permission := range permissions {
if err := db.WithContext(ctx).Create(permission).Error; err != nil {
t.Fatalf("create permission failed: %v", err)
}
}
userRolePrimary := &domain.UserRole{UserID: users[0].ID, RoleID: roles[0].ID}
if err := userRoleRepo.Create(ctx, userRolePrimary); err != nil {
t.Fatalf("UserRole Create failed: %v", err)
}
if err := userRoleRepo.BatchCreate(ctx, []*domain.UserRole{}); err != nil {
t.Fatalf("UserRole BatchCreate(empty) failed: %v", err)
}
userRoleBatch := []*domain.UserRole{
{UserID: users[0].ID, RoleID: roles[1].ID},
{UserID: users[1].ID, RoleID: roles[0].ID},
}
if err := userRoleRepo.BatchCreate(ctx, userRoleBatch); err != nil {
t.Fatalf("UserRole BatchCreate failed: %v", err)
}
exists, err := userRoleRepo.Exists(ctx, users[0].ID, roles[0].ID)
if err != nil {
t.Fatalf("UserRole Exists failed: %v", err)
}
if !exists {
t.Fatal("expected primary user-role relation to exist")
}
missing, err := userRoleRepo.Exists(ctx, users[1].ID, roles[1].ID)
if err != nil {
t.Fatalf("UserRole Exists(missing) failed: %v", err)
}
if missing {
t.Fatal("expected missing user-role relation to be absent")
}
rolesForUserOne, err := userRoleRepo.GetByUserID(ctx, users[0].ID)
if err != nil {
t.Fatalf("GetByUserID failed: %v", err)
}
if len(rolesForUserOne) != 2 {
t.Fatalf("expected 2 roles for user one, got %d", len(rolesForUserOne))
}
usersForRoleOne, err := userRoleRepo.GetByRoleID(ctx, roles[0].ID)
if err != nil {
t.Fatalf("GetByRoleID failed: %v", err)
}
if len(usersForRoleOne) != 2 {
t.Fatalf("expected 2 users for role one, got %d", len(usersForRoleOne))
}
roleIDs, err := userRoleRepo.GetRoleIDsByUserID(ctx, users[0].ID)
if err != nil {
t.Fatalf("GetRoleIDsByUserID failed: %v", err)
}
if len(roleIDs) != 2 || !containsInt64(roleIDs, roles[0].ID) || !containsInt64(roleIDs, roles[1].ID) {
t.Fatalf("unexpected role IDs for user one: %+v", roleIDs)
}
userIDs, err := userRoleRepo.GetUserIDByRoleID(ctx, roles[0].ID)
if err != nil {
t.Fatalf("GetUserIDByRoleID failed: %v", err)
}
if len(userIDs) != 2 || !containsInt64(userIDs, users[0].ID) || !containsInt64(userIDs, users[1].ID) {
t.Fatalf("unexpected user IDs for role one: %+v", userIDs)
}
if err := userRoleRepo.BatchDelete(ctx, []*domain.UserRole{}); err != nil {
t.Fatalf("UserRole BatchDelete(empty) failed: %v", err)
}
if err := userRoleRepo.BatchDelete(ctx, []*domain.UserRole{userRoleBatch[0]}); err != nil {
t.Fatalf("UserRole BatchDelete failed: %v", err)
}
if err := userRoleRepo.Delete(ctx, userRolePrimary.ID); err != nil {
t.Fatalf("UserRole Delete failed: %v", err)
}
existsAfterDelete, err := userRoleRepo.Exists(ctx, users[0].ID, roles[0].ID)
if err != nil {
t.Fatalf("UserRole Exists after Delete failed: %v", err)
}
if existsAfterDelete {
t.Fatal("expected primary user-role relation to be removed")
}
if err := userRoleRepo.DeleteByUserID(ctx, users[1].ID); err != nil {
t.Fatalf("DeleteByUserID failed: %v", err)
}
if err := userRoleRepo.Create(ctx, &domain.UserRole{UserID: users[0].ID, RoleID: roles[1].ID}); err != nil {
t.Fatalf("recreate user-role failed: %v", err)
}
if err := userRoleRepo.DeleteByRoleID(ctx, roles[1].ID); err != nil {
t.Fatalf("DeleteByRoleID failed: %v", err)
}
remainingUserRoles, err := userRoleRepo.GetByRoleID(ctx, roles[1].ID)
if err != nil {
t.Fatalf("GetByRoleID after DeleteByRoleID failed: %v", err)
}
if len(remainingUserRoles) != 0 {
t.Fatalf("expected no user-role relations for role two, got %d", len(remainingUserRoles))
}
rolePermissionPrimary := &domain.RolePermission{RoleID: roles[0].ID, PermissionID: permissions[0].ID}
if err := rolePermissionRepo.Create(ctx, rolePermissionPrimary); err != nil {
t.Fatalf("RolePermission Create failed: %v", err)
}
if err := rolePermissionRepo.BatchCreate(ctx, []*domain.RolePermission{}); err != nil {
t.Fatalf("RolePermission BatchCreate(empty) failed: %v", err)
}
rolePermissionBatch := []*domain.RolePermission{
{RoleID: roles[0].ID, PermissionID: permissions[1].ID},
{RoleID: roles[1].ID, PermissionID: permissions[0].ID},
}
if err := rolePermissionRepo.BatchCreate(ctx, rolePermissionBatch); err != nil {
t.Fatalf("RolePermission BatchCreate failed: %v", err)
}
rpExists, err := rolePermissionRepo.Exists(ctx, roles[0].ID, permissions[0].ID)
if err != nil {
t.Fatalf("RolePermission Exists failed: %v", err)
}
if !rpExists {
t.Fatal("expected primary role-permission relation to exist")
}
rpMissing, err := rolePermissionRepo.Exists(ctx, roles[1].ID, permissions[1].ID)
if err != nil {
t.Fatalf("RolePermission Exists(missing) failed: %v", err)
}
if rpMissing {
t.Fatal("expected missing role-permission relation to be absent")
}
permissionsForRoleOne, err := rolePermissionRepo.GetByRoleID(ctx, roles[0].ID)
if err != nil {
t.Fatalf("GetByRoleID failed: %v", err)
}
if len(permissionsForRoleOne) != 2 {
t.Fatalf("expected 2 permissions for role one, got %d", len(permissionsForRoleOne))
}
rolesForPermissionOne, err := rolePermissionRepo.GetByPermissionID(ctx, permissions[0].ID)
if err != nil {
t.Fatalf("GetByPermissionID failed: %v", err)
}
if len(rolesForPermissionOne) != 2 {
t.Fatalf("expected 2 roles for permission one, got %d", len(rolesForPermissionOne))
}
permissionIDs, err := rolePermissionRepo.GetPermissionIDsByRoleID(ctx, roles[0].ID)
if err != nil {
t.Fatalf("GetPermissionIDsByRoleID failed: %v", err)
}
if len(permissionIDs) != 2 || !containsInt64(permissionIDs, permissions[0].ID) || !containsInt64(permissionIDs, permissions[1].ID) {
t.Fatalf("unexpected permission IDs for role one: %+v", permissionIDs)
}
roleIDsByPermission, err := rolePermissionRepo.GetRoleIDByPermissionID(ctx, permissions[0].ID)
if err != nil {
t.Fatalf("GetRoleIDByPermissionID failed: %v", err)
}
if len(roleIDsByPermission) != 2 || !containsInt64(roleIDsByPermission, roles[0].ID) || !containsInt64(roleIDsByPermission, roles[1].ID) {
t.Fatalf("unexpected role IDs for permission one: %+v", roleIDsByPermission)
}
loadedPermission, err := rolePermissionRepo.GetPermissionByID(ctx, permissions[0].ID)
if err != nil {
t.Fatalf("GetPermissionByID failed: %v", err)
}
if loadedPermission.Code != "repo:permission:1" {
t.Fatalf("expected repo:permission:1, got %q", loadedPermission.Code)
}
permissionIDsByRoleIDs, err := rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, []int64{roles[0].ID, roles[1].ID})
if err != nil {
t.Fatalf("GetPermissionIDsByRoleIDs failed: %v", err)
}
if len(permissionIDsByRoleIDs) != 3 {
t.Fatalf("expected 3 permission IDs from combined roles, got %d", len(permissionIDsByRoleIDs))
}
emptyPermissionIDs, err := rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, []int64{})
if err != nil {
t.Fatalf("GetPermissionIDsByRoleIDs(empty) failed: %v", err)
}
if len(emptyPermissionIDs) != 0 {
t.Fatalf("expected empty slice for GetPermissionIDsByRoleIDs(empty), got %d", len(emptyPermissionIDs))
}
if err := rolePermissionRepo.BatchDelete(ctx, []*domain.RolePermission{}); err != nil {
t.Fatalf("RolePermission BatchDelete(empty) failed: %v", err)
}
if err := rolePermissionRepo.BatchDelete(ctx, []*domain.RolePermission{rolePermissionBatch[0]}); err != nil {
t.Fatalf("RolePermission BatchDelete failed: %v", err)
}
if err := rolePermissionRepo.Delete(ctx, rolePermissionPrimary.ID); err != nil {
t.Fatalf("RolePermission Delete failed: %v", err)
}
if err := rolePermissionRepo.DeleteByPermissionID(ctx, permissions[0].ID); err != nil {
t.Fatalf("DeleteByPermissionID failed: %v", err)
}
if err := rolePermissionRepo.Create(ctx, &domain.RolePermission{RoleID: roles[0].ID, PermissionID: permissions[1].ID}); err != nil {
t.Fatalf("recreate role-permission failed: %v", err)
}
if err := rolePermissionRepo.DeleteByRoleID(ctx, roles[0].ID); err != nil {
t.Fatalf("DeleteByRoleID failed: %v", err)
}
remainingRolePermissions, err := rolePermissionRepo.GetByRoleID(ctx, roles[0].ID)
if err != nil {
t.Fatalf("GetByRoleID after DeleteByRoleID failed: %v", err)
}
if len(remainingRolePermissions) != 0 {
t.Fatalf("expected no role-permission relations for role one, got %d", len(remainingRolePermissions))
}
}

213
internal/repository/role.go Normal file
View File

@@ -0,0 +1,213 @@
package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RoleRepository 角色数据访问层
type RoleRepository struct {
db *gorm.DB
}
// NewRoleRepository 创建角色数据访问层
func NewRoleRepository(db *gorm.DB) *RoleRepository {
return &RoleRepository{db: db}
}
// Create 创建角色
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 roles.
requestedStatus := role.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(role).Error; err != nil {
return err
}
if requestedStatus == domain.RoleStatusDisabled {
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
role.Status = requestedStatus
}
return nil
})
}
// Update 更新角色
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
return r.db.WithContext(ctx).Save(role).Error
}
// Delete 删除角色
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).First(&role, id).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetByCode 根据代码获取角色
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// List 获取角色列表
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByStatus 根据状态获取角色列表
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// GetDefaultRoles 获取默认角色
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// ExistsByCode 检查角色代码是否存在
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新角色状态
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索角色
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByParentID 根据父ID获取角色列表
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetByIDs 根据ID列表批量获取角色
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
if len(ids) == 0 {
return []*domain.Role{}, nil
}
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
// 循环向上查找父角色,直到没有父角色为止
for {
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return nil, err
}
if role.ParentID == nil {
break
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
}
return ancestorIDs, nil
}
// GetAncestors 获取角色的完整继承链(从父到子)
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
if len(ancestorIDs) == 0 {
return []*domain.Role{}, nil
}
return r.GetByIDs(ctx, ancestorIDs)
}

View File

@@ -0,0 +1,150 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RolePermissionRepository 角色权限关联数据访问层
type RolePermissionRepository struct {
db *gorm.DB
}
// NewRolePermissionRepository 创建角色权限关联数据访问层
func NewRolePermissionRepository(db *gorm.DB) *RolePermissionRepository {
return &RolePermissionRepository{db: db}
}
// Create 创建角色权限关联
func (r *RolePermissionRepository) Create(ctx context.Context, rolePermission *domain.RolePermission) error {
return r.db.WithContext(ctx).Create(rolePermission).Error
}
// Delete 删除角色权限关联
func (r *RolePermissionRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.RolePermission{}, id).Error
}
// DeleteByRoleID 删除角色的所有权限
func (r *RolePermissionRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.RolePermission{}).Error
}
// DeleteByPermissionID 删除权限的所有角色
func (r *RolePermissionRepository) DeleteByPermissionID(ctx context.Context, permissionID int64) error {
return r.db.WithContext(ctx).Where("permission_id = ?", permissionID).Delete(&domain.RolePermission{}).Error
}
// GetByRoleID 根据角色ID获取权限列表
func (r *RolePermissionRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.RolePermission, error) {
var rolePermissions []*domain.RolePermission
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&rolePermissions).Error
if err != nil {
return nil, err
}
return rolePermissions, nil
}
// GetByPermissionID 根据权限ID获取角色列表
func (r *RolePermissionRepository) GetByPermissionID(ctx context.Context, permissionID int64) ([]*domain.RolePermission, error) {
var rolePermissions []*domain.RolePermission
err := r.db.WithContext(ctx).Where("permission_id = ?", permissionID).Find(&rolePermissions).Error
if err != nil {
return nil, err
}
return rolePermissions, nil
}
// GetPermissionIDsByRoleID 根据角色ID获取权限ID列表
func (r *RolePermissionRepository) GetPermissionIDsByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
var permissionIDs []int64
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id = ?", roleID).Pluck("permission_id", &permissionIDs).Error
if err != nil {
return nil, err
}
return permissionIDs, nil
}
// GetRoleIDByPermissionID 根据权限ID获取角色ID列表
func (r *RolePermissionRepository) GetRoleIDByPermissionID(ctx context.Context, permissionID int64) ([]int64, error) {
var roleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("permission_id = ?", permissionID).Pluck("role_id", &roleIDs).Error
if err != nil {
return nil, err
}
return roleIDs, nil
}
// Exists 检查角色权限关联是否存在
func (r *RolePermissionRepository) Exists(ctx context.Context, roleID, permissionID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
Where("role_id = ? AND permission_id = ?", roleID, permissionID).
Count(&count).Error
return count > 0, err
}
// BatchCreate 批量创建角色权限关联
func (r *RolePermissionRepository) BatchCreate(ctx context.Context, rolePermissions []*domain.RolePermission) error {
if len(rolePermissions) == 0 {
return nil
}
return r.db.WithContext(ctx).Create(&rolePermissions).Error
}
// BatchDelete 批量删除角色权限关联
func (r *RolePermissionRepository) BatchDelete(ctx context.Context, rolePermissions []*domain.RolePermission) error {
if len(rolePermissions) == 0 {
return nil
}
var ids []int64
for _, rp := range rolePermissions {
ids = append(ids, rp.ID)
}
return r.db.WithContext(ctx).Delete(&domain.RolePermission{}, ids).Error
}
// GetPermissionByID 根据权限ID获取权限信息
func (r *RolePermissionRepository) GetPermissionByID(ctx context.Context, permissionID int64) (*domain.Permission, error) {
var permission domain.Permission
err := r.db.WithContext(ctx).First(&permission, permissionID).Error
if err != nil {
return nil, err
}
return &permission, nil
}
// GetPermissionIDsByRoleIDs 根据角色ID列表批量获取权限ID
func (r *RolePermissionRepository) GetPermissionIDsByRoleIDs(ctx context.Context, roleIDs []int64) ([]int64, error) {
if len(roleIDs) == 0 {
return []int64{}, nil
}
var permissionIDs []int64
err := r.db.WithContext(ctx).Model(&domain.RolePermission{}).
Where("role_id IN ?", roleIDs).
Pluck("permission_id", &permissionIDs).Error
if err != nil {
return nil, err
}
return permissionIDs, nil
}
// GetPermissionsByIDs 根据权限ID列表批量获取权限
func (r *RolePermissionRepository) GetPermissionsByIDs(ctx context.Context, permissionIDs []int64) ([]*domain.Permission, error) {
if len(permissionIDs) == 0 {
return []*domain.Permission{}, nil
}
var permissions []*domain.Permission
err := r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error
if err != nil {
return nil, err
}
return permissions, nil
}

View File

@@ -0,0 +1,68 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/require"
)
func TestSchedulerSnapshotOutboxReplay(t *testing.T) {
ctx := context.Background()
rdb := testRedis(t)
client := testEntClient(t)
_, _ = integrationDB.ExecContext(ctx, "TRUNCATE scheduler_outbox")
accountRepo := newAccountRepositoryWithSQL(client, integrationDB, nil)
outboxRepo := NewSchedulerOutboxRepository(integrationDB)
cache := NewSchedulerCache(rdb)
cfg := &config.Config{
RunMode: config.RunModeStandard,
Gateway: config.GatewayConfig{
Scheduling: config.GatewaySchedulingConfig{
OutboxPollIntervalSeconds: 1,
FullRebuildIntervalSeconds: 0,
DbFallbackEnabled: true,
},
},
}
account := &service.Account{
Name: "outbox-replay-" + time.Now().Format("150405.000000"),
Platform: service.PlatformOpenAI,
Type: service.AccountTypeAPIKey,
Status: service.StatusActive,
Schedulable: true,
Concurrency: 3,
Priority: 1,
Credentials: map[string]any{},
Extra: map[string]any{},
}
require.NoError(t, accountRepo.Create(ctx, account))
require.NoError(t, cache.SetAccount(ctx, account))
svc := service.NewSchedulerSnapshotService(cache, outboxRepo, accountRepo, nil, cfg)
svc.Start()
t.Cleanup(svc.Stop)
require.NoError(t, accountRepo.UpdateLastUsed(ctx, account.ID))
updated, err := accountRepo.GetByID(ctx, account.ID)
require.NoError(t, err)
require.NotNil(t, updated.LastUsedAt)
expectedUnix := updated.LastUsedAt.Unix()
require.Eventually(t, func() bool {
cached, err := cache.GetAccount(ctx, account.ID)
if err != nil || cached == nil || cached.LastUsedAt == nil {
return false
}
return cached.LastUsedAt.Unix() == expectedUnix
}, 5*time.Second, 100*time.Millisecond)
}

View File

@@ -0,0 +1,295 @@
package repository
import (
"context"
"database/sql"
"fmt"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
// SocialAccountRepository 社交账号仓库接口
type SocialAccountRepository interface {
Create(ctx context.Context, account *domain.SocialAccount) error
Update(ctx context.Context, account *domain.SocialAccount) error
Delete(ctx context.Context, id int64) error
DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error
GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error)
GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error)
GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error)
List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error)
}
// SocialAccountRepositoryImpl 社交账号仓库实现
type SocialAccountRepositoryImpl struct {
db *sql.DB
}
// NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
var sqlDB *sql.DB
switch d := db.(type) {
case *gorm.DB:
var err error
sqlDB, err = d.DB()
if err != nil {
return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err)
}
case *sql.DB:
sqlDB = d
default:
return nil, fmt.Errorf("unsupported db type: %T", db)
}
if sqlDB == nil {
return nil, fmt.Errorf("sql db is nil")
}
return &SocialAccountRepositoryImpl{db: sqlDB}, nil
}
// Create 创建社交账号
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
query := `
INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
`
result, err := r.db.ExecContext(ctx, query,
account.UserID,
account.Provider,
account.OpenID,
account.UnionID,
account.Nickname,
account.Avatar,
account.Gender,
account.Email,
account.Phone,
account.Extra,
account.Status,
)
if err != nil {
return fmt.Errorf("failed to create social account: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return err
}
account.ID = id
return nil
}
// Update 更新社交账号
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
query := `
UPDATE user_social_accounts
SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP
WHERE id = ?
`
_, err := r.db.ExecContext(ctx, query,
account.UnionID,
account.Nickname,
account.Avatar,
account.Gender,
account.Email,
account.Phone,
account.Extra,
account.Status,
account.ID,
)
if err != nil {
return fmt.Errorf("failed to update social account: %w", err)
}
return nil
}
// Delete 删除社交账号
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
query := `DELETE FROM user_social_accounts WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return fmt.Errorf("failed to delete social account: %w", err)
}
return nil
}
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?`
_, err := r.db.ExecContext(ctx, query, provider, userID)
if err != nil {
return fmt.Errorf("failed to delete social account: %w", err)
}
return nil
}
// GetByID 根据ID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE id = ?
`
var account domain.SocialAccount
err := r.db.QueryRowContext(ctx, query, id).Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// GetByUserID 根据用户ID获取社交账号列表
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE user_id = ?
ORDER BY created_at DESC
`
rows, err := r.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, fmt.Errorf("failed to query social accounts: %w", err)
}
defer rows.Close()
var accounts []*domain.SocialAccount
for rows.Next() {
var account domain.SocialAccount
err := rows.Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err != nil {
return nil, err
}
accounts = append(accounts, &account)
}
return accounts, nil
}
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
WHERE provider = ? AND open_id = ?
`
var account domain.SocialAccount
err := r.db.QueryRowContext(ctx, query, provider, openID).Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// List 分页获取社交账号列表
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
// 获取总数
var total int64
countQuery := `SELECT COUNT(*) FROM user_social_accounts`
if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil {
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
}
// 获取列表
query := `
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
FROM user_social_accounts
ORDER BY created_at DESC
LIMIT ? OFFSET ?
`
rows, err := r.db.QueryContext(ctx, query, limit, offset)
if err != nil {
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
}
defer rows.Close()
var accounts []*domain.SocialAccount
for rows.Next() {
var account domain.SocialAccount
err := rows.Scan(
&account.ID,
&account.UserID,
&account.Provider,
&account.OpenID,
&account.UnionID,
&account.Nickname,
&account.Avatar,
&account.Gender,
&account.Email,
&account.Phone,
&account.Extra,
&account.Status,
&account.CreatedAt,
&account.UpdatedAt,
)
if err != nil {
return nil, 0, err
}
accounts = append(accounts, &account)
}
return accounts, total, nil
}

View File

@@ -0,0 +1,41 @@
package repository
import "testing"
func TestNewSocialAccountRepository_AcceptsGormDB(t *testing.T) {
db := openTestDB(t)
repo, err := NewSocialAccountRepository(db)
if err != nil {
t.Fatalf("expected constructor to succeed: %v", err)
}
if repo == nil {
t.Fatal("expected repository instance")
}
}
func TestNewSocialAccountRepository_AcceptsSQLDB(t *testing.T) {
db := openTestDB(t)
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("expected sql db handle: %v", err)
}
repo, err := NewSocialAccountRepository(sqlDB)
if err != nil {
t.Fatalf("expected constructor to succeed: %v", err)
}
if repo == nil {
t.Fatal("expected repository instance")
}
}
func TestNewSocialAccountRepository_RejectsUnsupportedType(t *testing.T) {
repo, err := NewSocialAccountRepository(struct{}{})
if err == nil {
t.Fatal("expected constructor error")
}
if repo != nil {
t.Fatal("did not expect repository instance")
}
}

View File

@@ -0,0 +1,42 @@
package repository
import (
"context"
"database/sql"
"errors"
)
type sqlQueryer interface {
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
}
// scanSingleRow 执行查询并扫描第一行到 dest。
// 若无结果,可通过 errors.Is(err, sql.ErrNoRows) 判断。
// 如果 Close 失败,会与原始错误合并返回。
// 设计目的:仅依赖 QueryContext避免 QueryRowContext 对 *sql.Tx 的强绑定,
// 让 ent.Tx 也能作为 sqlExecutor/Queryer 使用。
func scanSingleRow(ctx context.Context, q sqlQueryer, query string, args []any, dest ...any) (err error) {
rows, err := q.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
err = errors.Join(err, closeErr)
}
}()
if !rows.Next() {
if err = rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
if err = rows.Scan(dest...); err != nil {
return err
}
if err = rows.Err(); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,48 @@
package repository
import (
"fmt"
"sync/atomic"
"testing"
_ "modernc.org/sqlite" // 纯 Go SQLite注册 "sqlite" 驱动
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"github.com/user-management-system/internal/domain"
)
var repoDBCounter int64
// openTestDB 为每个测试打开独立的内存数据库(使用 modernc.org/sqlite无需 CGO
// 每次调用都生成唯一的 DSN避免多个测试共用同一内存 DB 导致 index 重复错误
func openTestDB(t *testing.T) *gorm.DB {
t.Helper()
id := atomic.AddInt64(&repoDBCounter, 1)
dsn := fmt.Sprintf("file:repotestdb%d?mode=memory&cache=private", id)
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("打开测试数据库失败: %v", err)
}
tables := []interface{}{
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
}
if err := db.AutoMigrate(tables...); err != nil {
t.Fatalf("数据库迁移失败: %v", err)
}
return db
}

View File

@@ -0,0 +1,99 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// ThemeConfigRepository 主题配置数据访问层
type ThemeConfigRepository struct {
db *gorm.DB
}
// NewThemeConfigRepository 创建主题配置数据访问层
func NewThemeConfigRepository(db *gorm.DB) *ThemeConfigRepository {
return &ThemeConfigRepository{db: db}
}
// Create 创建主题配置
func (r *ThemeConfigRepository) Create(ctx context.Context, theme *domain.ThemeConfig) error {
return r.db.WithContext(ctx).Create(theme).Error
}
// Update 更新主题配置
func (r *ThemeConfigRepository) Update(ctx context.Context, theme *domain.ThemeConfig) error {
return r.db.WithContext(ctx).Save(theme).Error
}
// Delete 删除主题配置
func (r *ThemeConfigRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.ThemeConfig{}, id).Error
}
// GetByID 根据ID获取主题配置
func (r *ThemeConfigRepository) GetByID(ctx context.Context, id int64) (*domain.ThemeConfig, error) {
var theme domain.ThemeConfig
err := r.db.WithContext(ctx).First(&theme, id).Error
if err != nil {
return nil, err
}
return &theme, nil
}
// GetByName 根据名称获取主题配置
func (r *ThemeConfigRepository) GetByName(ctx context.Context, name string) (*domain.ThemeConfig, error) {
var theme domain.ThemeConfig
err := r.db.WithContext(ctx).Where("name = ?", name).First(&theme).Error
if err != nil {
return nil, err
}
return &theme, nil
}
// GetDefault 获取默认主题
func (r *ThemeConfigRepository) GetDefault(ctx context.Context) (*domain.ThemeConfig, error) {
var theme domain.ThemeConfig
err := r.db.WithContext(ctx).Where("is_default = ?", true).First(&theme).Error
if err != nil {
// 如果没有默认主题,返回默认配置
if err == gorm.ErrRecordNotFound {
return domain.DefaultThemeConfig(), nil
}
return nil, err
}
return &theme, nil
}
// List 获取所有已启用的主题配置
func (r *ThemeConfigRepository) List(ctx context.Context) ([]*domain.ThemeConfig, error) {
var themes []*domain.ThemeConfig
err := r.db.WithContext(ctx).Where("enabled = ?", true).Order("is_default DESC, id ASC").Find(&themes).Error
if err != nil {
return nil, err
}
return themes, nil
}
// ListAll 获取所有主题配置
func (r *ThemeConfigRepository) ListAll(ctx context.Context) ([]*domain.ThemeConfig, error) {
var themes []*domain.ThemeConfig
err := r.db.WithContext(ctx).Order("is_default DESC, id ASC").Find(&themes).Error
if err != nil {
return nil, err
}
return themes, nil
}
// SetDefault 设置默认主题
func (r *ThemeConfigRepository) SetDefault(ctx context.Context, id int64) error {
// 先清除所有默认标记
if err := r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("is_default = ?", true).Update("is_default", false).Error; err != nil {
return err
}
// 设置新的默认主题
return r.db.WithContext(ctx).Model(&domain.ThemeConfig{}).Where("id = ?", id).Update("is_default", true).Error
}

View File

@@ -0,0 +1,73 @@
//go:build integration
package repository
import (
"errors"
"testing"
"time"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type UpdateCacheSuite struct {
IntegrationRedisSuite
cache *updateCache
}
func (s *UpdateCacheSuite) SetupTest() {
s.IntegrationRedisSuite.SetupTest()
s.cache = NewUpdateCache(s.rdb).(*updateCache)
}
func (s *UpdateCacheSuite) TestGetUpdateInfo_Missing() {
_, err := s.cache.GetUpdateInfo(s.ctx)
require.True(s.T(), errors.Is(err, redis.Nil), "expected redis.Nil for missing update info")
}
func (s *UpdateCacheSuite) TestSetAndGetUpdateInfo() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL), "SetUpdateInfo")
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err, "GetUpdateInfo")
require.Equal(s.T(), "v1.2.3", info, "update info mismatch")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_TTL() {
updateTTL := 5 * time.Minute
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.2.3", updateTTL))
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err, "TTL updateCacheKey")
s.AssertTTLWithin(ttl, 1*time.Second, updateTTL)
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_Overwrite() {
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v1.0.0", 5*time.Minute))
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v2.0.0", 5*time.Minute))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v2.0.0", info, "expected overwritten value")
}
func (s *UpdateCacheSuite) TestSetUpdateInfo_ZeroTTL() {
// TTL=0 means persist forever (no expiry) in Redis SET command
require.NoError(s.T(), s.cache.SetUpdateInfo(s.ctx, "v0.0.0", 0))
info, err := s.cache.GetUpdateInfo(s.ctx)
require.NoError(s.T(), err)
require.Equal(s.T(), "v0.0.0", info)
ttl, err := s.rdb.TTL(s.ctx, updateCacheKey).Result()
require.NoError(s.T(), err)
// TTL=-1 means no expiry, TTL=-2 means key doesn't exist
require.Equal(s.T(), time.Duration(-1), ttl, "expected TTL=-1 for key with no expiry")
}
func TestUpdateCacheSuite(t *testing.T) {
suite.Run(t, new(UpdateCacheSuite))
}

314
internal/repository/user.go Normal file
View File

@@ -0,0 +1,314 @@
package repository
import (
"context"
"strings"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _
// 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配
func escapeLikePattern(s string) string {
// 先转义 \,再转义 % 和 _顺序很重要
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}
// UserRepository 用户数据访问层
type UserRepository struct {
db *gorm.DB
}
// NewUserRepository 创建用户数据访问层
func NewUserRepository(db *gorm.DB) *UserRepository {
return &UserRepository{db: db}
}
// Create 创建用户
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Create(user).Error
}
// Update 更新用户
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Save(user).Error
}
// Delete 删除用户(软删除)
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
}
// GetByID 根据ID获取用户
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).First(&user, id).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByUsername 根据用户名获取用户
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByEmail 根据邮箱获取用户
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// GetByPhone 根据手机号获取用户
func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
var user domain.User
err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error
if err != nil {
return nil, err
}
return &user, nil
}
// List 获取用户列表
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// ListByStatus 根据状态获取用户列表
func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// UpdateStatus 更新用户状态
func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error
}
// UpdateLastLogin 更新最后登录信息
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
now := time.Now()
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{
"last_login_time": &now,
"last_login_ip": ip,
}).Error
}
// ExistsByUsername 检查用户名是否存在
func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error
return count > 0, err
}
// ExistsByEmail 检查邮箱是否存在
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error
return count > 0, err
}
// ExistsByPhone 检查手机号是否存在
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error
return count > 0, err
}
// Search 搜索用户
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.User{}).Where(
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
pattern, pattern, pattern, pattern,
)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// UpdateTOTP 更新用户的 TOTP 字段
func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error {
return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{
"totp_enabled": user.TOTPEnabled,
"totp_secret": user.TOTPSecret,
"totp_recovery_codes": user.TOTPRecoveryCodes,
}).Error
}
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
}
// ListCreatedAfter 查询指定时间之后创建的用户limit=0表示不限制数量
func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since)
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if limit > 0 {
query = query.Offset(offset).Limit(limit)
}
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}
// AdvancedFilter 高级用户筛选请求
type AdvancedFilter struct {
Keyword string // 关键字(用户名/邮箱/手机号/昵称)
Status int // 状态:-1 全部0/1/2/3 对应 UserStatus
RoleIDs []int64 // 角色ID列表按角色筛选
CreatedFrom *time.Time // 注册时间范围(起始)
CreatedTo *time.Time // 注册时间范围(截止)
LastLoginFrom *time.Time // 最后登录时间范围(起始)
SortBy string // 排序字段created_at, last_login_time, username
SortOrder string // 排序方向asc, desc
Offset int
Limit int
}
// AdvancedSearch 高级用户搜索(支持多维度组合筛选)
func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) {
var users []*domain.User
var total int64
query := r.db.WithContext(ctx).Model(&domain.User{})
// 关键字搜索(转义 LIKE 特殊字符)
if filter.Keyword != "" {
like := "%" + escapeLikePattern(filter.Keyword) + "%"
query = query.Where(
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
like, like, like, like,
)
}
// 状态筛选
if filter.Status >= 0 {
query = query.Where("status = ?", filter.Status)
}
// 注册时间范围
if filter.CreatedFrom != nil {
query = query.Where("created_at >= ?", filter.CreatedFrom)
}
if filter.CreatedTo != nil {
query = query.Where("created_at <= ?", filter.CreatedTo)
}
// 最后登录时间范围
if filter.LastLoginFrom != nil {
query = query.Where("last_login_time >= ?", filter.LastLoginFrom)
}
// 按角色筛选(子查询)
if len(filter.RoleIDs) > 0 {
query = query.Where(
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
filter.RoleIDs,
)
}
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 排序
sortBy := "created_at"
sortOrder := "DESC"
if filter.SortBy != "" {
allowedFields := map[string]bool{
"created_at": true, "last_login_time": true,
"username": true, "updated_at": true,
}
if allowedFields[filter.SortBy] {
sortBy = filter.SortBy
}
}
if filter.SortOrder == "asc" {
sortOrder = "ASC"
}
query = query.Order(sortBy + " " + sortOrder)
// 分页
limit := filter.Limit
if limit <= 0 {
limit = 20
}
if limit > 200 {
limit = 200
}
query = query.Offset(filter.Offset).Limit(limit)
if err := query.Find(&users).Error; err != nil {
return nil, 0, err
}
return users, total, nil
}

View File

@@ -0,0 +1,537 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
dbent "github.com/user-management-system/ent"
"github.com/user-management-system/internal/pkg/pagination"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/suite"
)
type UserRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userRepository
}
func (s *UserRepoSuite) SetupTest() {
s.ctx = context.Background()
s.client = testEntClient(s.T())
s.repo = newUserRepositoryWithSQL(s.client, integrationDB)
// 清理测试数据,确保每个测试从干净状态开始
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_subscriptions")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM user_allowed_groups")
_, _ = integrationDB.ExecContext(s.ctx, "DELETE FROM users")
}
func TestUserRepoSuite(t *testing.T) {
suite.Run(t, new(UserRepoSuite))
}
func (s *UserRepoSuite) mustCreateUser(u *service.User) *service.User {
s.T().Helper()
if u.Email == "" {
u.Email = "user-" + time.Now().Format(time.RFC3339Nano) + "@example.com"
}
if u.PasswordHash == "" {
u.PasswordHash = "test-password-hash"
}
if u.Role == "" {
u.Role = service.RoleUser
}
if u.Status == "" {
u.Status = service.StatusActive
}
if u.Concurrency == 0 {
u.Concurrency = 5
}
s.Require().NoError(s.repo.Create(s.ctx, u), "create user")
return u
}
func (s *UserRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *UserRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
s.T().Helper()
now := time.Now()
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1 * time.Hour)).
SetExpiresAt(now.Add(24 * time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
if mutate != nil {
mutate(create)
}
sub, err := create.Save(s.ctx)
s.Require().NoError(err, "create subscription")
return sub
}
// --- Create / GetByID / GetByEmail / Update / Delete ---
func (s *UserRepoSuite) TestCreate() {
user := s.mustCreateUser(&service.User{
Email: "create@test.com",
Username: "testuser",
PasswordHash: "test-password-hash",
Role: service.RoleUser,
Status: service.StatusActive,
})
s.Require().NotZero(user.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal("create@test.com", got.Email)
}
func (s *UserRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserRepoSuite) TestGetByEmail() {
user := s.mustCreateUser(&service.User{Email: "byemail@test.com"})
got, err := s.repo.GetByEmail(s.ctx, user.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user.ID, got.ID)
}
func (s *UserRepoSuite) TestGetByEmail_NotFound() {
_, err := s.repo.GetByEmail(s.ctx, "nonexistent@test.com")
s.Require().Error(err, "expected error for non-existent email")
}
func (s *UserRepoSuite) TestUpdate() {
user := s.mustCreateUser(&service.User{Email: "update@test.com", Username: "original"})
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
got.Username = "updated"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
updated, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated", updated.Username)
}
func (s *UserRepoSuite) TestDelete() {
user := s.mustCreateUser(&service.User{Email: "delete@test.com"})
err := s.repo.Delete(s.ctx, user.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, user.ID)
s.Require().Error(err, "expected error after delete")
}
// --- List / ListWithFilters ---
func (s *UserRepoSuite) TestList() {
s.mustCreateUser(&service.User{Email: "list1@test.com"})
s.mustCreateUser(&service.User{Email: "list2@test.com"})
users, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "List")
s.Require().Len(users, 2)
s.Require().Equal(int64(2), page.Total)
}
func (s *UserRepoSuite) TestListWithFilters_Status() {
s.mustCreateUser(&service.User{Email: "active@test.com", Status: service.StatusActive})
s.mustCreateUser(&service.User{Email: "disabled@test.com", Status: service.StatusDisabled})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.StatusActive, users[0].Status)
}
func (s *UserRepoSuite) TestListWithFilters_Role() {
s.mustCreateUser(&service.User{Email: "user@test.com", Role: service.RoleUser})
s.mustCreateUser(&service.User{Email: "admin@test.com", Role: service.RoleAdmin})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Role: service.RoleAdmin})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal(service.RoleAdmin, users[0].Role)
}
func (s *UserRepoSuite) TestListWithFilters_Search() {
s.mustCreateUser(&service.User{Email: "alice@test.com", Username: "Alice"})
s.mustCreateUser(&service.User{Email: "bob@test.com", Username: "Bob"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "alice"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Contains(users[0].Email, "alice")
}
func (s *UserRepoSuite) TestListWithFilters_SearchByUsername() {
s.mustCreateUser(&service.User{Email: "u1@test.com", Username: "JohnDoe"})
s.mustCreateUser(&service.User{Email: "u2@test.com", Username: "JaneSmith"})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "john"})
s.Require().NoError(err)
s.Require().Len(users, 1)
s.Require().Equal("JohnDoe", users[0].Username)
}
func (s *UserRepoSuite) TestListWithFilters_LoadsActiveSubscriptions() {
user := s.mustCreateUser(&service.User{Email: "sub@test.com", Status: service.StatusActive})
groupActive := s.mustCreateGroup("g-sub-active")
groupExpired := s.mustCreateGroup("g-sub-expired")
_ = s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusActive)
c.SetExpiresAt(time.Now().Add(1 * time.Hour))
})
_ = s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-1 * time.Hour))
})
users, _, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Search: "sub@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Len(users, 1, "expected 1 user")
s.Require().Len(users[0].Subscriptions, 1, "expected 1 active subscription")
s.Require().NotNil(users[0].Subscriptions[0].Group, "expected subscription group preload")
s.Require().Equal(groupActive.ID, users[0].Subscriptions[0].Group.ID, "group ID mismatch")
}
func (s *UserRepoSuite) TestListWithFilters_CombinedFilters() {
s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
target := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
users, page, err := s.repo.ListWithFilters(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(target.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- Balance operations ---
func (s *UserRepoSuite) TestUpdateBalance() {
user := s.mustCreateUser(&service.User{Email: "bal@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, 2.5)
s.Require().NoError(err, "UpdateBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(12.5, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestUpdateBalance_Negative() {
user := s.mustCreateUser(&service.User{Email: "balneg@test.com", Balance: 10})
err := s.repo.UpdateBalance(s.ctx, user.ID, -3)
s.Require().NoError(err, "UpdateBalance with negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(7.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance() {
user := s.mustCreateUser(&service.User{Email: "deduct@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 5)
s.Require().NoError(err, "DeductBalance")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(5.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_InsufficientFunds() {
user := s.mustCreateUser(&service.User{Email: "insuf@test.com", Balance: 5})
// 透支策略:允许扣除超过余额的金额
err := s.repo.DeductBalance(s.ctx, user.ID, 999)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额变为负数
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-994.0, got.Balance, 1e-6, "Balance should be negative after overdraft")
}
func (s *UserRepoSuite) TestDeductBalance_ExactAmount() {
user := s.mustCreateUser(&service.User{Email: "exact@test.com", Balance: 10})
err := s.repo.DeductBalance(s.ctx, user.ID, 10)
s.Require().NoError(err, "DeductBalance exact amount")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(0.0, got.Balance, 1e-6)
}
func (s *UserRepoSuite) TestDeductBalance_AllowsOverdraft() {
user := s.mustCreateUser(&service.User{Email: "overdraft@test.com", Balance: 5.0})
// 扣除超过余额的金额 - 应该成功
err := s.repo.DeductBalance(s.ctx, user.ID, 10.0)
s.Require().NoError(err, "DeductBalance should allow overdraft")
// 验证余额为负
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().InDelta(-5.0, got.Balance, 1e-6, "Balance should be -5.0 after overdraft")
}
// --- Concurrency ---
func (s *UserRepoSuite) TestUpdateConcurrency() {
user := s.mustCreateUser(&service.User{Email: "conc@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, 3)
s.Require().NoError(err, "UpdateConcurrency")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(8, got.Concurrency)
}
func (s *UserRepoSuite) TestUpdateConcurrency_Negative() {
user := s.mustCreateUser(&service.User{Email: "concneg@test.com", Concurrency: 5})
err := s.repo.UpdateConcurrency(s.ctx, user.ID, -2)
s.Require().NoError(err, "UpdateConcurrency negative")
got, err := s.repo.GetByID(s.ctx, user.ID)
s.Require().NoError(err)
s.Require().Equal(3, got.Concurrency)
}
// --- ExistsByEmail ---
func (s *UserRepoSuite) TestExistsByEmail() {
s.mustCreateUser(&service.User{Email: "exists@test.com"})
exists, err := s.repo.ExistsByEmail(s.ctx, "exists@test.com")
s.Require().NoError(err, "ExistsByEmail")
s.Require().True(exists)
notExists, err := s.repo.ExistsByEmail(s.ctx, "notexists@test.com")
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- RemoveGroupFromAllowedGroups ---
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups() {
target := s.mustCreateGroup("target-42")
other := s.mustCreateGroup("other-7")
userA := s.mustCreateUser(&service.User{
Email: "a1@example.com",
AllowedGroups: []int64{target.ID, other.ID},
})
s.mustCreateUser(&service.User{
Email: "a2@example.com",
AllowedGroups: []int64{other.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, target.ID)
s.Require().NoError(err, "RemoveGroupFromAllowedGroups")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
got, err := s.repo.GetByID(s.ctx, userA.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotContains(got.AllowedGroups, target.ID)
s.Require().Contains(got.AllowedGroups, other.ID)
}
func (s *UserRepoSuite) TestRemoveGroupFromAllowedGroups_NoMatch() {
groupA := s.mustCreateGroup("nomatch-a")
groupB := s.mustCreateGroup("nomatch-b")
s.mustCreateUser(&service.User{
Email: "nomatch@test.com",
AllowedGroups: []int64{groupA.ID, groupB.ID},
})
affected, err := s.repo.RemoveGroupFromAllowedGroups(s.ctx, 999999)
s.Require().NoError(err)
s.Require().Zero(affected, "expected no affected rows")
}
// --- GetFirstAdmin ---
func (s *UserRepoSuite) TestGetFirstAdmin() {
admin1 := s.mustCreateUser(&service.User{
Email: "admin1@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
s.mustCreateUser(&service.User{
Email: "admin2@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(admin1.ID, got.ID, "GetFirstAdmin mismatch")
}
func (s *UserRepoSuite) TestGetFirstAdmin_NoAdmin() {
s.mustCreateUser(&service.User{
Email: "user@example.com",
Role: service.RoleUser,
Status: service.StatusActive,
})
_, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().Error(err, "expected error when no admin exists")
}
func (s *UserRepoSuite) TestGetFirstAdmin_DisabledAdminIgnored() {
s.mustCreateUser(&service.User{
Email: "disabled@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
activeAdmin := s.mustCreateUser(&service.User{
Email: "active@example.com",
Role: service.RoleAdmin,
Status: service.StatusActive,
})
got, err := s.repo.GetFirstAdmin(s.ctx)
s.Require().NoError(err, "GetFirstAdmin")
s.Require().Equal(activeAdmin.ID, got.ID, "should return only active admin")
}
// --- Combined ---
func (s *UserRepoSuite) TestCRUD_And_Filters_And_AtomicUpdates() {
user1 := s.mustCreateUser(&service.User{
Email: "a@example.com",
Username: "Alice",
Role: service.RoleUser,
Status: service.StatusActive,
Balance: 10,
})
user2 := s.mustCreateUser(&service.User{
Email: "b@example.com",
Username: "Bob",
Role: service.RoleAdmin,
Status: service.StatusActive,
Balance: 1,
})
s.mustCreateUser(&service.User{
Email: "c@example.com",
Role: service.RoleAdmin,
Status: service.StatusDisabled,
})
got, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(user1.Email, got.Email, "GetByID email mismatch")
gotByEmail, err := s.repo.GetByEmail(s.ctx, user2.Email)
s.Require().NoError(err, "GetByEmail")
s.Require().Equal(user2.ID, gotByEmail.ID, "GetByEmail ID mismatch")
got.Username = "Alice2"
s.Require().NoError(s.repo.Update(s.ctx, got), "Update")
got2, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("Alice2", got2.Username, "Update did not persist")
s.Require().NoError(s.repo.UpdateBalance(s.ctx, user1.ID, 2.5), "UpdateBalance")
got3, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateBalance")
s.Require().InDelta(12.5, got3.Balance, 1e-6)
s.Require().NoError(s.repo.DeductBalance(s.ctx, user1.ID, 5), "DeductBalance")
got4, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after DeductBalance")
s.Require().InDelta(7.5, got4.Balance, 1e-6)
// 透支策略:允许扣除超过余额的金额
err = s.repo.DeductBalance(s.ctx, user1.ID, 999)
s.Require().NoError(err, "DeductBalance should allow overdraft")
gotOverdraft, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after overdraft")
s.Require().Less(gotOverdraft.Balance, 0.0, "Balance should be negative after overdraft")
s.Require().NoError(s.repo.UpdateConcurrency(s.ctx, user1.ID, 3), "UpdateConcurrency")
got5, err := s.repo.GetByID(s.ctx, user1.ID)
s.Require().NoError(err, "GetByID after UpdateConcurrency")
s.Require().Equal(user1.Concurrency+3, got5.Concurrency)
params := pagination.PaginationParams{Page: 1, PageSize: 10}
users, page, err := s.repo.ListWithFilters(s.ctx, params, service.UserListFilters{Status: service.StatusActive, Role: service.RoleAdmin, Search: "b@"})
s.Require().NoError(err, "ListWithFilters")
s.Require().Equal(int64(1), page.Total, "ListWithFilters total mismatch")
s.Require().Len(users, 1, "ListWithFilters len mismatch")
s.Require().Equal(user2.ID, users[0].ID, "ListWithFilters result mismatch")
}
// --- UpdateBalance/UpdateConcurrency 影响行数校验测试 ---
func (s *UserRepoSuite) TestUpdateBalance_NotFound() {
err := s.repo.UpdateBalance(s.ctx, 999999, 10.0)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestUpdateConcurrency_NotFound() {
err := s.repo.UpdateConcurrency(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
s.Require().ErrorIs(err, service.ErrUserNotFound)
}
func (s *UserRepoSuite) TestDeductBalance_NotFound() {
err := s.repo.DeductBalance(s.ctx, 999999, 5)
s.Require().Error(err, "expected error for non-existent user")
// DeductBalance 在用户不存在时返回 ErrUserNotFound
s.Require().ErrorIs(err, service.ErrUserNotFound)
}

View File

@@ -0,0 +1,198 @@
package repository
import (
"context"
"testing"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
func setupTestDB(t *testing.T) *gorm.DB {
return openTestDB(t)
}
// TestUserRepository_Create 测试创建用户
func TestUserRepository_Create(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "testuser",
Email: domain.StrPtr("test@example.com"),
Phone: domain.StrPtr("13800138000"),
Password: "hashedpassword",
Status: domain.UserStatusActive,
}
if err := repo.Create(ctx, user); err != nil {
t.Fatalf("Create() error = %v", err)
}
if user.ID == 0 {
t.Error("创建后用户ID不应为0")
}
}
// TestUserRepository_GetByUsername 测试根据用户名查询
func TestUserRepository_GetByUsername(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "findme",
Email: domain.StrPtr("findme@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
found, err := repo.GetByUsername(ctx, "findme")
if err != nil {
t.Fatalf("GetByUsername() error = %v", err)
}
if found.Username != "findme" {
t.Errorf("Username = %v, want findme", found.Username)
}
_, err = repo.GetByUsername(ctx, "notexist")
if err == nil {
t.Error("查找不存在的用户应返回错误")
}
}
// TestUserRepository_GetByEmail 测试根据邮箱查询
func TestUserRepository_GetByEmail(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "emailuser",
Email: domain.StrPtr("email@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
found, err := repo.GetByEmail(ctx, "email@example.com")
if err != nil {
t.Fatalf("GetByEmail() error = %v", err)
}
if domain.DerefStr(found.Email) != "email@example.com" {
t.Errorf("Email = %v, want email@example.com", domain.DerefStr(found.Email))
}
}
// TestUserRepository_Update 测试更新用户
func TestUserRepository_Update(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "updateme",
Email: domain.StrPtr("update@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
user.Nickname = "已更新"
if err := repo.Update(ctx, user); err != nil {
t.Fatalf("Update() error = %v", err)
}
found, _ := repo.GetByID(ctx, user.ID)
if found.Nickname != "已更新" {
t.Errorf("Nickname = %v, want 已更新", found.Nickname)
}
}
// TestUserRepository_Delete 测试删除用户
func TestUserRepository_Delete(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "deleteme",
Email: domain.StrPtr("delete@example.com"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
if err := repo.Delete(ctx, user.ID); err != nil {
t.Fatalf("Delete() error = %v", err)
}
_, err := repo.GetByID(ctx, user.ID)
if err == nil {
t.Error("删除后查询应返回错误")
}
}
// TestUserRepository_ExistsBy 测试存在性检查
func TestUserRepository_ExistsBy(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
user := &domain.User{
Username: "existsuser",
Email: domain.StrPtr("exists@example.com"),
Phone: domain.StrPtr("13900139000"),
Password: "hash",
Status: domain.UserStatusActive,
}
repo.Create(ctx, user)
ok, _ := repo.ExistsByUsername(ctx, "existsuser")
if !ok {
t.Error("ExistsByUsername 应返回 true")
}
ok, _ = repo.ExistsByEmail(ctx, "exists@example.com")
if !ok {
t.Error("ExistsByEmail 应返回 true")
}
ok, _ = repo.ExistsByPhone(ctx, "13900139000")
if !ok {
t.Error("ExistsByPhone 应返回 true")
}
ok, _ = repo.ExistsByUsername(ctx, "notexist")
if ok {
t.Error("不存在的用户 ExistsByUsername 应返回 false")
}
}
// TestUserRepository_List 测试列表查询
func TestUserRepository_List(t *testing.T) {
db := setupTestDB(t)
repo := NewUserRepository(db)
ctx := context.Background()
for i := 0; i < 5; i++ {
repo.Create(ctx, &domain.User{
Username: "listuser" + string(rune('0'+i)),
Password: "hash",
Status: domain.UserStatusActive,
})
}
users, total, err := repo.List(ctx, 0, 10)
if err != nil {
t.Fatalf("List() error = %v", err)
}
if len(users) != 5 {
t.Errorf("len(users) = %d, want 5", len(users))
}
if total != 5 {
t.Errorf("total = %d, want 5", total)
}
}

View File

@@ -0,0 +1,175 @@
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// UserRoleRepository 用户角色关联数据访问层
type UserRoleRepository struct {
db *gorm.DB
}
// NewUserRoleRepository 创建用户角色关联数据访问层
func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: db}
}
// Create 创建用户角色关联
func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error {
return r.db.WithContext(ctx).Create(userRole).Error
}
// Delete 删除用户角色关联
func (r *UserRoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, id).Error
}
// DeleteByUserID 删除用户的所有角色
func (r *UserRoleRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error
}
// DeleteByRoleID 删除角色的所有用户
func (r *UserRoleRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.UserRole{}).Error
}
// GetByUserID 根据用户ID获取角色列表
func (r *UserRoleRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetByRoleID 根据角色ID获取用户列表
func (r *UserRoleRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetRoleIDsByUserID 根据用户ID获取角色ID列表
func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int64) ([]int64, error) {
var roleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &roleIDs).Error
if err != nil {
return nil, err
}
return roleIDs, nil
}
// GetUserRolesAndPermissions 获取用户角色和权限PERF-01 优化:合并为单次 JOIN 查询)
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
var results []struct {
RoleID int64
RoleName string
RoleCode string
RoleStatus int
PermissionID int64
PermissionCode string
PermissionName string
}
// 使用 LEFT JOIN 一次性获取用户角色和权限
err := r.db.WithContext(ctx).
Raw(`
SELECT DISTINCT r.id as role_id, r.name as role_name, r.code as role_code, r.status as role_status,
p.id as permission_id, p.code as permission_code, p.name as permission_name
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
LEFT JOIN role_permissions rp ON r.id = rp.role_id
LEFT JOIN permissions p ON rp.permission_id = p.id
WHERE ur.user_id = ? AND r.status = 1
`, userID).
Scan(&results).Error
if err != nil {
return nil, nil, err
}
// 构建角色和权限列表
roleMap := make(map[int64]*domain.Role)
permMap := make(map[int64]*domain.Permission)
for _, row := range results {
if _, ok := roleMap[row.RoleID]; !ok {
roleMap[row.RoleID] = &domain.Role{
ID: row.RoleID,
Name: row.RoleName,
Code: row.RoleCode,
Status: domain.RoleStatus(row.RoleStatus),
}
}
if row.PermissionID > 0 {
if _, ok := permMap[row.PermissionID]; !ok {
permMap[row.PermissionID] = &domain.Permission{
ID: row.PermissionID,
Code: row.PermissionCode,
Name: row.PermissionName,
}
}
}
}
roles := make([]*domain.Role, 0, len(roleMap))
for _, role := range roleMap {
roles = append(roles, role)
}
perms := make([]*domain.Permission, 0, len(permMap))
for _, perm := range permMap {
perms = append(perms, perm)
}
return roles, perms, nil
}
// GetUserIDByRoleID 根据角色ID获取用户ID列表
func (r *UserRoleRepository) GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
var userIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("role_id = ?", roleID).Pluck("user_id", &userIDs).Error
if err != nil {
return nil, err
}
return userIDs, nil
}
// Exists 检查用户角色关联是否存在
func (r *UserRoleRepository) Exists(ctx context.Context, userID, roleID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, roleID).
Count(&count).Error
return count > 0, err
}
// BatchCreate 批量创建用户角色关联
func (r *UserRoleRepository) BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
return r.db.WithContext(ctx).Create(&userRoles).Error
}
// BatchDelete 批量删除用户角色关联
func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
var ids []int64
for _, ur := range userRoles {
ids = append(ids, ur.ID)
}
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error
}

View File

@@ -0,0 +1,747 @@
//go:build integration
package repository
import (
"context"
"fmt"
"testing"
"time"
dbent "github.com/user-management-system/ent"
"github.com/user-management-system/internal/pkg/pagination"
"github.com/user-management-system/internal/service"
"github.com/stretchr/testify/suite"
)
type UserSubscriptionRepoSuite struct {
suite.Suite
ctx context.Context
client *dbent.Client
repo *userSubscriptionRepository
}
func (s *UserSubscriptionRepoSuite) SetupTest() {
s.ctx = context.Background()
tx := testEntTx(s.T())
s.client = tx.Client()
s.repo = NewUserSubscriptionRepository(s.client).(*userSubscriptionRepository)
}
func TestUserSubscriptionRepoSuite(t *testing.T) {
suite.Run(t, new(UserSubscriptionRepoSuite))
}
func (s *UserSubscriptionRepoSuite) mustCreateUser(email string, role string) *service.User {
s.T().Helper()
if role == "" {
role = service.RoleUser
}
u, err := s.client.User.Create().
SetEmail(email).
SetPasswordHash("test-password-hash").
SetStatus(service.StatusActive).
SetRole(role).
Save(s.ctx)
s.Require().NoError(err, "create user")
return userEntityToService(u)
}
func (s *UserSubscriptionRepoSuite) mustCreateGroup(name string) *service.Group {
s.T().Helper()
g, err := s.client.Group.Create().
SetName(name).
SetStatus(service.StatusActive).
Save(s.ctx)
s.Require().NoError(err, "create group")
return groupEntityToService(g)
}
func (s *UserSubscriptionRepoSuite) mustCreateSubscription(userID, groupID int64, mutate func(*dbent.UserSubscriptionCreate)) *dbent.UserSubscription {
s.T().Helper()
now := time.Now()
create := s.client.UserSubscription.Create().
SetUserID(userID).
SetGroupID(groupID).
SetStartsAt(now.Add(-1 * time.Hour)).
SetExpiresAt(now.Add(24 * time.Hour)).
SetStatus(service.SubscriptionStatusActive).
SetAssignedAt(now).
SetNotes("")
if mutate != nil {
mutate(create)
}
sub, err := create.Save(s.ctx)
s.Require().NoError(err, "create user subscription")
return sub
}
// --- Create / GetByID / Update / Delete ---
func (s *UserSubscriptionRepoSuite) TestCreate() {
user := s.mustCreateUser("sub-create@test.com", service.RoleUser)
group := s.mustCreateGroup("g-create")
sub := &service.UserSubscription{
UserID: user.ID,
GroupID: group.ID,
Status: service.SubscriptionStatusActive,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
err := s.repo.Create(s.ctx, sub)
s.Require().NoError(err, "Create")
s.Require().NotZero(sub.ID, "expected ID to be set")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().Equal(sub.UserID, got.UserID)
s.Require().Equal(sub.GroupID, got.GroupID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_WithPreloads() {
user := s.mustCreateUser("preload@test.com", service.RoleUser)
group := s.mustCreateGroup("g-preload")
admin := s.mustCreateUser("admin@test.com", service.RoleAdmin)
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetAssignedBy(admin.ID)
})
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID")
s.Require().NotNil(got.User, "expected User preload")
s.Require().NotNil(got.Group, "expected Group preload")
s.Require().NotNil(got.AssignedByUser, "expected AssignedByUser preload")
s.Require().Equal(user.ID, got.User.ID)
s.Require().Equal(group.ID, got.Group.ID)
s.Require().Equal(admin.ID, got.AssignedByUser.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetByID_NotFound() {
_, err := s.repo.GetByID(s.ctx, 999999)
s.Require().Error(err, "expected error for non-existent ID")
}
func (s *UserSubscriptionRepoSuite) TestUpdate() {
user := s.mustCreateUser("update@test.com", service.RoleUser)
group := s.mustCreateGroup("g-update")
created := s.mustCreateSubscription(user.ID, group.ID, nil)
sub, err := s.repo.GetByID(s.ctx, created.ID)
s.Require().NoError(err, "GetByID")
sub.Notes = "updated notes"
s.Require().NoError(s.repo.Update(s.ctx, sub), "Update")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err, "GetByID after update")
s.Require().Equal("updated notes", got.Notes)
}
func (s *UserSubscriptionRepoSuite) TestDelete() {
user := s.mustCreateUser("delete@test.com", service.RoleUser)
group := s.mustCreateGroup("g-delete")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.Delete(s.ctx, sub.ID)
s.Require().NoError(err, "Delete")
_, err = s.repo.GetByID(s.ctx, sub.ID)
s.Require().Error(err, "expected error after delete")
}
func (s *UserSubscriptionRepoSuite) TestDelete_Idempotent() {
s.Require().NoError(s.repo.Delete(s.ctx, 42424242), "Delete should be idempotent")
}
// --- GetByUserIDAndGroupID / GetActiveByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID() {
user := s.mustCreateUser("byuser@test.com", service.RoleUser)
group := s.mustCreateGroup("g-byuser")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
got, err := s.repo.GetByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetByUserIDAndGroupID")
s.Require().Equal(sub.ID, got.ID)
s.Require().NotNil(got.Group, "expected Group preload")
}
func (s *UserSubscriptionRepoSuite) TestGetByUserIDAndGroupID_NotFound() {
_, err := s.repo.GetByUserIDAndGroupID(s.ctx, 999999, 999999)
s.Require().Error(err, "expected error for non-existent pair")
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID() {
user := s.mustCreateUser("active@test.com", service.RoleUser)
group := s.mustCreateGroup("g-active")
active := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID)
}
func (s *UserSubscriptionRepoSuite) TestGetActiveByUserIDAndGroupID_ExpiredIgnored() {
user := s.mustCreateUser("expired@test.com", service.RoleUser)
group := s.mustCreateGroup("g-expired")
s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
})
_, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().Error(err, "expected error for expired subscription")
}
// --- ListByUserID / ListActiveByUserID ---
func (s *UserSubscriptionRepoSuite) TestListByUserID() {
user := s.mustCreateUser("listby@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-list1")
g2 := s.mustCreateGroup("g-list2")
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, err := s.repo.ListByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListByUserID")
s.Require().Len(subs, 2)
for _, sub := range subs {
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
func (s *UserSubscriptionRepoSuite) TestListActiveByUserID() {
user := s.mustCreateUser("listactive@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-act1")
g2 := s.mustCreateGroup("g-act2")
s.mustCreateSubscription(user.ID, g1.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
s.mustCreateSubscription(user.ID, g2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, err := s.repo.ListActiveByUserID(s.ctx, user.ID)
s.Require().NoError(err, "ListActiveByUserID")
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusActive, subs[0].Status)
}
// --- ListByGroupID ---
func (s *UserSubscriptionRepoSuite) TestListByGroupID() {
user1 := s.mustCreateUser("u1@test.com", service.RoleUser)
user2 := s.mustCreateUser("u2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-listgrp")
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, page, err := s.repo.ListByGroupID(s.ctx, group.ID, pagination.PaginationParams{Page: 1, PageSize: 10})
s.Require().NoError(err, "ListByGroupID")
s.Require().Len(subs, 2)
s.Require().Equal(int64(2), page.Total)
for _, sub := range subs {
s.Require().NotNil(sub.User, "expected User preload")
s.Require().NotNil(sub.Group, "expected Group preload")
}
}
// --- List with filters ---
func (s *UserSubscriptionRepoSuite) TestList_NoFilters() {
user := s.mustCreateUser("list@test.com", service.RoleUser)
group := s.mustCreateGroup("g-list")
s.mustCreateSubscription(user.ID, group.ID, nil)
subs, page, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, "", "", "", "")
s.Require().NoError(err, "List")
s.Require().Len(subs, 1)
s.Require().Equal(int64(1), page.Total)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByUserID() {
user1 := s.mustCreateUser("filter1@test.com", service.RoleUser)
user2 := s.mustCreateUser("filter2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-filter")
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, &user1.ID, nil, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(user1.ID, subs[0].UserID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByGroupID() {
user := s.mustCreateUser("grpfilter@test.com", service.RoleUser)
g1 := s.mustCreateGroup("g-f1")
g2 := s.mustCreateGroup("g-f2")
s.mustCreateSubscription(user.ID, g1.ID, nil)
s.mustCreateSubscription(user.ID, g2.ID, nil)
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, &g1.ID, "", "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(g1.ID, subs[0].GroupID)
}
func (s *UserSubscriptionRepoSuite) TestList_FilterByStatus() {
user1 := s.mustCreateUser("statfilter1@test.com", service.RoleUser)
user2 := s.mustCreateUser("statfilter2@test.com", service.RoleUser)
group1 := s.mustCreateGroup("g-stat-1")
group2 := s.mustCreateGroup("g-stat-2")
s.mustCreateSubscription(user1.ID, group1.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusActive)
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
s.mustCreateSubscription(user2.ID, group2.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
subs, _, err := s.repo.List(s.ctx, pagination.PaginationParams{Page: 1, PageSize: 10}, nil, nil, service.SubscriptionStatusExpired, "", "", "")
s.Require().NoError(err)
s.Require().Len(subs, 1)
s.Require().Equal(service.SubscriptionStatusExpired, subs[0].Status)
}
// --- Usage tracking ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage() {
user := s.mustCreateUser("usage@test.com", service.RoleUser)
group := s.mustCreateGroup("g-usage")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.IncrementUsage(s.ctx, sub.ID, 1.25)
s.Require().NoError(err, "IncrementUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(1.25, got.DailyUsageUSD, 1e-6)
s.Require().InDelta(1.25, got.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(1.25, got.MonthlyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Accumulates() {
user := s.mustCreateUser("accum@test.com", service.RoleUser)
group := s.mustCreateGroup("g-accum")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 1.0))
s.Require().NoError(s.repo.IncrementUsage(s.ctx, sub.ID, 2.5))
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(3.5, got.DailyUsageUSD, 1e-6)
}
func (s *UserSubscriptionRepoSuite) TestActivateWindows() {
user := s.mustCreateUser("activate@test.com", service.RoleUser)
group := s.mustCreateGroup("g-activate")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
activateAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC)
err := s.repo.ActivateWindows(s.ctx, sub.ID, activateAt)
s.Require().NoError(err, "ActivateWindows")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().NotNil(got.DailyWindowStart)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().WithinDuration(activateAt, *got.DailyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetDailyUsage() {
user := s.mustCreateUser("resetd@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetd")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetDailyUsageUsd(10.0)
c.SetWeeklyUsageUsd(20.0)
})
resetAt := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetDailyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetDailyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(0.0, got.DailyUsageUSD, 1e-6)
s.Require().InDelta(20.0, got.WeeklyUsageUSD, 1e-6)
s.Require().NotNil(got.DailyWindowStart)
s.Require().WithinDuration(resetAt, *got.DailyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetWeeklyUsage() {
user := s.mustCreateUser("resetw@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetw")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetWeeklyUsageUsd(15.0)
c.SetMonthlyUsageUsd(30.0)
})
resetAt := time.Date(2025, 1, 6, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetWeeklyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetWeeklyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(0.0, got.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(30.0, got.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(got.WeeklyWindowStart)
s.Require().WithinDuration(resetAt, *got.WeeklyWindowStart, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestResetMonthlyUsage() {
user := s.mustCreateUser("resetm@test.com", service.RoleUser)
group := s.mustCreateGroup("g-resetm")
sub := s.mustCreateSubscription(user.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetMonthlyUsageUsd(25.0)
})
resetAt := time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ResetMonthlyUsage(s.ctx, sub.ID, resetAt)
s.Require().NoError(err, "ResetMonthlyUsage")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().InDelta(0.0, got.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(got.MonthlyWindowStart)
s.Require().WithinDuration(resetAt, *got.MonthlyWindowStart, time.Microsecond)
}
// --- UpdateStatus / ExtendExpiry / UpdateNotes ---
func (s *UserSubscriptionRepoSuite) TestUpdateStatus() {
user := s.mustCreateUser("status@test.com", service.RoleUser)
group := s.mustCreateGroup("g-status")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.UpdateStatus(s.ctx, sub.ID, service.SubscriptionStatusExpired)
s.Require().NoError(err, "UpdateStatus")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal(service.SubscriptionStatusExpired, got.Status)
}
func (s *UserSubscriptionRepoSuite) TestExtendExpiry() {
user := s.mustCreateUser("extend@test.com", service.RoleUser)
group := s.mustCreateGroup("g-extend")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
newExpiry := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
err := s.repo.ExtendExpiry(s.ctx, sub.ID, newExpiry)
s.Require().NoError(err, "ExtendExpiry")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().WithinDuration(newExpiry, got.ExpiresAt, time.Microsecond)
}
func (s *UserSubscriptionRepoSuite) TestUpdateNotes() {
user := s.mustCreateUser("notes@test.com", service.RoleUser)
group := s.mustCreateGroup("g-notes")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
err := s.repo.UpdateNotes(s.ctx, sub.ID, "VIP user")
s.Require().NoError(err, "UpdateNotes")
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
s.Require().Equal("VIP user", got.Notes)
}
// --- ListExpired / BatchUpdateExpiredStatus ---
func (s *UserSubscriptionRepoSuite) TestListExpired() {
user := s.mustCreateUser("listexp@test.com", service.RoleUser)
groupActive := s.mustCreateGroup("g-listexp-active")
groupExpired := s.mustCreateGroup("g-listexp-expired")
s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
expired, err := s.repo.ListExpired(s.ctx)
s.Require().NoError(err, "ListExpired")
s.Require().Len(expired, 1)
}
func (s *UserSubscriptionRepoSuite) TestBatchUpdateExpiredStatus() {
user := s.mustCreateUser("batch@test.com", service.RoleUser)
groupFuture := s.mustCreateGroup("g-batch-future")
groupPast := s.mustCreateGroup("g-batch-past")
active := s.mustCreateSubscription(user.ID, groupFuture.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
expiredActive := s.mustCreateSubscription(user.ID, groupPast.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected)
gotActive, _ := s.repo.GetByID(s.ctx, active.ID)
s.Require().Equal(service.SubscriptionStatusActive, gotActive.Status)
gotExpired, _ := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().Equal(service.SubscriptionStatusExpired, gotExpired.Status)
}
// --- ExistsByUserIDAndGroupID ---
func (s *UserSubscriptionRepoSuite) TestExistsByUserIDAndGroupID() {
user := s.mustCreateUser("exists@test.com", service.RoleUser)
group := s.mustCreateGroup("g-exists")
s.mustCreateSubscription(user.ID, group.ID, nil)
exists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, group.ID)
s.Require().NoError(err, "ExistsByUserIDAndGroupID")
s.Require().True(exists)
notExists, err := s.repo.ExistsByUserIDAndGroupID(s.ctx, user.ID, 999999)
s.Require().NoError(err)
s.Require().False(notExists)
}
// --- CountByGroupID / CountActiveByGroupID ---
func (s *UserSubscriptionRepoSuite) TestCountByGroupID() {
user1 := s.mustCreateUser("cnt1@test.com", service.RoleUser)
user2 := s.mustCreateUser("cnt2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-count")
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetStatus(service.SubscriptionStatusExpired)
c.SetExpiresAt(time.Now().Add(-24 * time.Hour))
})
count, err := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountByGroupID")
s.Require().Equal(int64(2), count)
}
func (s *UserSubscriptionRepoSuite) TestCountActiveByGroupID() {
user1 := s.mustCreateUser("cntact1@test.com", service.RoleUser)
user2 := s.mustCreateUser("cntact2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-cntact")
s.mustCreateSubscription(user1.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(24 * time.Hour))
})
s.mustCreateSubscription(user2.ID, group.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-24 * time.Hour)) // expired by time
})
count, err := s.repo.CountActiveByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "CountActiveByGroupID")
s.Require().Equal(int64(1), count, "only future expiry counts as active")
}
// --- DeleteByGroupID ---
func (s *UserSubscriptionRepoSuite) TestDeleteByGroupID() {
user1 := s.mustCreateUser("delgrp1@test.com", service.RoleUser)
user2 := s.mustCreateUser("delgrp2@test.com", service.RoleUser)
group := s.mustCreateGroup("g-delgrp")
s.mustCreateSubscription(user1.ID, group.ID, nil)
s.mustCreateSubscription(user2.ID, group.ID, nil)
affected, err := s.repo.DeleteByGroupID(s.ctx, group.ID)
s.Require().NoError(err, "DeleteByGroupID")
s.Require().Equal(int64(2), affected)
count, _ := s.repo.CountByGroupID(s.ctx, group.ID)
s.Require().Zero(count)
}
// --- Combined scenario ---
func (s *UserSubscriptionRepoSuite) TestActiveExpiredBoundaries_UsageAndReset_BatchUpdateExpiredStatus() {
user := s.mustCreateUser("subr@example.com", service.RoleUser)
groupActive := s.mustCreateGroup("g-subr-active")
groupExpired := s.mustCreateGroup("g-subr-expired")
active := s.mustCreateSubscription(user.ID, groupActive.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(2 * time.Hour))
})
expiredActive := s.mustCreateSubscription(user.ID, groupExpired.ID, func(c *dbent.UserSubscriptionCreate) {
c.SetExpiresAt(time.Now().Add(-2 * time.Hour))
})
got, err := s.repo.GetActiveByUserIDAndGroupID(s.ctx, user.ID, groupActive.ID)
s.Require().NoError(err, "GetActiveByUserIDAndGroupID")
s.Require().Equal(active.ID, got.ID, "expected active subscription")
activateAt := time.Now().Add(-25 * time.Hour)
s.Require().NoError(s.repo.ActivateWindows(s.ctx, active.ID, activateAt), "ActivateWindows")
s.Require().NoError(s.repo.IncrementUsage(s.ctx, active.ID, 1.25), "IncrementUsage")
after, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID")
s.Require().InDelta(1.25, after.DailyUsageUSD, 1e-6)
s.Require().InDelta(1.25, after.WeeklyUsageUSD, 1e-6)
s.Require().InDelta(1.25, after.MonthlyUsageUSD, 1e-6)
s.Require().NotNil(after.DailyWindowStart, "expected DailyWindowStart activated")
s.Require().NotNil(after.WeeklyWindowStart, "expected WeeklyWindowStart activated")
s.Require().NotNil(after.MonthlyWindowStart, "expected MonthlyWindowStart activated")
resetAt := time.Now().Truncate(time.Microsecond) // truncate to microsecond for DB precision
s.Require().NoError(s.repo.ResetDailyUsage(s.ctx, active.ID, resetAt), "ResetDailyUsage")
afterReset, err := s.repo.GetByID(s.ctx, active.ID)
s.Require().NoError(err, "GetByID after reset")
s.Require().InDelta(0.0, afterReset.DailyUsageUSD, 1e-6)
s.Require().NotNil(afterReset.DailyWindowStart)
s.Require().WithinDuration(resetAt, *afterReset.DailyWindowStart, time.Microsecond)
affected, err := s.repo.BatchUpdateExpiredStatus(s.ctx)
s.Require().NoError(err, "BatchUpdateExpiredStatus")
s.Require().Equal(int64(1), affected, "expected 1 affected row")
updated, err := s.repo.GetByID(s.ctx, expiredActive.ID)
s.Require().NoError(err, "GetByID expired")
s.Require().Equal(service.SubscriptionStatusExpired, updated.Status, "expected status expired")
}
// --- 软删除过滤测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_SoftDeletedGroup() {
user := s.mustCreateUser("softdeleted@test.com", service.RoleUser)
group := s.mustCreateGroup("g-softdeleted")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
// 软删除分组
_, err := s.client.Group.UpdateOneID(group.ID).SetDeletedAt(time.Now()).Save(s.ctx)
s.Require().NoError(err, "soft delete group")
// IncrementUsage 应该失败,因为分组已软删除
err = s.repo.IncrementUsage(s.ctx, sub.ID, 1.0)
s.Require().Error(err, "should fail for soft-deleted group")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_NotFound() {
err := s.repo.IncrementUsage(s.ctx, 999999, 1.0)
s.Require().Error(err, "should fail for non-existent subscription")
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}
// --- nil 入参测试 ---
func (s *UserSubscriptionRepoSuite) TestCreate_NilInput() {
err := s.repo.Create(s.ctx, nil)
s.Require().Error(err, "Create should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
func (s *UserSubscriptionRepoSuite) TestUpdate_NilInput() {
err := s.repo.Update(s.ctx, nil)
s.Require().Error(err, "Update should fail with nil input")
s.Require().ErrorIs(err, service.ErrSubscriptionNilInput)
}
// --- 并发用量更新测试 ---
func (s *UserSubscriptionRepoSuite) TestIncrementUsage_Concurrent() {
user := s.mustCreateUser("concurrent@test.com", service.RoleUser)
group := s.mustCreateGroup("g-concurrent")
sub := s.mustCreateSubscription(user.ID, group.ID, nil)
const numGoroutines = 10
const incrementPerGoroutine = 1.5
// 启动多个 goroutine 并发调用 IncrementUsage
errCh := make(chan error, numGoroutines)
for i := 0; i < numGoroutines; i++ {
go func() {
errCh <- s.repo.IncrementUsage(s.ctx, sub.ID, incrementPerGoroutine)
}()
}
// 等待所有 goroutine 完成
for i := 0; i < numGoroutines; i++ {
err := <-errCh
s.Require().NoError(err, "IncrementUsage should succeed")
}
// 验证累加结果正确
got, err := s.repo.GetByID(s.ctx, sub.ID)
s.Require().NoError(err)
expectedUsage := float64(numGoroutines) * incrementPerGoroutine
s.Require().InDelta(expectedUsage, got.DailyUsageUSD, 1e-6, "daily usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.WeeklyUsageUSD, 1e-6, "weekly usage should be correctly accumulated")
s.Require().InDelta(expectedUsage, got.MonthlyUsageUSD, 1e-6, "monthly usage should be correctly accumulated")
}
func (s *UserSubscriptionRepoSuite) TestTxContext_RollbackIsolation() {
baseClient := testEntClient(s.T())
tx, err := baseClient.Tx(context.Background())
s.Require().NoError(err, "begin tx")
defer func() {
if tx != nil {
_ = tx.Rollback()
}
}()
txCtx := dbent.NewTxContext(context.Background(), tx)
suffix := fmt.Sprintf("%d", time.Now().UnixNano())
userEnt, err := tx.Client().User.Create().
SetEmail("tx-user-" + suffix + "@example.com").
SetPasswordHash("test").
Save(txCtx)
s.Require().NoError(err, "create user in tx")
groupEnt, err := tx.Client().Group.Create().
SetName("tx-group-" + suffix).
Save(txCtx)
s.Require().NoError(err, "create group in tx")
repo := NewUserSubscriptionRepository(baseClient)
sub := &service.UserSubscription{
UserID: userEnt.ID,
GroupID: groupEnt.ID,
ExpiresAt: time.Now().AddDate(0, 0, 30),
Status: service.SubscriptionStatusActive,
AssignedAt: time.Now(),
Notes: "tx",
}
s.Require().NoError(repo.Create(txCtx, sub), "create subscription in tx")
s.Require().NoError(repo.UpdateNotes(txCtx, sub.ID, "tx-note"), "update subscription in tx")
s.Require().NoError(tx.Rollback(), "rollback tx")
tx = nil
_, err = repo.GetByID(context.Background(), sub.ID)
s.Require().ErrorIs(err, service.ErrSubscriptionNotFound)
}

View File

@@ -0,0 +1,126 @@
package repository
import (
"context"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
// WebhookRepository Webhook 持久化仓储
type WebhookRepository struct {
db *gorm.DB
}
// NewWebhookRepository 创建 Webhook 仓储
func NewWebhookRepository(db *gorm.DB) *WebhookRepository {
return &WebhookRepository{db: db}
}
// Create 创建 Webhook
func (r *WebhookRepository) Create(ctx context.Context, wh *domain.Webhook) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill inactive status so repository callers can persist status=0.
requestedStatus := wh.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(wh).Error; err != nil {
return err
}
if requestedStatus == domain.WebhookStatusInactive {
if err := tx.Model(&domain.Webhook{}).Where("id = ?", wh.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
wh.Status = requestedStatus
}
return nil
})
}
// Update 更新 Webhook 字段(只更新 updates map 中的字段)
func (r *WebhookRepository) Update(ctx context.Context, id int64, updates map[string]interface{}) error {
return r.db.WithContext(ctx).
Model(&domain.Webhook{}).
Where("id = ?", id).
Updates(updates).Error
}
// Delete 删除 Webhook软删除
func (r *WebhookRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Webhook{}, id).Error
}
// GetByID 按 ID 获取 Webhook
func (r *WebhookRepository) GetByID(ctx context.Context, id int64) (*domain.Webhook, error) {
var wh domain.Webhook
err := r.db.WithContext(ctx).First(&wh, id).Error
if err != nil {
return nil, err
}
return &wh, nil
}
// ListByCreator 按创建者列出 WebhookcreatedBy=0 表示列出所有)
func (r *WebhookRepository) ListByCreator(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
var webhooks []*domain.Webhook
query := r.db.WithContext(ctx)
if createdBy > 0 {
query = query.Where("created_by = ?", createdBy)
}
if err := query.Find(&webhooks).Error; err != nil {
return nil, err
}
return webhooks, nil
}
// ListByCreatorPaginated 按创建者分页列出 WebhookcreatedBy=0 表示列出所有)
func (r *WebhookRepository) ListByCreatorPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
var webhooks []*domain.Webhook
var total int64
query := r.db.WithContext(ctx).Model(&domain.Webhook{})
if createdBy > 0 {
query = query.Where("created_by = ?", createdBy)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if offset > 0 {
query = query.Offset(offset)
}
if limit > 0 {
query = query.Limit(limit)
}
if err := query.Order("created_at DESC").Find(&webhooks).Error; err != nil {
return nil, 0, err
}
return webhooks, total, nil
}
// ListActive 列出所有状态为活跃的 Webhook
func (r *WebhookRepository) ListActive(ctx context.Context) ([]*domain.Webhook, error) {
var webhooks []*domain.Webhook
err := r.db.WithContext(ctx).
Where("status = ?", domain.WebhookStatusActive).
Find(&webhooks).Error
return webhooks, err
}
// CreateDelivery 记录投递日志
func (r *WebhookRepository) CreateDelivery(ctx context.Context, delivery *domain.WebhookDelivery) error {
return r.db.WithContext(ctx).Create(delivery).Error
}
// ListDeliveries 按 Webhook ID 分页查询投递记录(最新在前)
func (r *WebhookRepository) ListDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
var deliveries []*domain.WebhookDelivery
err := r.db.WithContext(ctx).
Where("webhook_id = ?", webhookID).
Order("created_at DESC").
Limit(limit).
Find(&deliveries).Error
return deliveries, err
}

View File

@@ -0,0 +1,190 @@
package repository
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/domain"
)
func setupWebhookRepository(t *testing.T) *WebhookRepository {
t.Helper()
db := openTestDB(t)
if err := db.AutoMigrate(&domain.Webhook{}, &domain.WebhookDelivery{}); err != nil {
t.Fatalf("migrate webhook tables failed: %v", err)
}
return NewWebhookRepository(db)
}
func newWebhookFixture(name string, createdBy int64, status domain.WebhookStatus) *domain.Webhook {
return &domain.Webhook{
Name: name,
URL: "https://example.com/webhook",
Secret: "secret-demo",
Events: `["user.registered"]`,
Status: status,
MaxRetries: 3,
TimeoutSec: 10,
CreatedBy: createdBy,
}
}
func TestWebhookRepositoryCreateGetUpdateAndDelete(t *testing.T) {
repo := setupWebhookRepository(t)
ctx := context.Background()
webhook := newWebhookFixture("alpha", 101, domain.WebhookStatusActive)
if err := repo.Create(ctx, webhook); err != nil {
t.Fatalf("Create failed: %v", err)
}
if webhook.ID == 0 {
t.Fatal("expected webhook id to be assigned")
}
loaded, err := repo.GetByID(ctx, webhook.ID)
if err != nil {
t.Fatalf("GetByID failed: %v", err)
}
if loaded.Name != "alpha" {
t.Fatalf("expected loaded webhook name alpha, got %q", loaded.Name)
}
if err := repo.Update(ctx, webhook.ID, map[string]interface{}{
"name": "alpha-updated",
"status": domain.WebhookStatusInactive,
}); err != nil {
t.Fatalf("Update failed: %v", err)
}
updated, err := repo.GetByID(ctx, webhook.ID)
if err != nil {
t.Fatalf("GetByID after update failed: %v", err)
}
if updated.Name != "alpha-updated" {
t.Fatalf("expected updated name alpha-updated, got %q", updated.Name)
}
if updated.Status != domain.WebhookStatusInactive {
t.Fatalf("expected updated status inactive, got %d", updated.Status)
}
if err := repo.Delete(ctx, webhook.ID); err != nil {
t.Fatalf("Delete failed: %v", err)
}
if _, err := repo.GetByID(ctx, webhook.ID); err == nil {
t.Fatal("expected deleted webhook lookup to fail")
}
}
func TestWebhookRepositoryListsByCreatorAndActiveStatus(t *testing.T) {
repo := setupWebhookRepository(t)
ctx := context.Background()
fixtures := []*domain.Webhook{
newWebhookFixture("creator-1-active", 1, domain.WebhookStatusActive),
newWebhookFixture("creator-1-inactive", 1, domain.WebhookStatusInactive),
newWebhookFixture("creator-2-active", 2, domain.WebhookStatusActive),
}
for _, webhook := range fixtures {
if err := repo.Create(ctx, webhook); err != nil {
t.Fatalf("Create(%s) failed: %v", webhook.Name, err)
}
}
creatorOneHooks, err := repo.ListByCreator(ctx, 1)
if err != nil {
t.Fatalf("ListByCreator(1) failed: %v", err)
}
if len(creatorOneHooks) != 2 {
t.Fatalf("expected 2 hooks for creator 1, got %d", len(creatorOneHooks))
}
allHooks, err := repo.ListByCreator(ctx, 0)
if err != nil {
t.Fatalf("ListByCreator(0) failed: %v", err)
}
if len(allHooks) != 3 {
t.Fatalf("expected 3 hooks when listing all creators, got %d", len(allHooks))
}
activeHooks, err := repo.ListActive(ctx)
if err != nil {
t.Fatalf("ListActive failed: %v", err)
}
if len(activeHooks) != 2 {
t.Fatalf("expected 2 active hooks, got %d", len(activeHooks))
}
for _, hook := range activeHooks {
if hook.Status != domain.WebhookStatusActive {
t.Fatalf("expected active hook status, got %d", hook.Status)
}
}
}
func TestWebhookRepositoryCreateAndListDeliveries(t *testing.T) {
repo := setupWebhookRepository(t)
ctx := context.Background()
webhook := newWebhookFixture("delivery-hook", 7, domain.WebhookStatusActive)
if err := repo.Create(ctx, webhook); err != nil {
t.Fatalf("Create webhook failed: %v", err)
}
olderTime := time.Now().Add(-time.Minute)
newerTime := time.Now()
firstDelivery := &domain.WebhookDelivery{
WebhookID: webhook.ID,
EventType: domain.EventUserRegistered,
Payload: `{"user":"older"}`,
StatusCode: 200,
ResponseBody: `{"ok":true}`,
Attempt: 1,
Success: true,
CreatedAt: olderTime,
}
secondDelivery := &domain.WebhookDelivery{
WebhookID: webhook.ID,
EventType: domain.EventUserLogin,
Payload: `{"user":"newer"}`,
StatusCode: 500,
ResponseBody: `{"ok":false}`,
Attempt: 2,
Success: false,
Error: "delivery failed",
CreatedAt: newerTime,
}
if err := repo.CreateDelivery(ctx, firstDelivery); err != nil {
t.Fatalf("CreateDelivery(first) failed: %v", err)
}
if err := repo.CreateDelivery(ctx, secondDelivery); err != nil {
t.Fatalf("CreateDelivery(second) failed: %v", err)
}
latestOnly, err := repo.ListDeliveries(ctx, webhook.ID, 1)
if err != nil {
t.Fatalf("ListDeliveries(limit=1) failed: %v", err)
}
if len(latestOnly) != 1 {
t.Fatalf("expected 1 latest delivery, got %d", len(latestOnly))
}
if latestOnly[0].ID != secondDelivery.ID {
t.Fatalf("expected latest delivery id %d, got %d", secondDelivery.ID, latestOnly[0].ID)
}
allDeliveries, err := repo.ListDeliveries(ctx, webhook.ID, 10)
if err != nil {
t.Fatalf("ListDeliveries(limit=10) failed: %v", err)
}
if len(allDeliveries) != 2 {
t.Fatalf("expected 2 deliveries, got %d", len(allDeliveries))
}
if allDeliveries[0].ID != secondDelivery.ID || allDeliveries[1].ID != firstDelivery.ID {
t.Fatal("expected deliveries to be returned in reverse created_at order")
}
}