fix: P0-02 prevent login attempt counter race condition

Add atomic Increment method to cache layers:
- L2Cache interface: add Increment method signature
- RedisCache: implement using Redis INCRBY
- L1Cache: implement with mutex-protected counter
- CacheManager: add Increment that updates both L1 and L2

Update incrementFailAttempts to use atomic Increment instead
of Get-Increment-Set pattern, preventing TOCTOU race.
This commit is contained in:
2026-04-18 13:45:09 +08:00
parent 32a3d4c9e0
commit ca7ba5ccdf
4 changed files with 84 additions and 9 deletions

View File

@@ -106,3 +106,16 @@ func (cm *CacheManager) GetL1() *L1Cache {
func (cm *CacheManager) GetL2() L2Cache { func (cm *CacheManager) GetL2() L2Cache {
return cm.l2 return cm.l2
} }
// Increment 原子递增同时更新L1和L2
func (cm *CacheManager) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
// 先更新L1
cm.l1.Increment(key, delta, ttl)
// 再更新L2
if cm.l2 != nil {
return cm.l2.Increment(ctx, key, delta, ttl)
}
return cm.l1.Increment(key, 0, 0), nil
}

41
internal/cache/l1.go vendored
View File

@@ -169,3 +169,44 @@ func (c *L1Cache) Cleanup() {
c.removeFromAccessOrder(key) c.removeFromAccessOrder(key)
} }
} }
// Increment 原子递增(用于登录失败计数器等原子操作场景)
func (c *L1Cache) Increment(key string, delta int64, ttl time.Duration) int64 {
c.mu.Lock()
defer c.mu.Unlock()
var expiration int64
if ttl > 0 {
expiration = time.Now().Add(ttl).UnixNano()
}
current := int64(0)
if item, ok := c.items[key]; ok {
if item.Expired() {
delete(c.items, key)
c.removeFromAccessOrder(key)
} else {
if v, ok := item.Value.(int64); ok {
current = v
} else if v, ok := item.Value.(int); ok {
current = int64(v)
} else if v, ok := item.Value.(float64); ok {
current = int64(v)
}
}
}
newVal := current + delta
c.items[key] = &CacheItem{
Value: newVal,
Expiration: expiration,
}
if _, exists := c.items[key]; !exists {
c.accessOrder = append(c.accessOrder, key)
} else {
c.updateAccessOrder(key)
}
return newVal
}

15
internal/cache/l2.go vendored
View File

@@ -17,6 +17,7 @@ type L2Cache interface {
Delete(ctx context.Context, key string) error Delete(ctx context.Context, key string) error
Exists(ctx context.Context, key string) (bool, error) Exists(ctx context.Context, key string) (bool, error)
Clear(ctx context.Context) error Clear(ctx context.Context) error
Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error)
Close() error Close() error
} }
@@ -127,6 +128,20 @@ func (c *RedisCache) Close() error {
return c.client.Close() return c.client.Close()
} }
func (c *RedisCache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
if !c.enabled || c.client == nil {
return 0, errors.New("redis is not enabled")
}
result, err := c.client.IncrBy(ctx, key, delta).Result()
if err != nil {
return 0, err
}
if ttl > 0 {
c.client.Expire(ctx, key, ttl)
}
return result, nil
}
func decodeRedisValue(raw string) (interface{}, error) { func decodeRedisValue(raw string) (interface{}, error) {
decoder := json.NewDecoder(strings.NewReader(raw)) decoder := json.NewDecoder(strings.NewReader(raw))
decoder.UseNumber() decoder.UseNumber()

View File

@@ -494,17 +494,23 @@ func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int
return 0 return 0
} }
current := 0 // 使用原子递增,避免竞态条件
if value, ok := s.cache.Get(ctx, key); ok { newVal, err := s.cache.Increment(ctx, key, 1, s.loginLockDuration)
current = attemptCount(value) if err != nil {
} log.Printf("auth: increment login attempts failed, key=%s err=%v", key, err)
current++ // 回退到原来的非原子方式
current := 0
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil { if value, ok := s.cache.Get(ctx, key); ok {
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err) current = attemptCount(value)
}
current++
if setErr := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); setErr != nil {
log.Printf("auth: store login attempts failed, key=%s err=%v", key, setErr)
}
return current
} }
return current return int(newVal)
} }
func isValidPhoneSimple(phone string) bool { func isValidPhoneSimple(phone string) bool {