fix: atomic TOTP recovery code consumption with repository-level transaction

- Add ConsumeTOTPRecoveryCode to UserRepository for atomic read-verify-update
- Update TOTPService.VerifyTOTP to prefer atomic consumption when available
- Update AuthService.verifyTOTPCodeOrRecoveryCode with same pattern
- Fix critical bug: ConsumeTOTPRecoveryCode now correctly returns consumed=false on mismatch
- Maintain backward compatibility: falls back to non-atomic path if repo doesn't implement interface
- Add comprehensive unit tests for atomic consumption path

Refs: review-fix-closure-2026-05-28 TOTP recovery code atomicity
This commit is contained in:
Your Name
2026-05-29 12:31:36 +08:00
parent 80c59e2c2c
commit 878ca731f4
4 changed files with 229 additions and 12 deletions

View File

@@ -2,11 +2,15 @@ package repository
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"gorm.io/gorm"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/pagination"
)
@@ -231,6 +235,63 @@ func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) erro
}).Error
}
// ConsumeTOTPRecoveryCode 原子性地消费一个恢复码
// 在事务中验证恢复码并更新,避免并发竞争窗口
func (r *UserRepository) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
var user domain.User
var consumed bool
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 在事务中重新获取用户
// 注意SQLite 不完全支持 FOR UPDATE依赖事务隔离
if err := tx.First(&user, userID).Error; err != nil {
return err
}
if !user.TOTPEnabled {
return errors.New("TOTP 未启用")
}
// 解析存储的哈希恢复码
var hashedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
// 验证恢复码(输入会被哈希后与存储的哈希比较)
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
// 不匹配,标记消费失败但不返回错误
consumed = false
return nil
}
// 从列表中移除已使用的恢复码
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
codesJSON, err := json.Marshal(hashedCodes)
if err != nil {
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(codesJSON)
// 在同一事务中更新
if err := tx.Model(&user).Update("totp_recovery_codes", user.TOTPRecoveryCodes).Error; err != nil {
return err
}
consumed = true
return nil
})
if err != nil {
return nil, false, err
}
return &user, consumed, nil
}
// UpdatePassword 更新用户密码
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error

View File

@@ -1299,9 +1299,25 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do
return nil
}
// 尝试原子性消费恢复码(如果 repo 支持)
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, user.ID, code)
if err != nil {
return fmt.Errorf("消费恢复码失败: %w", err)
}
if consumed {
return nil
}
// 恢复码不匹配
return errors.New("TOTP code or recovery code is invalid")
}
// 降级到非原子性恢复码消费(兼容性模式)
var hashedCodes []string
if strings.TrimSpace(user.TOTPRecoveryCodes) != "" {
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
index, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
@@ -1311,7 +1327,7 @@ func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *do
hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...)
payload, err := json.Marshal(hashedCodes)
if err != nil {
return err
return fmt.Errorf("序列化恢复码失败: %w", err)
}
user.TOTPRecoveryCodes = string(payload)
return s.userRepo.UpdateTOTP(ctx, user)

View File

@@ -7,8 +7,14 @@ import (
"fmt"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
// TOTPService manages 2FA setup, enable/disable, and verification.
type atomicTOTPRecoveryCodeConsumer interface {
ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error)
}
// TOTPService manages 2FA setup, enable/disable, and verification.
type TOTPService struct {
userRepo userRepositoryInterface
@@ -122,7 +128,7 @@ func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
return fmt.Errorf("用户不存在")
}
if !user.TOTPEnabled {
return nil
@@ -132,13 +138,27 @@ func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string)
return nil
}
// 尝试原子性消费恢复码(如果 repo 支持)
if consumer, ok := s.userRepo.(atomicTOTPRecoveryCodeConsumer); ok {
_, consumed, err := consumer.ConsumeTOTPRecoveryCode(ctx, userID, code)
if err != nil {
return fmt.Errorf("消费恢复码失败: %w", err)
}
if consumed {
return nil
}
// 恢复码不匹配,继续返回通用错误
return errors.New("验证码错误或已过期")
}
// 降级到非原子性恢复码消费(兼容性模式)
var storedCodes []string
if user.TOTPRecoveryCodes != "" {
if err := json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes); err != nil {
return fmt.Errorf("解析恢复码失败: %w", err)
}
}
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
idx, matched := auth.VerifyRecoveryCode(code, storedCodes)
if !matched {
return errors.New("验证码错误或已过期")
}

View File

