467 lines
14 KiB
Go
467 lines
14 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"time"
|
|
|
|
"gorm.io/gorm"
|
|
|
|
"github.com/user-management-system/internal/domain"
|
|
)
|
|
|
|
func migrateRepositoryTables(t *testing.T, db *gorm.DB, tables ...interface{}) {
|
|
t.Helper()
|
|
|
|
if err := db.AutoMigrate(tables...); err != nil {
|
|
t.Fatalf("migrate repository tables failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func int64Ptr(v int64) *int64 {
|
|
return &v
|
|
}
|
|
|
|
func TestDeviceRepositoryLifecycleAndQueries(t *testing.T) {
|
|
db := openTestDB(t)
|
|
migrateRepositoryTables(t, db, &domain.Device{})
|
|
|
|
repo := NewDeviceRepository(db)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
devices := []*domain.Device{
|
|
{
|
|
UserID: 1,
|
|
DeviceID: "device-alpha",
|
|
DeviceName: "Alpha",
|
|
DeviceType: domain.DeviceTypeDesktop,
|
|
DeviceOS: "Windows",
|
|
DeviceBrowser: "Chrome",
|
|
IP: "10.0.0.1",
|
|
Location: "Shanghai",
|
|
Status: domain.DeviceStatusActive,
|
|
LastActiveTime: now.Add(-1 * time.Hour),
|
|
},
|
|
{
|
|
UserID: 1,
|
|
DeviceID: "device-beta",
|
|
DeviceName: "Beta",
|
|
DeviceType: domain.DeviceTypeWeb,
|
|
DeviceOS: "macOS",
|
|
DeviceBrowser: "Safari",
|
|
IP: "10.0.0.2",
|
|
Location: "Hangzhou",
|
|
Status: domain.DeviceStatusInactive,
|
|
LastActiveTime: now.Add(-2 * time.Hour),
|
|
},
|
|
{
|
|
UserID: 2,
|
|
DeviceID: "device-gamma",
|
|
DeviceName: "Gamma",
|
|
DeviceType: domain.DeviceTypeMobile,
|
|
DeviceOS: "Android",
|
|
DeviceBrowser: "WebView",
|
|
IP: "10.0.0.3",
|
|
Location: "Beijing",
|
|
Status: domain.DeviceStatusActive,
|
|
LastActiveTime: now.Add(-40 * 24 * time.Hour),
|
|
},
|
|
}
|
|
|
|
for _, device := range devices {
|
|
if err := repo.Create(ctx, device); err != nil {
|
|
t.Fatalf("Create(%s) failed: %v", device.DeviceID, err)
|
|
}
|
|
}
|
|
|
|
if allDevices, total, err := repo.List(ctx, 0, 10); err != nil {
|
|
t.Fatalf("List failed: %v", err)
|
|
} else if total != 3 || len(allDevices) != 3 {
|
|
t.Fatalf("expected 3 devices, got total=%d len=%d", total, len(allDevices))
|
|
}
|
|
|
|
loadedByDeviceID, err := repo.GetByDeviceID(ctx, 1, "device-beta")
|
|
if err != nil {
|
|
t.Fatalf("GetByDeviceID failed: %v", err)
|
|
}
|
|
if loadedByDeviceID.DeviceName != "Beta" {
|
|
t.Fatalf("expected device name Beta, got %q", loadedByDeviceID.DeviceName)
|
|
}
|
|
|
|
exists, err := repo.Exists(ctx, 1, "device-alpha")
|
|
if err != nil {
|
|
t.Fatalf("Exists(device-alpha) failed: %v", err)
|
|
}
|
|
if !exists {
|
|
t.Fatal("expected device-alpha to exist")
|
|
}
|
|
|
|
missing, err := repo.Exists(ctx, 1, "missing-device")
|
|
if err != nil {
|
|
t.Fatalf("Exists(missing-device) failed: %v", err)
|
|
}
|
|
if missing {
|
|
t.Fatal("expected missing-device to be absent")
|
|
}
|
|
|
|
userDevices, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByUserID failed: %v", err)
|
|
}
|
|
if total != 2 || len(userDevices) != 2 {
|
|
t.Fatalf("expected 2 devices for user 1, got total=%d len=%d", total, len(userDevices))
|
|
}
|
|
if userDevices[0].DeviceID != "device-alpha" {
|
|
t.Fatalf("expected latest active device first, got %q", userDevices[0].DeviceID)
|
|
}
|
|
|
|
activeDevices, total, err := repo.ListByStatus(ctx, domain.DeviceStatusActive, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByStatus failed: %v", err)
|
|
}
|
|
if total != 2 || len(activeDevices) != 2 {
|
|
t.Fatalf("expected 2 active devices, got total=%d len=%d", total, len(activeDevices))
|
|
}
|
|
|
|
if err := repo.UpdateStatus(ctx, devices[1].ID, domain.DeviceStatusActive); err != nil {
|
|
t.Fatalf("UpdateStatus failed: %v", err)
|
|
}
|
|
|
|
beforeTouch, err := repo.GetByID(ctx, devices[1].ID)
|
|
if err != nil {
|
|
t.Fatalf("GetByID before UpdateLastActiveTime failed: %v", err)
|
|
}
|
|
|
|
time.Sleep(10 * time.Millisecond)
|
|
if err := repo.UpdateLastActiveTime(ctx, devices[1].ID); err != nil {
|
|
t.Fatalf("UpdateLastActiveTime failed: %v", err)
|
|
}
|
|
|
|
afterTouch, err := repo.GetByID(ctx, devices[1].ID)
|
|
if err != nil {
|
|
t.Fatalf("GetByID after UpdateLastActiveTime failed: %v", err)
|
|
}
|
|
if !afterTouch.LastActiveTime.After(beforeTouch.LastActiveTime) {
|
|
t.Fatal("expected last_active_time to move forward")
|
|
}
|
|
|
|
recentDevices, err := repo.GetActiveDevices(ctx, 1)
|
|
if err != nil {
|
|
t.Fatalf("GetActiveDevices failed: %v", err)
|
|
}
|
|
if len(recentDevices) != 2 {
|
|
t.Fatalf("expected 2 recent devices for user 1, got %d", len(recentDevices))
|
|
}
|
|
|
|
if err := repo.DeleteByUserID(ctx, 1); err != nil {
|
|
t.Fatalf("DeleteByUserID failed: %v", err)
|
|
}
|
|
|
|
remainingDevices, remainingTotal, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List after DeleteByUserID failed: %v", err)
|
|
}
|
|
if remainingTotal != 1 || len(remainingDevices) != 1 {
|
|
t.Fatalf("expected 1 remaining device, got total=%d len=%d", remainingTotal, len(remainingDevices))
|
|
}
|
|
|
|
if err := repo.Delete(ctx, devices[2].ID); err != nil {
|
|
t.Fatalf("Delete failed: %v", err)
|
|
}
|
|
|
|
if _, err := repo.GetByID(ctx, devices[2].ID); err == nil {
|
|
t.Fatal("expected deleted device lookup to fail")
|
|
}
|
|
}
|
|
|
|
func TestLoginLogRepositoryQueriesAndRetention(t *testing.T) {
|
|
db := openTestDB(t)
|
|
migrateRepositoryTables(t, db, &domain.LoginLog{})
|
|
|
|
repo := NewLoginLogRepository(db)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
logs := []*domain.LoginLog{
|
|
{
|
|
UserID: int64Ptr(1),
|
|
LoginType: int(domain.LoginTypePassword),
|
|
DeviceID: "device-alpha",
|
|
IP: "10.0.0.1",
|
|
Location: "Shanghai",
|
|
Status: 1,
|
|
CreatedAt: now.Add(-1 * time.Hour),
|
|
},
|
|
{
|
|
UserID: int64Ptr(1),
|
|
LoginType: int(domain.LoginTypeSMSCode),
|
|
DeviceID: "device-beta",
|
|
IP: "10.0.0.2",
|
|
Location: "Hangzhou",
|
|
Status: 0,
|
|
FailReason: "code expired",
|
|
CreatedAt: now.Add(-30 * time.Minute),
|
|
},
|
|
{
|
|
UserID: int64Ptr(2),
|
|
LoginType: int(domain.LoginTypeOAuth),
|
|
DeviceID: "device-gamma",
|
|
IP: "10.0.0.3",
|
|
Location: "Beijing",
|
|
Status: 1,
|
|
CreatedAt: now.Add(-45 * 24 * time.Hour),
|
|
},
|
|
}
|
|
|
|
for _, log := range logs {
|
|
if err := repo.Create(ctx, log); err != nil {
|
|
t.Fatalf("Create login log failed: %v", err)
|
|
}
|
|
}
|
|
|
|
loaded, err := repo.GetByID(ctx, logs[0].ID)
|
|
if err != nil {
|
|
t.Fatalf("GetByID failed: %v", err)
|
|
}
|
|
if loaded.DeviceID != "device-alpha" {
|
|
t.Fatalf("expected device-alpha, got %q", loaded.DeviceID)
|
|
}
|
|
|
|
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByUserID failed: %v", err)
|
|
}
|
|
if total != 2 || len(userLogs) != 2 {
|
|
t.Fatalf("expected 2 user logs, got total=%d len=%d", total, len(userLogs))
|
|
}
|
|
if userLogs[0].DeviceID != "device-beta" {
|
|
t.Fatalf("expected newest login log first, got %q", userLogs[0].DeviceID)
|
|
}
|
|
|
|
allLogs, total, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List failed: %v", err)
|
|
}
|
|
if total != 3 || len(allLogs) != 3 {
|
|
t.Fatalf("expected 3 total logs, got total=%d len=%d", total, len(allLogs))
|
|
}
|
|
|
|
successLogs, total, err := repo.ListByStatus(ctx, 1, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByStatus failed: %v", err)
|
|
}
|
|
if total != 2 || len(successLogs) != 2 {
|
|
t.Fatalf("expected 2 success logs, got total=%d len=%d", total, len(successLogs))
|
|
}
|
|
|
|
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-2*time.Hour), now, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByTimeRange failed: %v", err)
|
|
}
|
|
if total != 2 || len(recentLogs) != 2 {
|
|
t.Fatalf("expected 2 recent logs, got total=%d len=%d", total, len(recentLogs))
|
|
}
|
|
|
|
if count := repo.CountByResultSince(ctx, true, now.Add(-2*time.Hour)); count != 1 {
|
|
t.Fatalf("expected 1 recent success login, got %d", count)
|
|
}
|
|
if count := repo.CountByResultSince(ctx, false, now.Add(-2*time.Hour)); count != 1 {
|
|
t.Fatalf("expected 1 recent failed login, got %d", count)
|
|
}
|
|
|
|
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
|
|
t.Fatalf("DeleteOlderThan failed: %v", err)
|
|
}
|
|
|
|
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List after DeleteOlderThan failed: %v", err)
|
|
}
|
|
if retainedTotal != 2 || len(retainedLogs) != 2 {
|
|
t.Fatalf("expected 2 retained logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
|
|
}
|
|
|
|
if err := repo.DeleteByUserID(ctx, 1); err != nil {
|
|
t.Fatalf("DeleteByUserID failed: %v", err)
|
|
}
|
|
|
|
finalLogs, finalTotal, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List after DeleteByUserID failed: %v", err)
|
|
}
|
|
if finalTotal != 0 || len(finalLogs) != 0 {
|
|
t.Fatalf("expected all logs removed, got total=%d len=%d", finalTotal, len(finalLogs))
|
|
}
|
|
}
|
|
|
|
func TestPasswordHistoryRepositoryKeepsNewestRecords(t *testing.T) {
|
|
db := openTestDB(t)
|
|
migrateRepositoryTables(t, db, &domain.PasswordHistory{})
|
|
|
|
repo := NewPasswordHistoryRepository(db)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
histories := []*domain.PasswordHistory{
|
|
{UserID: 1, PasswordHash: "hash-1", CreatedAt: now.Add(-4 * time.Hour)},
|
|
{UserID: 1, PasswordHash: "hash-2", CreatedAt: now.Add(-3 * time.Hour)},
|
|
{UserID: 1, PasswordHash: "hash-3", CreatedAt: now.Add(-2 * time.Hour)},
|
|
{UserID: 1, PasswordHash: "hash-4", CreatedAt: now.Add(-1 * time.Hour)},
|
|
{UserID: 2, PasswordHash: "hash-foreign", CreatedAt: now.Add(-30 * time.Minute)},
|
|
}
|
|
|
|
for _, history := range histories {
|
|
if err := repo.Create(ctx, history); err != nil {
|
|
t.Fatalf("Create password history failed: %v", err)
|
|
}
|
|
}
|
|
|
|
latestTwo, err := repo.GetByUserID(ctx, 1, 2)
|
|
if err != nil {
|
|
t.Fatalf("GetByUserID(limit=2) failed: %v", err)
|
|
}
|
|
if len(latestTwo) != 2 {
|
|
t.Fatalf("expected 2 latest password histories, got %d", len(latestTwo))
|
|
}
|
|
if latestTwo[0].PasswordHash != "hash-4" || latestTwo[1].PasswordHash != "hash-3" {
|
|
t.Fatalf("expected newest password hashes to be retained, got %q and %q", latestTwo[0].PasswordHash, latestTwo[1].PasswordHash)
|
|
}
|
|
|
|
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
|
|
t.Fatalf("DeleteOldRecords failed: %v", err)
|
|
}
|
|
|
|
remainingHistories, err := repo.GetByUserID(ctx, 1, 10)
|
|
if err != nil {
|
|
t.Fatalf("GetByUserID after DeleteOldRecords failed: %v", err)
|
|
}
|
|
if len(remainingHistories) != 2 {
|
|
t.Fatalf("expected 2 remaining histories, got %d", len(remainingHistories))
|
|
}
|
|
if remainingHistories[0].PasswordHash != "hash-4" || remainingHistories[1].PasswordHash != "hash-3" {
|
|
t.Fatalf("unexpected remaining password hashes: %q and %q", remainingHistories[0].PasswordHash, remainingHistories[1].PasswordHash)
|
|
}
|
|
|
|
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
|
|
t.Fatalf("DeleteOldRecords for missing user failed: %v", err)
|
|
}
|
|
}
|
|
|
|
func TestOperationLogRepositorySearchAndRetention(t *testing.T) {
|
|
db := openTestDB(t)
|
|
migrateRepositoryTables(t, db, &domain.OperationLog{})
|
|
|
|
repo := NewOperationLogRepository(db)
|
|
ctx := context.Background()
|
|
now := time.Now().UTC()
|
|
|
|
logs := []*domain.OperationLog{
|
|
{
|
|
UserID: int64Ptr(1),
|
|
OperationType: "user",
|
|
OperationName: "create user",
|
|
RequestMethod: "POST",
|
|
RequestPath: "/api/v1/users",
|
|
RequestParams: `{"username":"alice"}`,
|
|
ResponseStatus: 201,
|
|
IP: "10.0.0.1",
|
|
UserAgent: "Chrome",
|
|
CreatedAt: now.Add(-20 * time.Minute),
|
|
},
|
|
{
|
|
UserID: int64Ptr(1),
|
|
OperationType: "dashboard",
|
|
OperationName: "view dashboard",
|
|
RequestMethod: "GET",
|
|
RequestPath: "/dashboard",
|
|
RequestParams: "{}",
|
|
ResponseStatus: 200,
|
|
IP: "10.0.0.2",
|
|
UserAgent: "Chrome",
|
|
CreatedAt: now.Add(-10 * time.Minute),
|
|
},
|
|
{
|
|
UserID: int64Ptr(2),
|
|
OperationType: "user",
|
|
OperationName: "delete user",
|
|
RequestMethod: "DELETE",
|
|
RequestPath: "/api/v1/users/7",
|
|
RequestParams: "{}",
|
|
ResponseStatus: 204,
|
|
IP: "10.0.0.3",
|
|
UserAgent: "Firefox",
|
|
CreatedAt: now.Add(-40 * 24 * time.Hour),
|
|
},
|
|
}
|
|
|
|
for _, log := range logs {
|
|
if err := repo.Create(ctx, log); err != nil {
|
|
t.Fatalf("Create operation log failed: %v", err)
|
|
}
|
|
}
|
|
|
|
loaded, err := repo.GetByID(ctx, logs[0].ID)
|
|
if err != nil {
|
|
t.Fatalf("GetByID failed: %v", err)
|
|
}
|
|
if loaded.OperationName != "create user" {
|
|
t.Fatalf("expected create user log, got %q", loaded.OperationName)
|
|
}
|
|
|
|
userLogs, total, err := repo.ListByUserID(ctx, 1, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByUserID failed: %v", err)
|
|
}
|
|
if total != 2 || len(userLogs) != 2 {
|
|
t.Fatalf("expected 2 user operation logs, got total=%d len=%d", total, len(userLogs))
|
|
}
|
|
if userLogs[0].OperationName != "view dashboard" {
|
|
t.Fatalf("expected newest operation log first, got %q", userLogs[0].OperationName)
|
|
}
|
|
|
|
allLogs, total, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List failed: %v", err)
|
|
}
|
|
if total != 3 || len(allLogs) != 3 {
|
|
t.Fatalf("expected 3 total operation logs, got total=%d len=%d", total, len(allLogs))
|
|
}
|
|
|
|
postLogs, total, err := repo.ListByMethod(ctx, "POST", 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByMethod failed: %v", err)
|
|
}
|
|
if total != 1 || len(postLogs) != 1 || postLogs[0].OperationName != "create user" {
|
|
t.Fatalf("expected a single POST operation log, got total=%d len=%d", total, len(postLogs))
|
|
}
|
|
|
|
recentLogs, total, err := repo.ListByTimeRange(ctx, now.Add(-1*time.Hour), now, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("ListByTimeRange failed: %v", err)
|
|
}
|
|
if total != 2 || len(recentLogs) != 2 {
|
|
t.Fatalf("expected 2 recent operation logs, got total=%d len=%d", total, len(recentLogs))
|
|
}
|
|
|
|
searchResults, total, err := repo.Search(ctx, "user", 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("Search failed: %v", err)
|
|
}
|
|
if total != 2 || len(searchResults) != 2 {
|
|
t.Fatalf("expected 2 operation logs matching user, got total=%d len=%d", total, len(searchResults))
|
|
}
|
|
|
|
if err := repo.DeleteOlderThan(ctx, 30); err != nil {
|
|
t.Fatalf("DeleteOlderThan failed: %v", err)
|
|
}
|
|
|
|
retainedLogs, retainedTotal, err := repo.List(ctx, 0, 10)
|
|
if err != nil {
|
|
t.Fatalf("List after DeleteOlderThan failed: %v", err)
|
|
}
|
|
if retainedTotal != 2 || len(retainedLogs) != 2 {
|
|
t.Fatalf("expected 2 retained operation logs, got total=%d len=%d", retainedTotal, len(retainedLogs))
|
|
}
|
|
}
|