fix: atomic TOTP recovery code consumption with repository-level transaction
- Add ConsumeTOTPRecoveryCode to UserRepository for atomic read-verify-update - Update TOTPService.VerifyTOTP to prefer atomic consumption when available - Update AuthService.verifyTOTPCodeOrRecoveryCode with same pattern - Fix critical bug: ConsumeTOTPRecoveryCode now correctly returns consumed=false on mismatch - Maintain backward compatibility: falls back to non-atomic path if repo doesn't implement interface - Add comprehensive unit tests for atomic consumption path Refs: review-fix-closure-2026-05-28 TOTP recovery code atomicity
This commit is contained in:
@@ -2,11 +2,15 @@ package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/pagination"
|
||||
)
|
||||
@@ -231,6 +235,63 @@ func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) erro
|
||||
}).Error
|
||||
}
|
||||
|
||||
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
|
||||
// 在事务中验证恢复码并更新,避免并发竞争窗口
|
||||
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||||
var user domain.User
|
||||
var consumed bool
|
||||
|
||||
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
// 在事务中重新获取用户
|
||||
// 注意:SQLite 不完全支持 FOR UPDATE,依赖事务隔离
|
||||
if err := tx.First(&user, userID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !user.TOTPEnabled {
|
||||
return errors.New("TOTP 未启用")
|
||||
}
|
||||
|
||||
// 解析存储的哈希恢复码
|
||||
var hashedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证恢复码(输入会被哈希后与存储的哈希比较)
|
||||
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
// 不匹配,标记消费失败但不返回错误
|
||||
consumed = false
|
||||
return nil
|
||||
}
|
||||
|
||||
// 从列表中移除已使用的恢复码
|
||||
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
|
||||
codesJSON, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("序列化恢复码失败: %w", err)
|
||||
}
|
||||
user.TOTPRecoveryCodes = string(codesJSON)
|
||||
|
||||
// 在同一事务中更新
|
||||
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
consumed = true
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return &user, consumed, nil
|
||||
}
|
||||
|
||||
// UpdatePassword 更新用户密码
|
||||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||||
|
||||
@@ -1299,9 +1299,25 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试原子性消费恢复码(如果 repo 支持)
|
||||
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
|
||||
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, user.ID, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("消费恢复码失败: %w", err)
|
||||
}
|
||||
if consumed {
|
||||
return nil
|
||||
}
|
||||
// 恢复码不匹配
|
||||
return errors.New("TOTP code or recovery code is invalid")
|
||||
}
|
||||
|
||||
// 降级到非原子性恢复码消费(兼容性模式)
|
||||
var hashedCodes []string
|
||||
if strings.TrimSpace(user.TOTPRecoveryCodes) != "" {
|
||||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
|
||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
index, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
@@ -1311,7 +1327,7 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do
|
||||
hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...)
|
||||
payload, err := json.Marshal(hashedCodes)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("序列化恢复码失败: %w", err)
|
||||
}
|
||||
user.TOTPRecoveryCodes = string(payload)
|
||||
return s.userRepo.UpdateTOTP(ctx, user)
|
||||
|
||||
@@ -7,8 +7,14 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||
type atomicTOTPRecoveryCodeConsumer interface {
|
||||
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
|
||||
}
|
||||
|
||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||
type TOTPService struct {
|
||||
userRepo userRepositoryInterface
|
||||
@@ -122,7 +128,7 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
|
||||
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
return fmt.Errorf("用户不存在")
|
||||
}
|
||||
if !user.TOTPEnabled {
|
||||
return nil
|
||||
@@ -132,13 +138,27 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string)
|
||||
return nil
|
||||
}
|
||||
|
||||
// 尝试原子性消费恢复码(如果 repo 支持)
|
||||
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
|
||||
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("消费恢复码失败: %w", err)
|
||||
}
|
||||
if consumed {
|
||||
return nil
|
||||
}
|
||||
// 恢复码不匹配,继续返回通用错误
|
||||
return errors.New("验证码错误或已过期")
|
||||
}
|
||||
|
||||
// 降级到非原子性恢复码消费(兼容性模式)
|
||||
var storedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
|
||||
return fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
|
||||
idx, matched := auth.VerifyRecoveryCode(code, storedCodes)
|
||||
if !matched {
|
||||
return errors.New("验证码错误或已过期")
|
||||
}
|
||||
|
||||
@@ -2,17 +2,40 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func mustHashRecoveryCode(t *testing.T, code string) string {
|
||||
t.Helper()
|
||||
hashed, err := auth.HashRecoveryCode(code)
|
||||
if err != nil {
|
||||
t.Fatalf("hash recovery code: %v", err)
|
||||
}
|
||||
return hashed
|
||||
}
|
||||
|
||||
func mustMarshalJSON(t *testing.T, value any) string {
|
||||
t.Helper()
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal json: %v", err)
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
type totpTestRepo struct {
|
||||
user *domain.User
|
||||
getErr error
|
||||
updateTOTPErr error
|
||||
user *domain.User
|
||||
getErr error
|
||||
updateTOTPErr error
|
||||
consumeRecoveryCodeErr error
|
||||
consumeRecoveryCodeCalled bool
|
||||
}
|
||||
|
||||
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
|
||||
@@ -69,6 +92,40 @@ func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, e
|
||||
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||||
return nil, 0, errors.New("not implemented")
|
||||
}
|
||||
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
|
||||
r.consumeRecoveryCodeCalled = true
|
||||
if r.consumeRecoveryCodeErr != nil {
|
||||
return nil, false, r.consumeRecoveryCodeErr
|
||||
}
|
||||
if r.user == nil || r.user.ID != userID {
|
||||
return nil, false, errors.New("not found")
|
||||
}
|
||||
|
||||
var hashedCodes []string
|
||||
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
|
||||
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
|
||||
return nil, false, fmt.Errorf("解析恢复码失败: %w", err)
|
||||
}
|
||||
}
|
||||
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
copyUser := *r.user
|
||||
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
|
||||
copyUser.TOTPRecoveryCodes = mustMarshalJSONFromHelper(hashedCodes)
|
||||
r.user = ©User
|
||||
return ©User, true, nil
|
||||
}
|
||||
|
||||
func mustMarshalJSONFromHelper(value any) string {
|
||||
payload, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(payload)
|
||||
}
|
||||
|
||||
func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
|
||||
repo := &totpTestRepo{user: &domain.User{
|
||||
@@ -89,16 +146,16 @@ func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) {
|
||||
func TestTOTPService_ReturnsAtomicConsumptionErrorAfterRecoveryCodeConsumption(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
ID: 7,
|
||||
Username: "totp-user",
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "invalid-secret",
|
||||
TOTPRecoveryCodes: `["RECOVERY-1"]`,
|
||||
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||
},
|
||||
updateTOTPErr: errors.New("write failed"),
|
||||
consumeRecoveryCodeErr: errors.New("write failed"),
|
||||
}
|
||||
svc := NewTOTPService(repo)
|
||||
|
||||
@@ -106,7 +163,70 @@ func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T
|
||||
if err == nil {
|
||||
t.Fatal("expected update failure to be returned")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "更新恢复码失败") {
|
||||
t.Fatalf("expected update error, got: %v", err)
|
||||
if !repo.consumeRecoveryCodeCalled {
|
||||
t.Fatal("expected atomic consumption path to be invoked")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "消费恢复码失败") {
|
||||
t.Fatalf("expected atomic consume error, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
ID: 8,
|
||||
Username: "totp-user",
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "invalid-secret",
|
||||
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1"), mustHashRecoveryCode(t, "RECOVERY-2")}),
|
||||
},
|
||||
}
|
||||
svc := NewTOTPService(repo)
|
||||
|
||||
if err := svc.VerifyTOTP(context.Background(), 8, "RECOVERY-1"); err != nil {
|
||||
t.Fatalf("expected hashed recovery code to verify, got: %v", err)
|
||||
}
|
||||
if !repo.consumeRecoveryCodeCalled {
|
||||
t.Fatal("expected atomic recovery-code consumption path to be used")
|
||||
}
|
||||
if repo.user == nil {
|
||||
t.Fatal("expected updated user to be persisted")
|
||||
}
|
||||
|
||||
var remaining []string
|
||||
if err := json.Unmarshal([]byte(repo.user.TOTPRecoveryCodes), &remaining); err != nil {
|
||||
t.Fatalf("unmarshal remaining codes: %v", err)
|
||||
}
|
||||
if len(remaining) != 1 {
|
||||
t.Fatalf("expected 1 remaining recovery code, got %d", len(remaining))
|
||||
}
|
||||
if remaining[0] != mustHashRecoveryCode(t, "RECOVERY-2") {
|
||||
t.Fatalf("expected RECOVERY-2 hash to remain, got %q", remaining[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
|
||||
repo := &totpTestRepo{
|
||||
user: &domain.User{
|
||||
ID: 9,
|
||||
Username: "totp-user",
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "invalid-secret",
|
||||
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
|
||||
},
|
||||
}
|
||||
svc := NewTOTPService(repo)
|
||||
|
||||
if err := svc.DisableTOTP(context.Background(), 9, "RECOVERY-1"); err != nil {
|
||||
t.Fatalf("expected hashed recovery code to disable TOTP, got: %v", err)
|
||||
}
|
||||
if repo.user == nil {
|
||||
t.Fatal("expected updated user to be persisted")
|
||||
}
|
||||
if repo.user.TOTPEnabled {
|
||||
t.Fatal("expected TOTP to be disabled")
|
||||
}
|
||||
if repo.user.TOTPSecret != "" || repo.user.TOTPRecoveryCodes != "" {
|
||||
t.Fatalf("expected TOTP secret and recovery codes to be cleared, got secret=%q codes=%q", repo.user.TOTPSecret, repo.user.TOTPRecoveryCodes)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user