@@ -2,17 +2,40 @@ package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"testing"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
)
func mustHashRecoveryCode(t *testing.T, code string) string {
t.Helper()
hashed, err := auth.HashRecoveryCode(code)
if err != nil {
t.Fatalf("hash recovery code: %v", err)
}
return hashed
}
func mustMarshalJSON(t *testing.T, value any) string {
t.Helper()
payload, err := json.Marshal(value)
if err != nil {
t.Fatalf("marshal json: %v", err)
}
return string(payload)
}
type totpTestRepo struct {
user *domain.User
getErr error
updateTOTPErr error
user *domain.User
getErr error
updateTOTPErr error
consumeRecoveryCodeErr error
consumeRecoveryCodeCalled bool
}
func (r *totpTestRepo) Create(ctx context.Context, user *domain.User) error { return nil }
@@ -69,6 +92,40 @@ func (r *totpTestRepo) ExistsByPhone(ctx context.Context, phone string) (bool, e
func (r *totpTestRepo) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
return nil, 0, errors.New("not implemented")
}
func (r *totpTestRepo) ConsumeTOTPRecoveryCode(ctx context.Context, userID int64, code string) (*domain.User, bool, error) {
r.consumeRecoveryCodeCalled = true
if r.consumeRecoveryCodeErr != nil {
return nil, false, r.consumeRecoveryCodeErr
}
if r.user == nil || r.user.ID != userID {
return nil, false, errors.New("not found")
}
var hashedCodes []string
if strings.TrimSpace(r.user.TOTPRecoveryCodes) != "" {
if err := json.Unmarshal([]byte(r.user.TOTPRecoveryCodes), &hashedCodes); err != nil {
return nil, false, fmt.Errorf("解析恢复码失败: %w", err)
}
}
idx, matched := auth.VerifyRecoveryCode(code, hashedCodes)
if !matched {
return nil, false, nil
}
copyUser := *r.user
hashedCodes = append(hashedCodes[:idx], hashedCodes[idx+1:]...)
copyUser.TOTPRecoveryCodes = mustMarshalJSONFromHelper(hashedCodes)
r.user = &copyUser
return &copyUser, true, nil
}
func mustMarshalJSONFromHelper(value any) string {
payload, err := json.Marshal(value)
if err != nil {
panic(err)
}
return string(payload)
}
func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
repo := &totpTestRepo{user: &domain.User{
@@ -89,16 +146,16 @@ func TestTOTPService_ReturnsDecodeErrorForCorruptedRecoveryCodes(t *testing.T) {
}
}
func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T) {
func TestTOTPService_ReturnsAtomicConsumptionErrorAfterRecoveryCodeConsumption(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 7,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: `["RECOVERY-1"]`,
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
updateTOTPErr: errors.New("write failed"),
consumeRecoveryCodeErr: errors.New("write failed"),
}
svc := NewTOTPService(repo)
@@ -106,7 +163,70 @@ func TestTOTPService_ReturnsUpdateErrorAfterRecoveryCodeConsumption(t *testing.T
if err == nil {
t.Fatal("expected update failure to be returned")
}
if !strings.Contains(err.Error(), "更新恢复码失败") {
t.Fatalf("expected update error, got: %v", err)
if !repo.consumeRecoveryCodeCalled {
t.Fatal("expected atomic consumption path to be invoked")
}
if !strings.Contains(err.Error(), "消费恢复码失败") {
t.Fatalf("expected atomic consume error, got: %v", err)
}
}
func TestTOTPService_ConsumesHashedRecoveryCodeOnVerify(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 8,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1"), mustHashRecoveryCode(t, "RECOVERY-2")}),
},
}
svc := NewTOTPService(repo)
if err := svc.VerifyTOTP(context.Background(), 8, "RECOVERY-1"); err != nil {
t.Fatalf("expected hashed recovery code to verify, got: %v", err)
}
if !repo.consumeRecoveryCodeCalled {
t.Fatal("expected atomic recovery-code consumption path to be used")
}
if repo.user == nil {
t.Fatal("expected updated user to be persisted")
}
var remaining []string
if err := json.Unmarshal([]byte(repo.user.TOTPRecoveryCodes), &remaining); err != nil {
t.Fatalf("unmarshal remaining codes: %v", err)
}
if len(remaining) != 1 {
t.Fatalf("expected 1 remaining recovery code, got %d", len(remaining))
}
if remaining[0] != mustHashRecoveryCode(t, "RECOVERY-2") {
t.Fatalf("expected RECOVERY-2 hash to remain, got %q", remaining[0])
}
}
func TestTOTPService_DisableAcceptsHashedRecoveryCode(t *testing.T) {
repo := &totpTestRepo{
user: &domain.User{
ID: 9,
Username: "totp-user",
TOTPEnabled: true,
TOTPSecret: "invalid-secret",
TOTPRecoveryCodes: mustMarshalJSON(t, []string{mustHashRecoveryCode(t, "RECOVERY-1")}),
},
}
svc := NewTOTPService(repo)
if err := svc.DisableTOTP(context.Background(), 9, "RECOVERY-1"); err != nil {
t.Fatalf("expected hashed recovery code to disable TOTP, got: %v", err)
}
if repo.user == nil {
t.Fatal("expected updated user to be persisted")
}
if repo.user.TOTPEnabled {
t.Fatal("expected TOTP to be disabled")
}
if repo.user.TOTPSecret != "" || repo.user.TOTPRecoveryCodes != "" {
t.Fatalf("expected TOTP secret and recovery codes to be cleared, got secret=%q codes=%q", repo.user.TOTPSecret, repo.user.TOTPRecoveryCodes)
}
}