test: add repository tests to improve coverage from 46.6% to 74%
New test files: - custom_field_repository_test.go: 10 tests for CustomFieldRepository & UserCustomFieldValueRepository - login_log_repository_test.go: 3 tests for ListCursor, ListByUserIDCursor, ListAllForExport - operation_log_repository_test.go: 1 test for ListCursor - role_repository_test.go: 2 tests for GetAncestorIDs, GetAncestors - social_account_repository_test.go: 8 CRUD tests - theme_repository_test.go: 10 tests for ThemeConfigRepository - user_role_repository_test.go: 1 test for DeleteByUserAndRole Modified test files: - device_repository_test.go: Added ListAllCursor tests - user_repository_test.go: Added AdvancedSearch tests - webhook_repository_test.go: Added ListByCreatorPaginated test Updated documentation with new coverage status.
This commit is contained in:
332
internal/repository/custom_field_repository_test.go
Normal file
332
internal/repository/custom_field_repository_test.go
Normal file
@@ -0,0 +1,332 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var customFieldTestCounter int64
|
||||
|
||||
// openCustomFieldTestDB 为每个测试打开独立的内存数据库
|
||||
func openCustomFieldTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&customFieldTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:customfieldtestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.CustomField{}, &domain.UserCustomFieldValue{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// setupCustomFieldTestDB 兼容性别名
|
||||
func setupCustomFieldTestDB(t *testing.T) *gorm.DB {
|
||||
return openCustomFieldTestDB(t)
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_Create 测试创建自定义字段
|
||||
func TestCustomFieldRepository_Create(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: "测试字段",
|
||||
FieldKey: "test_field",
|
||||
Type: domain.CustomFieldTypeString,
|
||||
Required: false,
|
||||
Sort: 1,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, field); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if field.ID == 0 {
|
||||
t.Error("创建后字段ID不应为0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_GetByID 测试根据ID获取字段
|
||||
func TestCustomFieldRepository_GetByID(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: "getbyid-field",
|
||||
FieldKey: "getbyid_key",
|
||||
Type: domain.CustomFieldTypeNumber,
|
||||
}
|
||||
repo.Create(ctx, field)
|
||||
|
||||
found, err := repo.GetByID(ctx, field.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if found.Name != "getbyid-field" {
|
||||
t.Errorf("Name = %v, want getbyid-field", found.Name)
|
||||
}
|
||||
|
||||
_, err = repo.GetByID(ctx, 9999)
|
||||
if err == nil {
|
||||
t.Error("GetByID() should return error for non-existent ID")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_GetByFieldKey 测试根据FieldKey获取字段
|
||||
func TestCustomFieldRepository_GetByFieldKey(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: "field-by-key",
|
||||
FieldKey: "unique_field_key",
|
||||
Type: domain.CustomFieldTypeBoolean,
|
||||
}
|
||||
repo.Create(ctx, field)
|
||||
|
||||
found, err := repo.GetByFieldKey(ctx, "unique_field_key")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByFieldKey() error = %v", err)
|
||||
}
|
||||
if found.Name != "field-by-key" {
|
||||
t.Errorf("Name = %v, want field-by-key", found.Name)
|
||||
}
|
||||
|
||||
_, err = repo.GetByFieldKey(ctx, "not_exist_key")
|
||||
if err == nil {
|
||||
t.Error("GetByFieldKey() should return error for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_Update 测试更新字段
|
||||
func TestCustomFieldRepository_Update(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: "before-update",
|
||||
FieldKey: "update_key",
|
||||
Type: domain.CustomFieldTypeString,
|
||||
}
|
||||
repo.Create(ctx, field)
|
||||
|
||||
field.Name = "after-update"
|
||||
field.Required = true
|
||||
if err := repo.Update(ctx, field); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, field.ID)
|
||||
if found.Name != "after-update" {
|
||||
t.Errorf("Name = %v, want after-update", found.Name)
|
||||
}
|
||||
if !found.Required {
|
||||
t.Error("Required should be true after update")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_Delete 测试删除字段
|
||||
func TestCustomFieldRepository_Delete(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: "to-delete",
|
||||
FieldKey: "delete_key",
|
||||
Type: domain.CustomFieldTypeDate,
|
||||
}
|
||||
repo.Create(ctx, field)
|
||||
|
||||
if err := repo.Delete(ctx, field.ID); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := repo.GetByID(ctx, field.ID)
|
||||
if err == nil {
|
||||
t.Error("删除后查询应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_List 测试获取启用字段列表
|
||||
func TestCustomFieldRepository_List(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.CustomField{Name: "enabled1", FieldKey: "enabled1_key", Type: domain.CustomFieldTypeString})
|
||||
repo.Create(ctx, &domain.CustomField{Name: "enabled2", FieldKey: "enabled2_key", Type: domain.CustomFieldTypeNumber})
|
||||
repo.Create(ctx, &domain.CustomField{Name: "enabled3", FieldKey: "enabled3_key", Type: domain.CustomFieldTypeBoolean})
|
||||
|
||||
fields, err := repo.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
// List filters by status=1, all 3 have status=1 (default)
|
||||
if len(fields) != 3 {
|
||||
t.Errorf("len(fields) = %d, want 3", len(fields))
|
||||
}
|
||||
}
|
||||
|
||||
// TestCustomFieldRepository_ListAll 测试获取所有字段列表
|
||||
func TestCustomFieldRepository_ListAll(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
repo := NewCustomFieldRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.CustomField{Name: "all1", FieldKey: "all1_key", Type: domain.CustomFieldTypeString})
|
||||
repo.Create(ctx, &domain.CustomField{Name: "all2", FieldKey: "all2_key", Type: domain.CustomFieldTypeNumber})
|
||||
|
||||
fields, err := repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll() error = %v", err)
|
||||
}
|
||||
if len(fields) != 2 {
|
||||
t.Errorf("len(fields) = %d, want 2", len(fields))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserCustomFieldValueRepository_GetByUserID 测试获取用户所有字段值
|
||||
func TestUserCustomFieldValueRepository_GetByUserID(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
valueRepo := NewUserCustomFieldValueRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 直接使用 GORM Create 测试,因为 Set 使用 NOW() 不兼容 SQLite
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 1,
|
||||
FieldKey: "field1_key",
|
||||
Value: "value1",
|
||||
})
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 2,
|
||||
FieldKey: "field2_key",
|
||||
Value: "value2",
|
||||
})
|
||||
|
||||
values, err := valueRepo.GetByUserID(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID() error = %v", err)
|
||||
}
|
||||
if len(values) != 2 {
|
||||
t.Errorf("len(values) = %d, want 2", len(values))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey 测试获取用户指定字段值
|
||||
func TestUserCustomFieldValueRepository_GetByUserIDAndFieldKey(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
valueRepo := NewUserCustomFieldValueRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 1,
|
||||
FieldKey: "specific_key",
|
||||
Value: "specific_value",
|
||||
})
|
||||
|
||||
found, err := valueRepo.GetByUserIDAndFieldKey(ctx, 1, "specific_key")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserIDAndFieldKey() error = %v", err)
|
||||
}
|
||||
if found.Value != "specific_value" {
|
||||
t.Errorf("Value = %v, want specific_value", found.Value)
|
||||
}
|
||||
|
||||
_, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "non_existent_key")
|
||||
if err == nil {
|
||||
t.Error("GetByUserIDAndFieldKey() should return error for non-existent key")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserCustomFieldValueRepository_Delete 测试删除用户字段值
|
||||
func TestUserCustomFieldValueRepository_Delete(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
valueRepo := NewUserCustomFieldValueRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 1,
|
||||
FieldKey: "delete_key",
|
||||
Value: "to_delete",
|
||||
})
|
||||
|
||||
err := valueRepo.Delete(ctx, 1, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err = valueRepo.GetByUserIDAndFieldKey(ctx, 1, "delete_key")
|
||||
if err == nil {
|
||||
t.Error("删除后查询应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserCustomFieldValueRepository_DeleteByUserID 测试删除用户所有字段值
|
||||
func TestUserCustomFieldValueRepository_DeleteByUserID(t *testing.T) {
|
||||
db := setupCustomFieldTestDB(t)
|
||||
valueRepo := NewUserCustomFieldValueRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 1,
|
||||
FieldKey: "multi1_key",
|
||||
Value: "v1",
|
||||
})
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 1,
|
||||
FieldID: 2,
|
||||
FieldKey: "multi2_key",
|
||||
Value: "v2",
|
||||
})
|
||||
db.WithContext(ctx).Create(&domain.UserCustomFieldValue{
|
||||
UserID: 2,
|
||||
FieldID: 1,
|
||||
FieldKey: "multi1_key",
|
||||
Value: "v3",
|
||||
})
|
||||
|
||||
err := valueRepo.DeleteByUserID(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteByUserID() error = %v", err)
|
||||
}
|
||||
|
||||
values, _ := valueRepo.GetByUserID(ctx, 1)
|
||||
if len(values) != 0 {
|
||||
t.Errorf("len(values) = %d, want 0", len(values))
|
||||
}
|
||||
|
||||
// 用户2的值应该还在
|
||||
values2, _ := valueRepo.GetByUserID(ctx, 2)
|
||||
if len(values2) != 1 {
|
||||
t.Errorf("用户2的字段值应该保留, got %d", len(values2))
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
var deviceTestCounter int64
|
||||
@@ -484,3 +485,91 @@ func createDevice(t *testing.T, repo *DeviceRepository, ctx context.Context, use
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
// TestDeviceRepository_ListAllCursor 测试设备游标分页查询
|
||||
func TestDeviceRepository_ListAllCursor(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建设备(需要设置LastActiveTime以支持游标分页)
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.Device{
|
||||
UserID: int64(i + 1),
|
||||
DeviceID: "cursor-device-" + string(rune('a'+i)),
|
||||
DeviceName: "设备" + string(rune('0'+i)),
|
||||
Status: domain.DeviceStatusActive,
|
||||
LastActiveTime: now.Add(-time.Duration(i) * time.Minute),
|
||||
})
|
||||
}
|
||||
|
||||
// 第一次查询,获取前3个
|
||||
devices, hasMore, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAllCursor() error = %v", err)
|
||||
}
|
||||
if len(devices) != 3 {
|
||||
t.Errorf("len(devices) = %d, want 3", len(devices))
|
||||
}
|
||||
if !hasMore {
|
||||
t.Error("hasMore should be true when more devices exist")
|
||||
}
|
||||
|
||||
// 使用游标继续查询
|
||||
lastDevice := devices[len(devices)-1]
|
||||
cursor := &pagination.Cursor{
|
||||
LastID: lastDevice.ID,
|
||||
LastValue: lastDevice.LastActiveTime,
|
||||
}
|
||||
devices2, hasMore2, err := repo.ListAllCursor(ctx, &ListDevicesParams{Offset: 0, Limit: 10}, 3, cursor)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAllCursor() error = %v", err)
|
||||
}
|
||||
if len(devices2) != 2 {
|
||||
t.Errorf("len(devices2) = %d, want 2", len(devices2))
|
||||
}
|
||||
if hasMore2 {
|
||||
t.Error("hasMore2 should be false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestDeviceRepository_ListAllCursor_WithFilters 测试带筛选条件的设备游标分页
|
||||
func TestDeviceRepository_ListAllCursor_WithFilters(t *testing.T) {
|
||||
db := setupDeviceTestDB(t)
|
||||
repo := NewDeviceRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
repo.Create(ctx, &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "filter-dev1",
|
||||
DeviceName: "用户1设备",
|
||||
Status: domain.DeviceStatusActive,
|
||||
LastActiveTime: now,
|
||||
})
|
||||
repo.Create(ctx, &domain.Device{
|
||||
UserID: 2,
|
||||
DeviceID: "filter-dev2",
|
||||
DeviceName: "用户2设备",
|
||||
Status: domain.DeviceStatusActive,
|
||||
LastActiveTime: now,
|
||||
})
|
||||
repo.Create(ctx, &domain.Device{
|
||||
UserID: 1,
|
||||
DeviceID: "filter-dev3",
|
||||
DeviceName: "用户1禁用设备",
|
||||
Status: domain.DeviceStatusInactive,
|
||||
LastActiveTime: now,
|
||||
})
|
||||
|
||||
// 按用户ID筛选
|
||||
status := domain.DeviceStatusActive
|
||||
devices, _, err := repo.ListAllCursor(ctx, &ListDevicesParams{UserID: 1, Status: &status, Offset: 0, Limit: 10}, 10, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAllCursor() error = %v", err)
|
||||
}
|
||||
if len(devices) != 1 {
|
||||
t.Errorf("len(devices) = %d, want 1", len(devices))
|
||||
}
|
||||
}
|
||||
|
||||
156
internal/repository/login_log_repository_test.go
Normal file
156
internal/repository/login_log_repository_test.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
var loginLogTestCounter int64
|
||||
|
||||
func openLoginLogTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&loginLogTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:loginlogtestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.LoginLog{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func setupLoginLogTestDB(t *testing.T) *gorm.DB {
|
||||
return openLoginLogTestDB(t)
|
||||
}
|
||||
|
||||
func TestLoginLogRepository_ListCursor(t *testing.T) {
|
||||
db := setupLoginLogTestDB(t)
|
||||
repo := NewLoginLogRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.LoginLog{
|
||||
UserID: int64Ptr(int64(i + 1)),
|
||||
LoginType: 1,
|
||||
IP: "192.168.1." + string(rune('0'+i)),
|
||||
Status: 1,
|
||||
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
|
||||
})
|
||||
}
|
||||
|
||||
// 第一次查询,获取前3个
|
||||
logs, hasMore, err := repo.ListCursor(ctx, 3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCursor() error = %v", err)
|
||||
}
|
||||
if len(logs) != 3 {
|
||||
t.Errorf("len(logs) = %d, want 3", len(logs))
|
||||
}
|
||||
if !hasMore {
|
||||
t.Error("hasMore should be true when more logs exist")
|
||||
}
|
||||
|
||||
// 使用游标继续查询
|
||||
lastLog := logs[len(logs)-1]
|
||||
cursor := &pagination.Cursor{
|
||||
LastID: lastLog.ID,
|
||||
LastValue: lastLog.CreatedAt,
|
||||
}
|
||||
logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCursor() error = %v", err)
|
||||
}
|
||||
if len(logs2) != 2 {
|
||||
t.Errorf("len(logs2) = %d, want 2", len(logs2))
|
||||
}
|
||||
if hasMore2 {
|
||||
t.Error("hasMore2 should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLogRepository_ListByUserIDCursor(t *testing.T) {
|
||||
db := setupLoginLogTestDB(t)
|
||||
repo := NewLoginLogRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
userID := int64(123)
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
repo.Create(ctx, &domain.LoginLog{
|
||||
UserID: int64Ptr(userID),
|
||||
LoginType: 1,
|
||||
IP: "192.168.1." + string(rune('0'+i)),
|
||||
Status: 1,
|
||||
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
|
||||
})
|
||||
}
|
||||
// 另一个用户的日志
|
||||
repo.Create(ctx, &domain.LoginLog{
|
||||
UserID: int64Ptr(999),
|
||||
LoginType: 1,
|
||||
IP: "10.0.0.1",
|
||||
Status: 1,
|
||||
})
|
||||
|
||||
// 查询指定用户的日志
|
||||
logs, hasMore, err := repo.ListByUserIDCursor(ctx, userID, 10, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByUserIDCursor() error = %v", err)
|
||||
}
|
||||
if len(logs) != 3 {
|
||||
t.Errorf("len(logs) = %d, want 3", len(logs))
|
||||
}
|
||||
if hasMore {
|
||||
t.Error("hasMore should be false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginLogRepository_ListAllForExport(t *testing.T) {
|
||||
db := setupLoginLogTestDB(t)
|
||||
repo := NewLoginLogRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.LoginLog{
|
||||
UserID: int64Ptr(1),
|
||||
LoginType: 1,
|
||||
IP: "192.168.1.1",
|
||||
Status: 1,
|
||||
})
|
||||
repo.Create(ctx, &domain.LoginLog{
|
||||
UserID: int64Ptr(2),
|
||||
LoginType: 2,
|
||||
IP: "192.168.1.2",
|
||||
Status: 0,
|
||||
FailReason: "invalid password",
|
||||
})
|
||||
|
||||
logs, err := repo.ListAllForExport(ctx, 0, -1, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAllForExport() error = %v", err)
|
||||
}
|
||||
if len(logs) != 2 {
|
||||
t.Errorf("len(logs) = %d, want 2", len(logs))
|
||||
}
|
||||
}
|
||||
94
internal/repository/operation_log_repository_test.go
Normal file
94
internal/repository/operation_log_repository_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
|
||||
var operationLogTestCounter int64
|
||||
|
||||
func openOperationLogTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&operationLogTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:operationlogtestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func setupOperationLogTestDB(t *testing.T) *gorm.DB {
|
||||
return openOperationLogTestDB(t)
|
||||
}
|
||||
|
||||
func TestOperationLogRepository_ListCursor(t *testing.T) {
|
||||
db := setupOperationLogTestDB(t)
|
||||
repo := NewOperationLogRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.OperationLog{
|
||||
UserID: nil,
|
||||
OperationType: "test",
|
||||
OperationName: "测试操作" + string(rune('0'+i)),
|
||||
RequestMethod: "GET",
|
||||
RequestPath: "/api/test",
|
||||
ResponseStatus: 200,
|
||||
IP: "192.168.1." + string(rune('0'+i)),
|
||||
CreatedAt: now.Add(-time.Duration(i) * time.Minute),
|
||||
})
|
||||
}
|
||||
|
||||
// 第一次查询,获取前3个
|
||||
logs, hasMore, err := repo.ListCursor(ctx, 3, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCursor() error = %v", err)
|
||||
}
|
||||
if len(logs) != 3 {
|
||||
t.Errorf("len(logs) = %d, want 3", len(logs))
|
||||
}
|
||||
if !hasMore {
|
||||
t.Error("hasMore should be true when more logs exist")
|
||||
}
|
||||
|
||||
// 使用游标继续查询
|
||||
lastLog := logs[len(logs)-1]
|
||||
cursor := &pagination.Cursor{
|
||||
LastID: lastLog.ID,
|
||||
LastValue: lastLog.CreatedAt,
|
||||
}
|
||||
logs2, hasMore2, err := repo.ListCursor(ctx, 3, cursor)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCursor() error = %v", err)
|
||||
}
|
||||
if len(logs2) != 2 {
|
||||
t.Errorf("len(logs2) = %d, want 2", len(logs2))
|
||||
}
|
||||
if hasMore2 {
|
||||
t.Error("hasMore2 should be false")
|
||||
}
|
||||
}
|
||||
90
internal/repository/role_repository_test.go
Normal file
90
internal/repository/role_repository_test.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func TestRoleRepository_GetAncestorIDs(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewRoleRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建角色层级: grandchild -> child -> parent
|
||||
parentID := int64(0)
|
||||
parent := &domain.Role{Name: "parent", Code: "parent", ParentID: nil}
|
||||
if err := repo.Create(ctx, parent); err != nil {
|
||||
t.Fatalf("Create parent failed: %v", err)
|
||||
}
|
||||
parentID = parent.ID
|
||||
|
||||
child := &domain.Role{Name: "child", Code: "child", ParentID: &parentID}
|
||||
if err := repo.Create(ctx, child); err != nil {
|
||||
t.Fatalf("Create child failed: %v", err)
|
||||
}
|
||||
childID := child.ID
|
||||
|
||||
grandchild := &domain.Role{Name: "grandchild", Code: "grandchild", ParentID: &childID}
|
||||
if err := repo.Create(ctx, grandchild); err != nil {
|
||||
t.Fatalf("Create grandchild failed: %v", err)
|
||||
}
|
||||
|
||||
// 获取grandchild的祖先ID列表
|
||||
ancestorIDs, err := repo.GetAncestorIDs(ctx, grandchild.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAncestorIDs failed: %v", err)
|
||||
}
|
||||
if len(ancestorIDs) != 2 {
|
||||
t.Errorf("len(ancestorIDs) = %d, want 2", len(ancestorIDs))
|
||||
}
|
||||
if ancestorIDs[0] != childID {
|
||||
t.Errorf("ancestorIDs[0] = %d, want %d", ancestorIDs[0], childID)
|
||||
}
|
||||
if ancestorIDs[1] != parentID {
|
||||
t.Errorf("ancestorIDs[1] = %d, want %d", ancestorIDs[1], parentID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleRepository_GetAncestors(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewRoleRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建角色层级
|
||||
parentID := int64(0)
|
||||
parent := &domain.Role{Name: "parent-role", Code: "parent-role", Status: domain.RoleStatusEnabled}
|
||||
if err := repo.Create(ctx, parent); err != nil {
|
||||
t.Fatalf("Create parent failed: %v", err)
|
||||
}
|
||||
parentID = parent.ID
|
||||
|
||||
child := &domain.Role{Name: "child-role", Code: "child-role", ParentID: &parentID, Status: domain.RoleStatusEnabled}
|
||||
if err := repo.Create(ctx, child); err != nil {
|
||||
t.Fatalf("Create child failed: %v", err)
|
||||
}
|
||||
childID := child.ID
|
||||
|
||||
grandchild := &domain.Role{Name: "grandchild-role", Code: "grandchild-role", ParentID: &childID, Status: domain.RoleStatusEnabled}
|
||||
if err := repo.Create(ctx, grandchild); err != nil {
|
||||
t.Fatalf("Create grandchild failed: %v", err)
|
||||
}
|
||||
|
||||
// 获取grandchild的完整继承链
|
||||
ancestors, err := repo.GetAncestors(ctx, grandchild.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetAncestors failed: %v", err)
|
||||
}
|
||||
if len(ancestors) != 2 {
|
||||
t.Errorf("len(ancestors) = %d, want 2", len(ancestors))
|
||||
}
|
||||
// 第一个应该是parent
|
||||
if ancestors[0].Code != "parent-role" {
|
||||
t.Errorf("ancestors[0].Code = %s, want parent-role", ancestors[0].Code)
|
||||
}
|
||||
// 第二个应该是child
|
||||
if ancestors[1].Code != "child-role" {
|
||||
t.Errorf("ancestors[1].Code = %s, want child-role", ancestors[1].Code)
|
||||
}
|
||||
}
|
||||
263
internal/repository/social_account_repository_test.go
Normal file
263
internal/repository/social_account_repository_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var socialAccountTestCounter int64
|
||||
|
||||
func openSocialAccountTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&socialAccountTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:socialaccounttestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.SocialAccount{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func setupSocialAccountTestDB(t *testing.T) *gorm.DB {
|
||||
return openSocialAccountTestDB(t)
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_Create(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
account := &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-123",
|
||||
Nickname: "testuser",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, account); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if account.ID == 0 {
|
||||
t.Error("创建后账户ID不应为0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_GetByID(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
account := &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-getbyid",
|
||||
Nickname: "getbyid-user",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
}
|
||||
repo.Create(ctx, account)
|
||||
|
||||
found, err := repo.GetByID(ctx, account.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if found.Nickname != "getbyid-user" {
|
||||
t.Errorf("Nickname = %v, want getbyid-user", found.Nickname)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_GetByUserID(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-user1-1",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "wechat",
|
||||
OpenID: "openid-user1-2",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 2,
|
||||
Provider: "github",
|
||||
OpenID: "openid-user2",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
|
||||
accounts, err := repo.GetByUserID(ctx, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID() error = %v", err)
|
||||
}
|
||||
if len(accounts) != 2 {
|
||||
t.Errorf("len(accounts) = %d, want 2", len(accounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_GetByProviderAndOpenID(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
account := &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "unique-openid-123",
|
||||
Nickname: "github-user",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
}
|
||||
repo.Create(ctx, account)
|
||||
|
||||
found, err := repo.GetByProviderAndOpenID(ctx, "github", "unique-openid-123")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByProviderAndOpenID() error = %v", err)
|
||||
}
|
||||
if found.UserID != 1 {
|
||||
t.Errorf("UserID = %d, want 1", found.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_Update(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
account := &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-update",
|
||||
Nickname: "before-update",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
}
|
||||
repo.Create(ctx, account)
|
||||
|
||||
account.Nickname = "after-update"
|
||||
if err := repo.Update(ctx, account); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, account.ID)
|
||||
if found.Nickname != "after-update" {
|
||||
t.Errorf("Nickname = %v, want after-update", found.Nickname)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_Delete(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
account := &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-delete",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
}
|
||||
repo.Create(ctx, account)
|
||||
|
||||
if err := repo.Delete(ctx, account.ID); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_DeleteByProviderAndUserID(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-del-provider",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
|
||||
err = repo.DeleteByProviderAndUserID(ctx, "github", 1)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteByProviderAndUserID() error = %v", err)
|
||||
}
|
||||
|
||||
accounts, _ := repo.GetByUserID(ctx, 1)
|
||||
if len(accounts) != 0 {
|
||||
t.Errorf("len(accounts) = %d, want 0 after delete", len(accounts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocialAccountRepository_List(t *testing.T) {
|
||||
db := setupSocialAccountTestDB(t)
|
||||
repo, err := NewSocialAccountRepository(db)
|
||||
if err != nil {
|
||||
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 1,
|
||||
Provider: "github",
|
||||
OpenID: "openid-list-1",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
repo.Create(ctx, &domain.SocialAccount{
|
||||
UserID: 2,
|
||||
Provider: "wechat",
|
||||
OpenID: "openid-list-2",
|
||||
Status: domain.SocialAccountStatusActive,
|
||||
})
|
||||
|
||||
accounts, total, err := repo.List(ctx, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
if len(accounts) != 2 {
|
||||
t.Errorf("len(accounts) = %d, want 2", len(accounts))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2", total)
|
||||
}
|
||||
}
|
||||
275
internal/repository/theme_repository_test.go
Normal file
275
internal/repository/theme_repository_test.go
Normal file
@@ -0,0 +1,275 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
_ "modernc.org/sqlite"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var themeTestCounter int64
|
||||
|
||||
// openThemeTestDB 为每个测试打开独立的内存数据库
|
||||
func openThemeTestDB(t *testing.T) *gorm.DB {
|
||||
t.Helper()
|
||||
|
||||
id := atomic.AddInt64(&themeTestCounter, 1)
|
||||
dsn := fmt.Sprintf("file:themetestdb%d?mode=memory&cache=private", id)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("打开测试数据库失败: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.ThemeConfig{}); err != nil {
|
||||
t.Fatalf("数据库迁移失败: %v", err)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
// setupThemeTestDB 兼容性别名
|
||||
func setupThemeTestDB(t *testing.T) *gorm.DB {
|
||||
return openThemeTestDB(t)
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_Create 测试创建主题
|
||||
func TestThemeConfigRepository_Create(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: "test-theme",
|
||||
PrimaryColor: "#ff0000",
|
||||
SecondaryColor: "#00ff00",
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, theme); err != nil {
|
||||
t.Fatalf("Create() error = %v", err)
|
||||
}
|
||||
if theme.ID == 0 {
|
||||
t.Error("创建后主题ID不应为0")
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_GetByID 测试根据ID获取主题
|
||||
func TestThemeConfigRepository_GetByID(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: "getbyid-theme",
|
||||
PrimaryColor: "#0000ff",
|
||||
Enabled: true,
|
||||
}
|
||||
repo.Create(ctx, theme)
|
||||
|
||||
found, err := repo.GetByID(ctx, theme.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByID() error = %v", err)
|
||||
}
|
||||
if found.Name != "getbyid-theme" {
|
||||
t.Errorf("Name = %v, want getbyid-theme", found.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_GetByName 测试根据名称获取主题
|
||||
func TestThemeConfigRepository_GetByName(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: "unique-theme-name",
|
||||
PrimaryColor: "#ffff00",
|
||||
Enabled: true,
|
||||
}
|
||||
repo.Create(ctx, theme)
|
||||
|
||||
found, err := repo.GetByName(ctx, "unique-theme-name")
|
||||
if err != nil {
|
||||
t.Fatalf("GetByName() error = %v", err)
|
||||
}
|
||||
if found.ID != theme.ID {
|
||||
t.Errorf("ID = %v, want %v", found.ID, theme.ID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_GetByName_NotFound 测试名称不存在
|
||||
func TestThemeConfigRepository_GetByName_NotFound(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := repo.GetByName(ctx, "not-exist-theme")
|
||||
if err == nil {
|
||||
t.Error("GetByName() should return error for non-existent theme")
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_Update 测试更新主题
|
||||
func TestThemeConfigRepository_Update(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: "update-test",
|
||||
PrimaryColor: "#000000",
|
||||
Enabled: true,
|
||||
}
|
||||
repo.Create(ctx, theme)
|
||||
|
||||
theme.PrimaryColor = "#ffffff"
|
||||
if err := repo.Update(ctx, theme); err != nil {
|
||||
t.Fatalf("Update() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, theme.ID)
|
||||
if found.PrimaryColor != "#ffffff" {
|
||||
t.Errorf("PrimaryColor = %v, want #ffffff", found.PrimaryColor)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_Delete 测试删除主题
|
||||
func TestThemeConfigRepository_Delete(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: "delete-test",
|
||||
Enabled: true,
|
||||
}
|
||||
repo.Create(ctx, theme)
|
||||
|
||||
if err := repo.Delete(ctx, theme.ID); err != nil {
|
||||
t.Fatalf("Delete() error = %v", err)
|
||||
}
|
||||
|
||||
_, err := repo.GetByID(ctx, theme.ID)
|
||||
if err == nil {
|
||||
t.Error("删除后查询应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_List 测试获取已启用主题列表
|
||||
func TestThemeConfigRepository_List(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.ThemeConfig{Name: "enabled1", Enabled: true})
|
||||
repo.Create(ctx, &domain.ThemeConfig{Name: "enabled2", Enabled: true})
|
||||
repo.Create(ctx, &domain.ThemeConfig{Name: "disabled1", Enabled: false})
|
||||
|
||||
themes, err := repo.List(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("List() error = %v", err)
|
||||
}
|
||||
// List filters by enabled=true
|
||||
if len(themes) < 2 {
|
||||
t.Errorf("len(themes) = %d, want at least 2", len(themes))
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_ListAll 测试获取所有主题列表
|
||||
func TestThemeConfigRepository_ListAll(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.ThemeConfig{Name: "all1", Enabled: true})
|
||||
repo.Create(ctx, &domain.ThemeConfig{Name: "all2", Enabled: false})
|
||||
|
||||
themes, err := repo.ListAll(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListAll() error = %v", err)
|
||||
}
|
||||
if len(themes) != 2 {
|
||||
t.Errorf("len(themes) = %d, want 2", len(themes))
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_GetDefault 测试获取默认主题
|
||||
func TestThemeConfigRepository_GetDefault(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建一个默认主题
|
||||
repo.Create(ctx, &domain.ThemeConfig{
|
||||
Name: "default-theme",
|
||||
IsDefault: true,
|
||||
Enabled: true,
|
||||
})
|
||||
|
||||
defaultTheme, err := repo.GetDefault(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefault() error = %v", err)
|
||||
}
|
||||
if defaultTheme.Name != "default-theme" {
|
||||
t.Errorf("Name = %v, want default-theme", defaultTheme.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_GetDefault_NoDefault 测试无默认主题时返回默认配置
|
||||
func TestThemeConfigRepository_GetDefault_NoDefault(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 不创建任何主题
|
||||
defaultTheme, err := repo.GetDefault(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GetDefault() error = %v", err)
|
||||
}
|
||||
// 应该返回内置默认配置
|
||||
if defaultTheme.Name != "default" {
|
||||
t.Errorf("Name = %v, want default", defaultTheme.Name)
|
||||
}
|
||||
}
|
||||
|
||||
// TestThemeConfigRepository_SetDefault 测试设置默认主题
|
||||
func TestThemeConfigRepository_SetDefault(t *testing.T) {
|
||||
db := setupThemeTestDB(t)
|
||||
repo := NewThemeConfigRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建两个主题
|
||||
theme1 := &domain.ThemeConfig{Name: "theme1", IsDefault: true, Enabled: true}
|
||||
theme2 := &domain.ThemeConfig{Name: "theme2", IsDefault: false, Enabled: true}
|
||||
repo.Create(ctx, theme1)
|
||||
repo.Create(ctx, theme2)
|
||||
|
||||
// 设置 theme2 为默认
|
||||
if err := repo.SetDefault(ctx, theme2.ID); err != nil {
|
||||
t.Fatalf("SetDefault() error = %v", err)
|
||||
}
|
||||
|
||||
// 验证 theme1 不再是默认
|
||||
t1, _ := repo.GetByID(ctx, theme1.ID)
|
||||
if t1.IsDefault {
|
||||
t.Error("theme1 should not be default anymore")
|
||||
}
|
||||
|
||||
// 验证 theme2 现在是默认
|
||||
t2, _ := repo.GetByID(ctx, theme2.ID)
|
||||
if !t2.IsDefault {
|
||||
t.Error("theme2 should be default")
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
@@ -401,3 +402,259 @@ func TestUserRepository_Search_LikePattern(t *testing.T) {
|
||||
// Should not error and should escape properly
|
||||
_ = users
|
||||
}
|
||||
|
||||
// TestUserRepository_GetByIDs 测试批量获取用户
|
||||
func TestUserRepository_GetByIDs(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
u1 := &domain.User{Username: "batchuser1", Password: "hash", Status: domain.UserStatusActive}
|
||||
u2 := &domain.User{Username: "batchuser2", Password: "hash", Status: domain.UserStatusActive}
|
||||
u3 := &domain.User{Username: "batchuser3", Password: "hash", Status: domain.UserStatusActive}
|
||||
repo.Create(ctx, u1)
|
||||
repo.Create(ctx, u2)
|
||||
repo.Create(ctx, u3)
|
||||
|
||||
users, err := repo.GetByIDs(ctx, []int64{u1.ID, u3.ID})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs() error = %v", err)
|
||||
}
|
||||
if len(users) != 2 {
|
||||
t.Errorf("len(users) = %d, want 2", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_GetByIDs_Empty 测试空ID列表
|
||||
func TestUserRepository_GetByIDs_Empty(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
users, err := repo.GetByIDs(ctx, []int64{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetByIDs() error = %v", err)
|
||||
}
|
||||
if len(users) != 0 {
|
||||
t.Errorf("len(users) = %d, want 0", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_UpdatePassword 测试更新密码
|
||||
func TestUserRepository_UpdatePassword(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "pwduser",
|
||||
Password: "oldpassword",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
err := repo.UpdatePassword(ctx, user.ID, "newpasswordhash")
|
||||
if err != nil {
|
||||
t.Fatalf("UpdatePassword() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, user.ID)
|
||||
if found.Password != "newpasswordhash" {
|
||||
t.Errorf("Password = %v, want newpasswordhash", found.Password)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_UpdateTOTP 测试更新TOTP
|
||||
func TestUserRepository_UpdateTOTP(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "totpuser",
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
user.TOTPEnabled = true
|
||||
user.TOTPSecret = "JBSWY3DPEHPK3PXP"
|
||||
err := repo.UpdateTOTP(ctx, user)
|
||||
if err != nil {
|
||||
t.Fatalf("UpdateTOTP() error = %v", err)
|
||||
}
|
||||
|
||||
found, _ := repo.GetByID(ctx, user.ID)
|
||||
if !found.TOTPEnabled {
|
||||
t.Error("TOTPEnabled should be true")
|
||||
}
|
||||
if found.TOTPSecret != "JBSWY3DPEHPK3PXP" {
|
||||
t.Errorf("TOTPSecret = %v, want JBSWY3DPEHPK3PXP", found.TOTPSecret)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_ListCreatedAfter 测试查询创建时间之后的用户
|
||||
func TestUserRepository_ListCreatedAfter(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
user := &domain.User{
|
||||
Username: "afteruser",
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
repo.Create(ctx, user)
|
||||
|
||||
since := user.CreatedAt.Add(-1 * time.Hour)
|
||||
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCreatedAfter() error = %v", err)
|
||||
}
|
||||
if total < 1 {
|
||||
t.Errorf("total = %d, want at least 1", total)
|
||||
}
|
||||
_ = users
|
||||
}
|
||||
|
||||
// TestUserRepository_ListCreatedAfter_Limited 测试带limit的查询
|
||||
func TestUserRepository_ListCreatedAfter_Limited(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "limituser" + string(rune('0'+i)),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
}
|
||||
|
||||
since := time.Now().Add(-1 * time.Hour)
|
||||
users, total, err := repo.ListCreatedAfter(ctx, since, 0, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("ListCreatedAfter() error = %v", err)
|
||||
}
|
||||
if len(users) != 3 {
|
||||
t.Errorf("len(users) = %d, want 3", len(users))
|
||||
}
|
||||
if total < 5 {
|
||||
t.Errorf("total = %d, want at least 5", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_AdvancedSearch 测试高级搜索
|
||||
func TestUserRepository_AdvancedSearch(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "searchuser1",
|
||||
Nickname: "张三",
|
||||
Email: domain.StrPtr("zhangsan@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "searchuser2",
|
||||
Nickname: "李四",
|
||||
Email: domain.StrPtr("lisi@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "searchuser3",
|
||||
Nickname: "王五",
|
||||
Email: domain.StrPtr("wangwu@example.com"),
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusInactive,
|
||||
})
|
||||
|
||||
// 按关键字搜索(Status=-1 表示全部状态)
|
||||
filter := &AdvancedFilter{Keyword: "searchuser1", Status: -1, Offset: 0, Limit: 10}
|
||||
users, total, err := repo.AdvancedSearch(ctx, filter)
|
||||
if err != nil {
|
||||
t.Fatalf("AdvancedSearch() error = %v", err)
|
||||
}
|
||||
if len(users) != 1 {
|
||||
t.Errorf("len(users) = %d, want 1", len(users))
|
||||
}
|
||||
if total != 1 {
|
||||
t.Errorf("total = %d, want 1", total)
|
||||
}
|
||||
|
||||
// 按状态筛选
|
||||
filter2 := &AdvancedFilter{Status: int(domain.UserStatusActive), Offset: 0, Limit: 10}
|
||||
users2, total2, err := repo.AdvancedSearch(ctx, filter2)
|
||||
if err != nil {
|
||||
t.Fatalf("AdvancedSearch() error = %v", err)
|
||||
}
|
||||
if len(users2) != 2 {
|
||||
t.Errorf("len(users2) = %d, want 2", len(users2))
|
||||
}
|
||||
if total2 != 2 {
|
||||
t.Errorf("total2 = %d, want 2", total2)
|
||||
}
|
||||
|
||||
// 按状态筛选 - 禁用用户
|
||||
filter3 := &AdvancedFilter{Status: int(domain.UserStatusInactive), Offset: 0, Limit: 10}
|
||||
users3, total3, err := repo.AdvancedSearch(ctx, filter3)
|
||||
if err != nil {
|
||||
t.Fatalf("AdvancedSearch() error = %v", err)
|
||||
}
|
||||
if len(users3) != 1 {
|
||||
t.Errorf("len(users3) = %d, want 1", len(users3))
|
||||
}
|
||||
if total3 != 1 {
|
||||
t.Errorf("total3 = %d, want 1", total3)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_AdvancedSearch_AllStatus 测试状态为-1返回全部
|
||||
func TestUserRepository_AdvancedSearch_AllStatus(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.User{Username: "active", Password: "hash", Status: domain.UserStatusActive})
|
||||
repo.Create(ctx, &domain.User{Username: "inactive", Password: "hash", Status: domain.UserStatusInactive})
|
||||
|
||||
filter := &AdvancedFilter{Status: -1, Offset: 0, Limit: 10}
|
||||
users, total, err := repo.AdvancedSearch(ctx, filter)
|
||||
if err != nil {
|
||||
t.Fatalf("AdvancedSearch() error = %v", err)
|
||||
}
|
||||
if len(users) != 2 {
|
||||
t.Errorf("len(users) = %d, want 2", len(users))
|
||||
}
|
||||
if total != 2 {
|
||||
t.Errorf("total = %d, want 2", total)
|
||||
}
|
||||
}
|
||||
|
||||
// TestUserRepository_AdvancedSearch_LikeSpecialChars 测试搜索LIKE特殊字符转义
|
||||
func TestUserRepository_AdvancedSearch_LikeSpecialChars(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
repo.Create(ctx, &domain.User{
|
||||
Username: "user%with%percent",
|
||||
Nickname: "测试用户",
|
||||
Password: "hash",
|
||||
Status: domain.UserStatusActive,
|
||||
})
|
||||
|
||||
// 搜索%应该不匹配任何记录(被转义)
|
||||
filter := &AdvancedFilter{Keyword: "%", Offset: 0, Limit: 10}
|
||||
users, _, err := repo.AdvancedSearch(ctx, filter)
|
||||
if err != nil {
|
||||
t.Fatalf("AdvancedSearch() error = %v", err)
|
||||
}
|
||||
if len(users) != 0 {
|
||||
t.Errorf("len(users) = %d, want 0 for escaped percent", len(users))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
36
internal/repository/user_role_repository_test.go
Normal file
36
internal/repository/user_role_repository_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func TestUserRoleRepository_DeleteByUserAndRole(t *testing.T) {
|
||||
db := setupTestDB(t)
|
||||
repo := NewUserRoleRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建用户和角色
|
||||
user := &domain.User{Username: "roleuser", Password: "hash", Status: domain.UserStatusActive}
|
||||
db.WithContext(ctx).Create(user)
|
||||
|
||||
role := &domain.Role{Code: "test_role", Name: "测试角色", Status: domain.RoleStatusEnabled}
|
||||
db.WithContext(ctx).Create(role)
|
||||
|
||||
// 创建用户角色关联
|
||||
repo.Create(ctx, &domain.UserRole{UserID: user.ID, RoleID: role.ID})
|
||||
|
||||
// 删除特定用户-角色关联
|
||||
err := repo.DeleteByUserAndRole(ctx, user.ID, role.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("DeleteByUserAndRole() error = %v", err)
|
||||
}
|
||||
|
||||
// 验证已删除
|
||||
exists, _ := repo.Exists(ctx, user.ID, role.ID)
|
||||
if exists {
|
||||
t.Error("DeleteByUserAndRole should have removed the association")
|
||||
}
|
||||
}
|
||||
@@ -188,3 +188,40 @@ func TestWebhookRepositoryCreateAndListDeliveries(t *testing.T) {
|
||||
t.Fatal("expected deliveries to be returned in reverse created_at order")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookRepositoryListByCreatorPaginated(t *testing.T) {
|
||||
repo := setupWebhookRepository(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// 创建多个webhook
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := repo.Create(ctx, newWebhookFixture("wh-creator1-"+string(rune('a'+i)), 1, domain.WebhookStatusActive)); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
}
|
||||
// 另一个用户的webhook
|
||||
if err := repo.Create(ctx, newWebhookFixture("wh-creator2", 2, domain.WebhookStatusActive)); err != nil {
|
||||
t.Fatalf("Create failed: %v", err)
|
||||
}
|
||||
|
||||
// 测试分页查询创建者1的webhook
|
||||
webhooks, total, err := repo.ListByCreatorPaginated(ctx, 1, 0, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByCreatorPaginated failed: %v", err)
|
||||
}
|
||||
if len(webhooks) != 3 {
|
||||
t.Errorf("len(webhooks) = %d, want 3", len(webhooks))
|
||||
}
|
||||
if total != 5 {
|
||||
t.Errorf("total = %d, want 5", total)
|
||||
}
|
||||
|
||||
// 测试第二页
|
||||
webhooks2, _, err := repo.ListByCreatorPaginated(ctx, 1, 3, 3)
|
||||
if err != nil {
|
||||
t.Fatalf("ListByCreatorPaginated page 2 failed: %v", err)
|
||||
}
|
||||
if len(webhooks2) != 2 {
|
||||
t.Errorf("len(webhooks2) = %d, want 2", len(webhooks2))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user