Backend fixes: - auth_handler: P0 认证逻辑修复 - ratelimit: 限速中间件增强 + 新增单元测试 - auth_service: 认证服务逻辑完善 + 新增测试 - server: server 配置增强 + 新增测试 - handler_test: 新增 handler 层集成测试 - auth_bootstrap_test: bootstrap 路径测试 Frontend patches: - LoginPage/RegisterPage: CSRF + 表单交互修复 - BootstrapAdminPage: 引导流程修复 - DevicesPage: 设备管理页修复 - auth/social-accounts/users/webhooks services: 类型修正 - csrf.ts: CSRF token 处理修正 - E2E 脚本: CDP smoke + auth e2e 增强 Docs: - FULL_CODE_REVIEW_REPORT_2026-04-20 - report-v6 执行计划 - REAL_PROJECT_STATUS 更新 - .gitignore: 新增 .gocache-*/config.yaml 排除 验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
1465 lines
42 KiB
Go
1465 lines
42 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/user-management-system/internal/auth"
|
|
"github.com/user-management-system/internal/cache"
|
|
"github.com/user-management-system/internal/domain"
|
|
"github.com/user-management-system/internal/repository"
|
|
"github.com/user-management-system/internal/security"
|
|
gormsqlite "gorm.io/driver/sqlite"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
// =============================================================================
|
|
// Auth Service Unit Tests
|
|
// =============================================================================
|
|
|
|
func TestPasswordStrength(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
password string
|
|
wantInfo PasswordStrengthInfo
|
|
}{
|
|
{
|
|
name: "empty_password",
|
|
password: "",
|
|
wantInfo: PasswordStrengthInfo{Score: 0, Length: 0, HasUpper: false, HasLower: false, HasDigit: false, HasSpecial: false},
|
|
},
|
|
{
|
|
name: "lowercase_only",
|
|
password: "abcdefgh",
|
|
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: true, HasDigit: false, HasSpecial: false},
|
|
},
|
|
{
|
|
name: "uppercase_only",
|
|
password: "ABCDEFGH",
|
|
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: true, HasLower: false, HasDigit: false, HasSpecial: false},
|
|
},
|
|
{
|
|
name: "digits_only",
|
|
password: "12345678",
|
|
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
|
},
|
|
{
|
|
name: "mixed_case_with_digits",
|
|
password: "Abcd1234",
|
|
wantInfo: PasswordStrengthInfo{Score: 3, Length: 8, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: false},
|
|
},
|
|
{
|
|
name: "mixed_with_special",
|
|
password: "Abcd1234!",
|
|
wantInfo: PasswordStrengthInfo{Score: 4, Length: 9, HasUpper: true, HasLower: true, HasDigit: true, HasSpecial: true},
|
|
},
|
|
{
|
|
name: "chinese_characters",
|
|
password: "密码123456",
|
|
wantInfo: PasswordStrengthInfo{Score: 1, Length: 8, HasUpper: false, HasLower: false, HasDigit: true, HasSpecial: false},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
info := GetPasswordStrength(tt.password)
|
|
if info.Score != tt.wantInfo.Score {
|
|
t.Errorf("Score: got %d, want %d", info.Score, tt.wantInfo.Score)
|
|
}
|
|
if info.Length != tt.wantInfo.Length {
|
|
t.Errorf("Length: got %d, want %d", info.Length, tt.wantInfo.Length)
|
|
}
|
|
if info.HasUpper != tt.wantInfo.HasUpper {
|
|
t.Errorf("HasUpper: got %v, want %v", info.HasUpper, tt.wantInfo.HasUpper)
|
|
}
|
|
if info.HasLower != tt.wantInfo.HasLower {
|
|
t.Errorf("HasLower: got %v, want %v", info.HasLower, tt.wantInfo.HasLower)
|
|
}
|
|
if info.HasDigit != tt.wantInfo.HasDigit {
|
|
t.Errorf("HasDigit: got %v, want %v", info.HasDigit, tt.wantInfo.HasDigit)
|
|
}
|
|
if info.HasSpecial != tt.wantInfo.HasSpecial {
|
|
t.Errorf("HasSpecial: got %v, want %v", info.HasSpecial, tt.wantInfo.HasSpecial)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidatePasswordStrength(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
password string
|
|
minLength int
|
|
strict bool
|
|
wantErr bool
|
|
}{
|
|
{
|
|
name: "valid_password_strict",
|
|
password: "Abcd1234!",
|
|
minLength: 8,
|
|
strict: true,
|
|
wantErr: false,
|
|
},
|
|
{
|
|
name: "too_short",
|
|
password: "Ab1!",
|
|
minLength: 8,
|
|
strict: false,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "weak_password",
|
|
password: "abcdefgh",
|
|
minLength: 8,
|
|
strict: false,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "strict_missing_uppercase",
|
|
password: "abcd1234!",
|
|
minLength: 8,
|
|
strict: true,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "strict_missing_lowercase",
|
|
password: "ABCD1234!",
|
|
minLength: 8,
|
|
strict: true,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "strict_missing_digit",
|
|
password: "Abcdefgh!",
|
|
minLength: 8,
|
|
strict: true,
|
|
wantErr: true,
|
|
},
|
|
{
|
|
name: "valid_weak_password_non_strict",
|
|
password: "Abcd1234",
|
|
minLength: 8,
|
|
strict: false,
|
|
wantErr: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validatePasswordStrength(tt.password, tt.minLength, tt.strict)
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("validatePasswordStrength() error = %v, wantErr %v", err, tt.wantErr)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestSanitizeUsername(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
want string
|
|
}{
|
|
{
|
|
name: "normal_username",
|
|
input: "john_doe",
|
|
want: "john_doe",
|
|
},
|
|
{
|
|
name: "username_with_spaces",
|
|
input: "john doe",
|
|
want: "john_doe",
|
|
},
|
|
{
|
|
name: "username_with_uppercase",
|
|
input: "JohnDoe",
|
|
want: "johndoe",
|
|
},
|
|
{
|
|
name: "username_with_special_chars",
|
|
input: "john@doe",
|
|
want: "johndoe",
|
|
},
|
|
{
|
|
name: "empty_username",
|
|
input: "",
|
|
want: "user",
|
|
},
|
|
{
|
|
name: "whitespace_only",
|
|
input: " ",
|
|
want: "user",
|
|
},
|
|
{
|
|
name: "username_with_emoji",
|
|
input: "john😀doe",
|
|
want: "johndoe", // emoji is filtered out as it's not letter/digit/./-/_
|
|
},
|
|
{
|
|
name: "username_with_leading_underscore",
|
|
input: "_john_",
|
|
want: "john", // leading and trailing _ are trimmed
|
|
},
|
|
{
|
|
name: "username_with_trailing_dots",
|
|
input: "john..doe...",
|
|
want: "john..doe", // trailing dots trimmed
|
|
},
|
|
{
|
|
name: "long_username_truncated",
|
|
input: "this_is_a_very_long_username_that_exceeds_fifty_characters_limit",
|
|
want: "this_is_a_very_long_username_that_exceeds_fifty_ch", // 50 chars max, cuts off "acters_limit"
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := sanitizeUsername(tt.input)
|
|
if got != tt.want {
|
|
t.Errorf("sanitizeUsername() = %q (len=%d), want %q (len=%d)", got, len(got), tt.want, len(tt.want))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIsValidPhoneSimple(t *testing.T) {
|
|
tests := []struct {
|
|
phone string
|
|
want bool
|
|
}{
|
|
{"13800138000", true},
|
|
{"+8613800138000", true}, // Valid: +86 prefix with 11 digit mobile
|
|
{"8613800138000", true}, // Valid: 86 prefix with 11 digit mobile
|
|
{"1234567890", false},
|
|
{"abcdefghij", false},
|
|
{"", false},
|
|
{"138001380001", false}, // 12 digits
|
|
{"1380013800", false}, // 10 digits
|
|
{"19800138000", true}, // 98 prefix
|
|
// +[1-9]\d{6,14} allows international numbers like +16171234567
|
|
{"+16171234567", true}, // 11 digits international, valid for \d{6,14}
|
|
{"+112345678901", true}, // 11 digits international, valid for \d{6,14}
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.phone, func(t *testing.T) {
|
|
got := isValidPhoneSimple(tt.phone)
|
|
if got != tt.want {
|
|
t.Errorf("isValidPhoneSimple(%q) = %v, want %v", tt.phone, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLoginRequestGetAccount(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *LoginRequest
|
|
want string
|
|
}{
|
|
{
|
|
name: "account_field",
|
|
req: &LoginRequest{Account: "john", Username: "jane", Email: "jane@test.com"},
|
|
want: "john",
|
|
},
|
|
{
|
|
name: "username_field",
|
|
req: &LoginRequest{Username: "jane", Email: "jane@test.com"},
|
|
want: "jane",
|
|
},
|
|
{
|
|
name: "email_field",
|
|
req: &LoginRequest{Email: "jane@test.com"},
|
|
want: "jane@test.com",
|
|
},
|
|
{
|
|
name: "phone_field",
|
|
req: &LoginRequest{Phone: "13800138000"},
|
|
want: "13800138000",
|
|
},
|
|
{
|
|
name: "all_fields_with_whitespace",
|
|
req: &LoginRequest{Account: " john ", Username: " jane ", Email: " jane@test.com "},
|
|
want: "john",
|
|
},
|
|
{
|
|
name: "empty_request",
|
|
req: &LoginRequest{},
|
|
want: "",
|
|
},
|
|
{
|
|
name: "nil_request",
|
|
req: nil,
|
|
want: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := tt.req.GetAccount()
|
|
if got != tt.want {
|
|
t.Errorf("GetAccount() = %q, want %q", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestBuildDeviceFingerprint(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
req *LoginRequest
|
|
want string
|
|
}{
|
|
{
|
|
name: "full_device_info",
|
|
req: &LoginRequest{
|
|
DeviceID: "device123",
|
|
DeviceName: "iPhone 15",
|
|
DeviceBrowser: "Safari",
|
|
DeviceOS: "iOS 17",
|
|
},
|
|
want: "device123|iPhone 15|Safari|iOS 17",
|
|
},
|
|
{
|
|
name: "partial_device_info",
|
|
req: &LoginRequest{
|
|
DeviceID: "device123",
|
|
DeviceName: "iPhone 15",
|
|
},
|
|
want: "device123|iPhone 15",
|
|
},
|
|
{
|
|
name: "only_device_id",
|
|
req: &LoginRequest{
|
|
DeviceID: "device123",
|
|
},
|
|
want: "device123",
|
|
},
|
|
{
|
|
name: "empty_device_info",
|
|
req: &LoginRequest{},
|
|
want: "",
|
|
},
|
|
{
|
|
name: "nil_request",
|
|
req: nil,
|
|
want: "",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := buildDeviceFingerprint(tt.req)
|
|
if got != tt.want {
|
|
t.Errorf("buildDeviceFingerprint() = %q, want %q", got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestLogin_IssuesTOTPChallengeTokenWhenSecondFactorIsRequired(t *testing.T) {
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:login_totp_challenge_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
|
HS256Secret: "totp-challenge-secret",
|
|
AccessTokenExpire: 15 * time.Minute,
|
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to create jwt manager: %v", err)
|
|
}
|
|
|
|
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), cache.NewRedisCache(false))
|
|
userRepo := repository.NewUserRepository(db)
|
|
svc := NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
|
|
|
hashedPassword, err := auth.HashPassword("Password123!")
|
|
if err != nil {
|
|
t.Fatalf("failed to hash password: %v", err)
|
|
}
|
|
|
|
user := &domain.User{
|
|
Username: "totpchallenge",
|
|
Password: hashedPassword,
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: true,
|
|
TOTPSecret: "JBSWY3DPEHPK3PXP",
|
|
}
|
|
if err := db.Create(user).Error; err != nil {
|
|
t.Fatalf("failed to create user: %v", err)
|
|
}
|
|
|
|
resp, err := svc.Login(context.Background(), &LoginRequest{
|
|
Account: "totpchallenge",
|
|
Password: "Password123!",
|
|
DeviceID: "device-1",
|
|
}, "127.0.0.1")
|
|
if err != nil {
|
|
t.Fatalf("login failed: %v", err)
|
|
}
|
|
|
|
if !resp.RequiresTOTP {
|
|
t.Fatalf("expected requires_totp response, got %+v", resp)
|
|
}
|
|
if resp.UserID != user.ID {
|
|
t.Fatalf("expected user id %d, got %d", user.ID, resp.UserID)
|
|
}
|
|
if strings.TrimSpace(resp.TempToken) == "" {
|
|
t.Fatalf("expected temp token when TOTP is required, got %+v", resp)
|
|
}
|
|
if resp.AccessToken != "" || resp.RefreshToken != "" {
|
|
t.Fatalf("expected no full session tokens before TOTP verification, got %+v", resp)
|
|
}
|
|
}
|
|
|
|
func TestAuthServiceDefaultConfig(t *testing.T) {
|
|
// Test that default configuration is applied correctly
|
|
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)
|
|
|
|
if svc == nil {
|
|
t.Fatal("NewAuthService returned nil")
|
|
}
|
|
|
|
// Check default password minimum length
|
|
if svc.passwordMinLength != defaultPasswordMinLen {
|
|
t.Errorf("passwordMinLength: got %d, want %d", svc.passwordMinLength, defaultPasswordMinLen)
|
|
}
|
|
|
|
// Check default max login attempts
|
|
if svc.maxLoginAttempts != 5 {
|
|
t.Errorf("maxLoginAttempts: got %d, want %d", svc.maxLoginAttempts, 5)
|
|
}
|
|
|
|
// Check default login lock duration
|
|
if svc.loginLockDuration != 15*time.Minute {
|
|
t.Errorf("loginLockDuration: got %v, want %v", svc.loginLockDuration, 15*time.Minute)
|
|
}
|
|
}
|
|
|
|
func TestAuthServiceNilSafety(t *testing.T) {
|
|
t.Run("validatePassword_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
err := svc.validatePassword("Abcd1234!")
|
|
if err != nil {
|
|
t.Errorf("nil service should not error: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("accessTokenTTL_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
ttl := svc.accessTokenTTLSeconds()
|
|
if ttl != 0 {
|
|
t.Errorf("nil service should return 0: got %d", ttl)
|
|
}
|
|
})
|
|
|
|
t.Run("RefreshTokenTTL_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
ttl := svc.RefreshTokenTTLSeconds()
|
|
if ttl != 0 {
|
|
t.Errorf("nil service should return 0: got %d", ttl)
|
|
}
|
|
})
|
|
|
|
t.Run("generateUniqueUsername_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
username, err := svc.generateUniqueUsername(context.Background(), "testuser")
|
|
if err != nil {
|
|
t.Errorf("nil service should return username: %v", err)
|
|
}
|
|
if username != "testuser" {
|
|
t.Errorf("username: got %q, want %q", username, "testuser")
|
|
}
|
|
})
|
|
|
|
t.Run("buildUserInfo_nil_user", func(t *testing.T) {
|
|
var svc *AuthService
|
|
info := svc.buildUserInfo(nil)
|
|
if info != nil {
|
|
t.Errorf("nil user should return nil info: got %v", info)
|
|
}
|
|
})
|
|
|
|
t.Run("ensureUserActive_nil_user", func(t *testing.T) {
|
|
var svc *AuthService
|
|
err := svc.ensureUserActive(nil)
|
|
if err == nil {
|
|
t.Error("nil user should return error")
|
|
}
|
|
})
|
|
|
|
t.Run("blacklistToken_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
err := svc.blacklistTokenClaims(context.Background(), "token", nil)
|
|
if err != nil {
|
|
t.Errorf("nil service should not error: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("Logout_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
err := svc.Logout(context.Background(), "user", nil)
|
|
if err != nil {
|
|
t.Errorf("nil service should not error: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("IsTokenBlacklisted_nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
blacklisted := svc.IsTokenBlacklisted(context.Background(), "jti")
|
|
if blacklisted {
|
|
t.Error("nil service should not blacklist tokens")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestUserInfoFromCacheValue(t *testing.T) {
|
|
t.Run("valid_UserInfo_pointer", func(t *testing.T) {
|
|
info := &UserInfo{ID: 1, Username: "testuser"}
|
|
got, ok := userInfoFromCacheValue(info)
|
|
if !ok {
|
|
t.Error("should parse *UserInfo")
|
|
}
|
|
if got.ID != 1 || got.Username != "testuser" {
|
|
t.Errorf("got %+v, want %+v", got, info)
|
|
}
|
|
})
|
|
|
|
t.Run("valid_UserInfo_value", func(t *testing.T) {
|
|
info := UserInfo{ID: 2, Username: "testuser2"}
|
|
got, ok := userInfoFromCacheValue(info)
|
|
if !ok {
|
|
t.Error("should parse UserInfo value")
|
|
}
|
|
if got.ID != 2 || got.Username != "testuser2" {
|
|
t.Errorf("got %+v, want %+v", got, info)
|
|
}
|
|
})
|
|
|
|
t.Run("invalid_type", func(t *testing.T) {
|
|
got, ok := userInfoFromCacheValue("invalid string")
|
|
if ok || got != nil {
|
|
t.Errorf("should not parse string: ok=%v, got=%+v", ok, got)
|
|
}
|
|
})
|
|
|
|
t.Run("map_string_interface", func(t *testing.T) {
|
|
info := map[string]interface{}{
|
|
"id": float64(3),
|
|
"username": "mapuser",
|
|
"email": "map@test.com",
|
|
}
|
|
got, ok := userInfoFromCacheValue(info)
|
|
if !ok {
|
|
t.Error("should parse map[string]interface{}")
|
|
}
|
|
if got == nil {
|
|
t.Fatal("got nil")
|
|
}
|
|
if got.ID != 3 || got.Username != "mapuser" {
|
|
t.Errorf("got ID=%d, Username=%s, want ID=3, Username=mapuser", got.ID, got.Username)
|
|
}
|
|
})
|
|
|
|
t.Run("map_with_invalid_data", func(t *testing.T) {
|
|
info := map[string]interface{}{
|
|
"id": "not_a_number",
|
|
}
|
|
got, ok := userInfoFromCacheValue(info)
|
|
// Should fail to parse
|
|
if ok {
|
|
t.Errorf("should not parse invalid map: ok=%v, got=%+v", ok, got)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestEnsureUserActive(t *testing.T) {
|
|
t.Run("nil_user", func(t *testing.T) {
|
|
var svc *AuthService
|
|
err := svc.ensureUserActive(nil)
|
|
if err == nil {
|
|
t.Error("nil user should error")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestAttemptCount(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
value interface{}
|
|
want int
|
|
}{
|
|
{"int_value", 5, 5},
|
|
{"int64_value", int64(3), 3},
|
|
{"float64_value", float64(4.0), 4},
|
|
{"string_int", "3", 0}, // strings are not converted
|
|
{"invalid_type", "abc", 0},
|
|
{"nil", nil, 0},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
got := attemptCount(tt.value)
|
|
if got != tt.want {
|
|
t.Errorf("attemptCount(%v) = %d, want %d", tt.value, got, tt.want)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestIncrementFailAttempts(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
count := svc.incrementFailAttempts(context.Background(), "key")
|
|
if count != 0 {
|
|
t.Errorf("nil service should return 0, got %d", count)
|
|
}
|
|
})
|
|
|
|
t.Run("empty_key", func(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
count := svc.incrementFailAttempts(context.Background(), "")
|
|
if count != 0 {
|
|
t.Errorf("empty key should return 0, got %d", count)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestWriteLoginLog_Nil(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
userID := int64(1)
|
|
// Should not panic
|
|
svc.writeLoginLog(context.Background(), &userID, 1, "127.0.0.1", true, "")
|
|
})
|
|
|
|
t.Run("nil_user_id", func(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
// Should not panic
|
|
svc.writeLoginLog(context.Background(), nil, 1, "127.0.0.1", true, "")
|
|
})
|
|
}
|
|
|
|
func TestRecordLoginAnomaly_Nil(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
userID := int64(1)
|
|
// Should not panic
|
|
svc.recordLoginAnomaly(context.Background(), &userID, "127.0.0.1", "location", "device", true)
|
|
})
|
|
|
|
t.Run("nil_user_id", func(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
// Should not panic
|
|
svc.recordLoginAnomaly(context.Background(), nil, "127.0.0.1", "location", "device", true)
|
|
})
|
|
}
|
|
|
|
func TestPublishEvent_Nil(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
// Should not panic
|
|
svc.publishEvent(context.Background(), domain.EventUserRegistered, map[string]interface{}{"user_id": 1})
|
|
})
|
|
}
|
|
|
|
func TestCacheUserInfo_Nil(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
// Should not panic
|
|
svc.cacheUserInfo(context.Background(), nil)
|
|
})
|
|
}
|
|
|
|
func TestBestEffortRegisterDevice_Nil(t *testing.T) {
|
|
t.Run("nil_service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
// Should not panic
|
|
svc.bestEffortRegisterDevice(context.Background(), 1, nil)
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Write Login Log Integration Tests
|
|
// =============================================================================
|
|
|
|
func TestWriteLoginLog_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:loginlog_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.LoginLog{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
loginLogRepo := repository.NewLoginLogRepository(db)
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
svc.SetLoginLogRepository(loginLogRepo)
|
|
|
|
userID := int64(123)
|
|
|
|
t.Run("write successful login log", func(t *testing.T) {
|
|
svc.writeLoginLog(context.Background(), &userID, domain.LoginTypePassword, "192.168.1.1", true, "")
|
|
|
|
// Wait for async goroutine
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
var logs []domain.LoginLog
|
|
db.Find(&logs)
|
|
if len(logs) != 1 {
|
|
t.Errorf("Expected 1 log, got %d", len(logs))
|
|
}
|
|
if len(logs) > 0 {
|
|
if logs[0].Status != 1 {
|
|
t.Errorf("Expected status 1, got %d", logs[0].Status)
|
|
}
|
|
if logs[0].IP != "192.168.1.1" {
|
|
t.Errorf("Expected IP '192.168.1.1', got %s", logs[0].IP)
|
|
}
|
|
}
|
|
})
|
|
|
|
t.Run("write failed login log", func(t *testing.T) {
|
|
svc.writeLoginLog(context.Background(), &userID, domain.LoginTypePassword, "10.0.0.1", false, "wrong password")
|
|
|
|
// Wait for async goroutine
|
|
time.Sleep(100 * time.Millisecond)
|
|
|
|
var logs []domain.LoginLog
|
|
db.Where("ip = ?", "10.0.0.1").Find(&logs)
|
|
if len(logs) != 1 {
|
|
t.Errorf("Expected 1 log, got %d", len(logs))
|
|
}
|
|
if len(logs) > 0 && logs[0].Status != 0 {
|
|
t.Errorf("Expected status 0 for failed login, got %d", logs[0].Status)
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Record Login Anomaly Tests
|
|
// =============================================================================
|
|
|
|
// mockAnomalyDetector is a mock implementation of anomalyRecorder
|
|
type mockAnomalyDetector struct {
|
|
events []security.AnomalyEvent
|
|
}
|
|
|
|
func (m *mockAnomalyDetector) RecordLogin(ctx context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []security.AnomalyEvent {
|
|
return m.events
|
|
}
|
|
|
|
func TestRecordLoginAnomaly_WithDetector(t *testing.T) {
|
|
t.Run("with anomaly detector returning events", func(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
detector := &mockAnomalyDetector{
|
|
events: []security.AnomalyEvent{security.AnomalyBruteForce},
|
|
}
|
|
svc.SetAnomalyDetector(detector)
|
|
|
|
userID := int64(1)
|
|
// Should not panic
|
|
svc.recordLoginAnomaly(context.Background(), &userID, "127.0.0.1", "location", "device", false)
|
|
})
|
|
|
|
t.Run("with anomaly detector returning no events", func(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
detector := &mockAnomalyDetector{events: nil}
|
|
svc.SetAnomalyDetector(detector)
|
|
|
|
userID := int64(1)
|
|
// Should not panic
|
|
svc.recordLoginAnomaly(context.Background(), &userID, "127.0.0.1", "location", "device", true)
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Generate Unique Username Integration Tests
|
|
// =============================================================================
|
|
|
|
func TestGenerateUniqueUsername_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:username_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
svc := NewAuthService(userRepo, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
t.Run("generate unique username with existing user", func(t *testing.T) {
|
|
// Create existing user
|
|
existingUser := &domain.User{
|
|
Username: "testuser",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(existingUser)
|
|
|
|
// Should generate unique username
|
|
username, err := svc.generateUniqueUsername(context.Background(), "testuser")
|
|
if err != nil {
|
|
t.Fatalf("generateUniqueUsername failed: %v", err)
|
|
}
|
|
if username == "testuser" {
|
|
t.Error("Expected different username since testuser already exists")
|
|
}
|
|
})
|
|
|
|
t.Run("generate unique username with new base", func(t *testing.T) {
|
|
username, err := svc.generateUniqueUsername(context.Background(), "newuser123")
|
|
if err != nil {
|
|
t.Fatalf("generateUniqueUsername failed: %v", err)
|
|
}
|
|
if username != "newuser123" {
|
|
t.Errorf("Expected 'newuser123', got %s", username)
|
|
}
|
|
})
|
|
|
|
t.Run("generate unique username with long base", func(t *testing.T) {
|
|
longBase := "this_is_a_very_long_username_that_exceeds_the_normal_limit"
|
|
username, err := svc.generateUniqueUsername(context.Background(), longBase)
|
|
if err != nil {
|
|
t.Fatalf("generateUniqueUsername failed: %v", err)
|
|
}
|
|
if len(username) > 50 {
|
|
t.Errorf("Username should be truncated to 50 chars, got %d", len(username))
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Upsert OAuth Social Account Tests
|
|
// =============================================================================
|
|
|
|
func TestUpsertOAuthSocialAccount_Nil(t *testing.T) {
|
|
t.Run("nil service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
_, err := svc.upsertOAuthSocialAccount(context.Background(), 1, "github", nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestUpsertOAuthSocialAccount_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:upsert_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.SocialAccount{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
socialRepo, _ := repository.NewSocialAccountRepository(db)
|
|
svc := NewAuthService(userRepo, socialRepo, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
// Create test user
|
|
user := &domain.User{
|
|
Username: "oauthuser",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
t.Run("create new social account", func(t *testing.T) {
|
|
oauthUser := &auth.OAuthUser{
|
|
OpenID: "github123",
|
|
Nickname: "GitHubUser",
|
|
Email: "github@example.com",
|
|
}
|
|
account, err := svc.upsertOAuthSocialAccount(context.Background(), user.ID, "github", oauthUser)
|
|
if err != nil {
|
|
t.Fatalf("upsertOAuthSocialAccount failed: %v", err)
|
|
}
|
|
if account == nil {
|
|
t.Fatal("Expected account to be created")
|
|
}
|
|
if account.Provider != "github" {
|
|
t.Errorf("Expected provider 'github', got %s", account.Provider)
|
|
}
|
|
if account.OpenID != "github123" {
|
|
t.Errorf("Expected OpenID 'github123', got %s", account.OpenID)
|
|
}
|
|
})
|
|
|
|
t.Run("update existing social account", func(t *testing.T) {
|
|
oauthUser := &auth.OAuthUser{
|
|
OpenID: "github123",
|
|
Nickname: "UpdatedUser",
|
|
Email: "updated@example.com",
|
|
}
|
|
account, err := svc.upsertOAuthSocialAccount(context.Background(), user.ID, "github", oauthUser)
|
|
if err != nil {
|
|
t.Fatalf("upsertOAuthSocialAccount failed: %v", err)
|
|
}
|
|
if account.Nickname != "UpdatedUser" {
|
|
t.Errorf("Expected nickname 'UpdatedUser', got %s", account.Nickname)
|
|
}
|
|
})
|
|
|
|
t.Run("nil oauth user", func(t *testing.T) {
|
|
_, err := svc.upsertOAuthSocialAccount(context.Background(), user.ID, "github", nil)
|
|
if err == nil {
|
|
t.Error("Expected error for nil oauth user")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Login By Code Integration Tests
|
|
// =============================================================================
|
|
|
|
func TestLoginByCode_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:logincode_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.LoginLog{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
loginLogRepo := repository.NewLoginLogRepository(db)
|
|
jwtManager, _ := auth.NewJWTWithOptions(auth.JWTOptions{
|
|
HS256Secret: fmt.Sprintf("test-secret-%d", time.Now().UnixNano()),
|
|
AccessTokenExpire: 15 * time.Minute,
|
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
|
})
|
|
|
|
svc := NewAuthService(userRepo, nil, jwtManager, nil, 8, 5, 15*time.Minute)
|
|
svc.SetLoginLogRepository(loginLogRepo)
|
|
|
|
// Create test user with phone
|
|
phone := "13800138000"
|
|
user := &domain.User{
|
|
Username: "logincodeuser",
|
|
Phone: &phone,
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
t.Run("LoginByCode without SMS service configured", func(t *testing.T) {
|
|
_, err := svc.LoginByCode(context.Background(), "13800138000", "123456", "127.0.0.1")
|
|
if err == nil {
|
|
t.Error("Expected error when SMS service not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// OAuth Callback Tests
|
|
// =============================================================================
|
|
|
|
func TestOAuthCallback_Nil(t *testing.T) {
|
|
t.Run("nil service", func(t *testing.T) {
|
|
var svc *AuthService
|
|
_, err := svc.OAuthCallback(context.Background(), "github", "code123")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestOAuthCallback_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:oauth_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.SocialAccount{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
socialRepo, _ := repository.NewSocialAccountRepository(db)
|
|
jwtManager, _ := auth.NewJWTWithOptions(auth.JWTOptions{
|
|
HS256Secret: fmt.Sprintf("test-secret-%d", time.Now().UnixNano()),
|
|
AccessTokenExpire: 15 * time.Minute,
|
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
|
})
|
|
|
|
svc := NewAuthService(userRepo, socialRepo, jwtManager, nil, 8, 5, 15*time.Minute)
|
|
|
|
t.Run("OAuthCallback without OAuth manager configured", func(t *testing.T) {
|
|
_, err := svc.OAuthCallback(context.Background(), "github", "code123")
|
|
if err == nil {
|
|
t.Error("Expected error when OAuth manager not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// OAuth Bind Callback Tests
|
|
// =============================================================================
|
|
|
|
func TestOAuthBindCallback_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:oauthbind_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.SocialAccount{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
socialRepo, _ := repository.NewSocialAccountRepository(db)
|
|
svc := NewAuthService(userRepo, socialRepo, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
// Create test user
|
|
user := &domain.User{
|
|
Username: "oauthbinduser",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
t.Run("OAuthBindCallback without OAuth manager configured", func(t *testing.T) {
|
|
_, err := svc.OAuthBindCallback(context.Background(), user.ID, "github", "code123")
|
|
if err == nil {
|
|
t.Error("Expected error when OAuth manager not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Best Effort Register Device Tests
|
|
// =============================================================================
|
|
|
|
func TestBestEffortRegisterDevice_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:device_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.Device{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
deviceRepo := repository.NewDeviceRepository(db)
|
|
deviceSvc := NewDeviceService(deviceRepo, userRepo)
|
|
|
|
svc := NewAuthService(userRepo, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
svc.SetDeviceService(deviceSvc)
|
|
|
|
// Create test user
|
|
user := &domain.User{
|
|
Username: "deviceuser",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
t.Run("register device with device info", func(t *testing.T) {
|
|
req := &LoginRequest{
|
|
DeviceID: "device123",
|
|
DeviceName: "iPhone 15",
|
|
DeviceBrowser: "Safari",
|
|
DeviceOS: "iOS 17",
|
|
}
|
|
svc.bestEffortRegisterDevice(context.Background(), user.ID, req)
|
|
// Should not panic
|
|
})
|
|
|
|
t.Run("register device with nil request", func(t *testing.T) {
|
|
svc.bestEffortRegisterDevice(context.Background(), user.ID, nil)
|
|
// Should not panic
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Verify Sensitive Action Tests
|
|
// =============================================================================
|
|
|
|
func TestVerifySensitiveAction_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:sensitive_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
svc := NewAuthService(userRepo, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
hashedPassword, _ := auth.HashPassword("Password123!")
|
|
|
|
t.Run("verify with password", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "sensitiveuser",
|
|
Password: hashedPassword,
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifySensitiveAction(context.Background(), user, "Password123!", "")
|
|
if err != nil {
|
|
t.Errorf("Expected no error for correct password, got: %v", err)
|
|
}
|
|
})
|
|
|
|
t.Run("verify with wrong password", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "wrongpassuser",
|
|
Password: hashedPassword,
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifySensitiveAction(context.Background(), user, "wrongpassword", "")
|
|
if err == nil {
|
|
t.Error("Expected error for wrong password")
|
|
}
|
|
})
|
|
|
|
t.Run("verify with TOTP user", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "totpuser",
|
|
Password: hashedPassword,
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: true,
|
|
TOTPSecret: "JBSWY3DPEHPK3PXP",
|
|
}
|
|
db.Create(user)
|
|
|
|
// TOTP requires valid code, so this should fail
|
|
err := svc.verifySensitiveAction(context.Background(), user, "", "invalid_totp")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid TOTP code")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Verify TOTP Code Or Recovery Code Tests
|
|
// =============================================================================
|
|
|
|
func TestVerifyTOTPCodeOrRecoveryCode_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:totp_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
svc := NewAuthService(userRepo, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
t.Run("user without TOTP", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "nototpuser",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: false,
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifyTOTPCodeOrRecoveryCode(context.Background(), user, "123456")
|
|
if err == nil {
|
|
t.Error("Expected error for user without TOTP")
|
|
}
|
|
})
|
|
|
|
t.Run("user with TOTP but wrong code", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "totpuser2",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: true,
|
|
TOTPSecret: "JBSWY3DPEHPK3PXP",
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifyTOTPCodeOrRecoveryCode(context.Background(), user, "invalid_code")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid TOTP code")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Start Social Account Binding Tests
|
|
// =============================================================================
|
|
|
|
func TestStartSocialAccountBinding_Integration(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:startbind_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}, &domain.SocialAccount{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
socialRepo, _ := repository.NewSocialAccountRepository(db)
|
|
svc := NewAuthService(userRepo, socialRepo, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
hashedPassword, _ := auth.HashPassword("Password123!")
|
|
|
|
t.Run("Start binding without OAuth manager", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "startbinduser",
|
|
Password: hashedPassword,
|
|
Status: domain.UserStatusActive,
|
|
}
|
|
db.Create(user)
|
|
|
|
_, _, err := svc.StartSocialAccountBinding(context.Background(), user.ID, "github", "http://localhost", "Password123!", "")
|
|
if err == nil {
|
|
t.Error("Expected error when OAuth manager not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// Verify TOTP Code Or Recovery Code Extended Tests
|
|
// =============================================================================
|
|
|
|
func TestVerifyTOTPCodeOrRecoveryCode_NilUser(t *testing.T) {
|
|
svc := NewAuthService(nil, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
err := svc.verifyTOTPCodeOrRecoveryCode(context.Background(), nil, "123456")
|
|
if err == nil {
|
|
t.Error("Expected error for nil user")
|
|
}
|
|
}
|
|
|
|
func TestVerifyTOTPCodeOrRecoveryCode_RecoveryCode(t *testing.T) {
|
|
// Create in-memory database
|
|
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
|
DriverName: "sqlite",
|
|
DSN: fmt.Sprintf("file:totp_recovery_test_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
|
}), &gorm.Config{
|
|
Logger: logger.Default.LogMode(logger.Silent),
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("failed to connect database: %v", err)
|
|
}
|
|
|
|
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
|
t.Fatalf("failed to migrate: %v", err)
|
|
}
|
|
|
|
userRepo := repository.NewUserRepository(db)
|
|
svc := NewAuthService(userRepo, nil, nil, nil, 8, 5, 15*time.Minute)
|
|
|
|
t.Run("user with empty TOTP secret", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "emptysecret",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: true,
|
|
TOTPSecret: "",
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifyTOTPCodeOrRecoveryCode(context.Background(), user, "123456")
|
|
if err == nil {
|
|
t.Error("Expected error for empty TOTP secret")
|
|
}
|
|
})
|
|
|
|
t.Run("user with TOTP enabled but no recovery codes", func(t *testing.T) {
|
|
user := &domain.User{
|
|
Username: "norecovery",
|
|
Password: "$2a$10$hash",
|
|
Status: domain.UserStatusActive,
|
|
TOTPEnabled: true,
|
|
TOTPSecret: "JBSWY3DPEHPK3PXP",
|
|
TOTPRecoveryCodes: "",
|
|
}
|
|
db.Create(user)
|
|
|
|
err := svc.verifyTOTPCodeOrRecoveryCode(context.Background(), user, "invalidcode")
|
|
if err == nil {
|
|
t.Error("Expected error for invalid code without recovery codes")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// RefreshTokenTTLSeconds Tests
|
|
// =============================================================================
|
|
|
|
func TestRefreshTokenTTLSeconds(t *testing.T) {
|
|
t.Run("nil service returns 0", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
ttl := nilSvc.RefreshTokenTTLSeconds()
|
|
if ttl != 0 {
|
|
t.Errorf("Expected 0, got %d", ttl)
|
|
}
|
|
})
|
|
|
|
t.Run("service without jwt manager returns 0", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
ttl := svc.RefreshTokenTTLSeconds()
|
|
if ttl != 0 {
|
|
t.Errorf("Expected 0, got %d", ttl)
|
|
}
|
|
})
|
|
|
|
t.Run("service with jwt manager", func(t *testing.T) {
|
|
jwtManager, _ := auth.NewJWTWithOptions(auth.JWTOptions{
|
|
HS256Secret: "test-secret",
|
|
AccessTokenExpire: 15 * time.Minute,
|
|
RefreshTokenExpire: 7 * 24 * time.Hour,
|
|
})
|
|
svc := &AuthService{jwtManager: jwtManager}
|
|
ttl := svc.RefreshTokenTTLSeconds()
|
|
if ttl == 0 {
|
|
t.Error("Expected non-zero TTL")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// PublishEvent Tests
|
|
// =============================================================================
|
|
|
|
func TestPublishEvent(t *testing.T) {
|
|
t.Run("nil service does not panic", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
nilSvc.publishEvent(context.Background(), domain.EventUserLogin, nil)
|
|
})
|
|
|
|
t.Run("service without webhook service does not panic", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
svc.publishEvent(context.Background(), domain.EventUserLogin, map[string]interface{}{"user_id": 1})
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// OAuthLogin Tests
|
|
// =============================================================================
|
|
|
|
func TestOAuthLogin(t *testing.T) {
|
|
t.Run("nil service returns error", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
_, err := nilSvc.OAuthLogin(context.Background(), "github", "http://localhost/callback")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("service without oauth manager returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
_, err := svc.OAuthLogin(context.Background(), "github", "http://localhost/callback")
|
|
if err == nil {
|
|
t.Error("Expected error when oauth manager not configured")
|
|
}
|
|
})
|
|
}
|
|
|
|
// =============================================================================
|
|
// StartSocialAccountBinding Extended Tests
|
|
// =============================================================================
|
|
|
|
func TestStartSocialAccountBinding_Extended(t *testing.T) {
|
|
t.Run("nil service returns error", func(t *testing.T) {
|
|
var nilSvc *AuthService
|
|
_, _, err := nilSvc.StartSocialAccountBinding(context.Background(), 1, "github", "http://localhost", "password", "")
|
|
if err == nil {
|
|
t.Error("Expected error for nil service")
|
|
}
|
|
})
|
|
|
|
t.Run("service without oauth manager returns error", func(t *testing.T) {
|
|
svc := &AuthService{}
|
|
_, _, err := svc.StartSocialAccountBinding(context.Background(), 1, "github", "http://localhost", "password", "")
|
|
if err == nil {
|
|
t.Error("Expected error when oauth manager not configured")
|
|
}
|
|
})
|
|
}
|