Files
user-system/internal/auth/totp_test.go
long-agent 2a18a6fb47 fix(n+1): 批量查询替代循环单查
- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量
- AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量
- 在 userRepositoryInterface 补充 GetByIDs 方法签名
2026-05-08 08:05:26 +08:00

209 lines
5.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
"strings"
"testing"
)
func TestTOTPManager_GenerateAndValidate(t *testing.T) {
m := NewTOTPManager()
// 生成密钥
setup, err := m.GenerateSecret("testuser@example.com")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
if setup.Secret == "" {
t.Fatal("生成的 Secret 不应为空")
}
if setup.QRCodeBase64 == "" {
t.Fatal("QRCode Base64 不应为空")
}
if len(setup.RecoveryCodes) != RecoveryCodeCount {
t.Fatalf("恢复码数量期望 %d实际 %d", RecoveryCodeCount, len(setup.RecoveryCodes))
}
t.Logf("生成 Secret: %s", setup.Secret)
t.Logf("恢复码示例: %s", setup.RecoveryCodes[0])
// 用生成的密钥生成当前 TOTP 码,再验证
code, err := m.GenerateCurrentCode(setup.Secret)
if err != nil {
t.Fatalf("GenerateCurrentCode 失败: %v", err)
}
if !m.ValidateCode(setup.Secret, code) {
t.Fatalf("有效 TOTP 码应该通过验证code=%s", code)
}
t.Logf("TOTP 验证通过code=%s", code)
}
func TestTOTPManager_InvalidCode(t *testing.T) {
m := NewTOTPManager()
setup, err := m.GenerateSecret("user")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
// 错误的验证码
if m.ValidateCode(setup.Secret, "000000") {
// 偶尔可能恰好正确,跳过而不是 fatal
t.Skip("000000 碰巧是有效码,跳过测试")
}
t.Log("无效验证码正确拒绝")
}
func TestTOTPManager_RecoveryCodeFormat(t *testing.T) {
m := NewTOTPManager()
setup, err := m.GenerateSecret("user2")
if err != nil {
t.Fatalf("GenerateSecret 失败: %v", err)
}
for i, code := range setup.RecoveryCodes {
parts := strings.Split(code, "-")
if len(parts) != 2 {
t.Errorf("恢复码 [%d] 格式错误(期望 XXXXX-XXXXX: %s", i, code)
}
if len(parts[0]) != 5 || len(parts[1]) != 5 {
t.Errorf("恢复码 [%d] 各部分长度应为 5: %s", i, code)
}
}
}
func TestValidateRecoveryCode(t *testing.T) {
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
// 正确匹配
idx, ok := ValidateRecoveryCode("ABCDE-FGHIJ", codes)
if !ok || idx != 0 {
t.Fatalf("有效恢复码应该匹配idx=%d ok=%v", idx, ok)
}
// 大小写不敏感
idx2, ok2 := ValidateRecoveryCode("klmno-pqrst", codes)
if !ok2 || idx2 != 1 {
t.Fatalf("大小写不敏感匹配失败idx=%d ok=%v", idx2, ok2)
}
// 去除空格
idx3, ok3 := ValidateRecoveryCode(" UVWXY-ZABCD ", codes)
if !ok3 || idx3 != 2 {
t.Fatalf("去除空格匹配失败idx=%d ok=%v", idx3, ok3)
}
// 不匹配
_, ok4 := ValidateRecoveryCode("XXXXX-YYYYY", codes)
if ok4 {
t.Fatal("无效恢复码不应该匹配")
}
t.Log("恢复码验证全部通过")
}
func TestHashRecoveryCode(t *testing.T) {
code := "ABCDE-FGHIJ"
hashed, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode failed: %v", err)
}
if hashed == "" {
t.Fatal("HashRecoveryCode should return non-empty hash")
}
// Same code should verify against its own hash (bcrypt uses random salt, so hashes differ)
_, ok := VerifyRecoveryCode(code, []string{hashed})
if !ok {
t.Error("Same code should verify against its own hash")
}
// Different codes should NOT verify
hashed3, err := HashRecoveryCode("DIFFERENT-CODE")
if err != nil {
t.Fatalf("HashRecoveryCode for different code failed: %v", err)
}
_, ok2 := VerifyRecoveryCode(code, []string{hashed3})
if ok2 {
t.Error("Different codes should NOT verify against each other's hash")
}
// bcrypt hash format check
if !strings.HasPrefix(hashed, "$2a$") {
t.Errorf("Hash should be bcrypt format, got: %s", hashed)
}
t.Logf("Hashed code (bcrypt): %s", hashed)
}
func TestVerifyRecoveryCode(t *testing.T) {
// Generate hashed codes
codes := []string{"ABCDE-FGHIJ", "KLMNO-PQRST", "UVWXY-ZABCD"}
hashedCodes := make([]string, len(codes))
for i, code := range codes {
hashed, err := HashRecoveryCode(code)
if err != nil {
t.Fatalf("HashRecoveryCode failed: %v", err)
}
hashedCodes[i] = hashed
}
// Test valid code (exact match)
idx, ok := VerifyRecoveryCode("ABCDE-FGHIJ", hashedCodes)
if !ok || idx != 0 {
t.Fatalf("Valid recovery code should match, idx=%d ok=%v", idx, ok)
}
// Test second code
idx2, ok2 := VerifyRecoveryCode("KLMNO-PQRST", hashedCodes)
if !ok2 || idx2 != 1 {
t.Fatalf("Second code match failed, idx=%d ok=%v", idx2, ok2)
}
// Test third code
idx3, ok3 := VerifyRecoveryCode("UVWXY-ZABCD", hashedCodes)
if !ok3 || idx3 != 2 {
t.Fatalf("Third code match failed, idx=%d ok=%v", idx3, ok3)
}
// Test invalid code
_, ok4 := VerifyRecoveryCode("XXXXX-YYYYY", hashedCodes)
if ok4 {
t.Fatal("Invalid recovery code should not match")
}
// Test empty hashed codes list
_, ok5 := VerifyRecoveryCode("ABCDE-FGHIJ", []string{})
if ok5 {
t.Fatal("Should not match against empty list")
}
t.Log("VerifyRecoveryCode tests passed")
}
func TestVerifyRecoveryCode_TimingSafety(t *testing.T) {
// Test that the function always iterates through all codes
// regardless of where the match is found (timing attack prevention)
codes := []string{"CODE1-AAAAA", "CODE2-BBBBB", "CODE3-CCCCC"}
hashedCodes := make([]string, len(codes))
for i, code := range codes {
hashed, _ := HashRecoveryCode(code)
hashedCodes[i] = hashed
}
// Test matching first code
idx1, ok1 := VerifyRecoveryCode("CODE1-AAAAA", hashedCodes)
if !ok1 || idx1 != 0 {
t.Errorf("First code match failed, idx=%d ok=%v", idx1, ok1)
}
// Test matching last code
idx3, ok3 := VerifyRecoveryCode("CODE3-CCCCC", hashedCodes)
if !ok3 || idx3 != 2 {
t.Errorf("Last code match failed, idx=%d ok=%v", idx3, ok3)
}
t.Log("Timing safety test passed")
}