fix: P0-02 prevent login attempt counter race condition
Add atomic Increment method to cache layers: - L2Cache interface: add Increment method signature - RedisCache: implement using Redis INCRBY - L1Cache: implement with mutex-protected counter - CacheManager: add Increment that updates both L1 and L2 Update incrementFailAttempts to use atomic Increment instead of Get-Increment-Set pattern, preventing TOCTOU race.
This commit is contained in:
13
internal/cache/cache_manager.go
vendored
13
internal/cache/cache_manager.go
vendored
@@ -106,3 +106,16 @@ func (cm *CacheManager) GetL1() *L1Cache {
|
|||||||
func (cm *CacheManager) GetL2() L2Cache {
|
func (cm *CacheManager) GetL2() L2Cache {
|
||||||
return cm.l2
|
return cm.l2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment 原子递增(同时更新L1和L2)
|
||||||
|
func (cm *CacheManager) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
|
||||||
|
// 先更新L1
|
||||||
|
cm.l1.Increment(key, delta, ttl)
|
||||||
|
|
||||||
|
// 再更新L2
|
||||||
|
if cm.l2 != nil {
|
||||||
|
return cm.l2.Increment(ctx, key, delta, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cm.l1.Increment(key, 0, 0), nil
|
||||||
|
}
|
||||||
|
|||||||
41
internal/cache/l1.go
vendored
41
internal/cache/l1.go
vendored
@@ -169,3 +169,44 @@ func (c *L1Cache) Cleanup() {
|
|||||||
c.removeFromAccessOrder(key)
|
c.removeFromAccessOrder(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment 原子递增(用于登录失败计数器等原子操作场景)
|
||||||
|
func (c *L1Cache) Increment(key string, delta int64, ttl time.Duration) int64 {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
var expiration int64
|
||||||
|
if ttl > 0 {
|
||||||
|
expiration = time.Now().Add(ttl).UnixNano()
|
||||||
|
}
|
||||||
|
|
||||||
|
current := int64(0)
|
||||||
|
if item, ok := c.items[key]; ok {
|
||||||
|
if item.Expired() {
|
||||||
|
delete(c.items, key)
|
||||||
|
c.removeFromAccessOrder(key)
|
||||||
|
} else {
|
||||||
|
if v, ok := item.Value.(int64); ok {
|
||||||
|
current = v
|
||||||
|
} else if v, ok := item.Value.(int); ok {
|
||||||
|
current = int64(v)
|
||||||
|
} else if v, ok := item.Value.(float64); ok {
|
||||||
|
current = int64(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
newVal := current + delta
|
||||||
|
c.items[key] = &CacheItem{
|
||||||
|
Value: newVal,
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, exists := c.items[key]; !exists {
|
||||||
|
c.accessOrder = append(c.accessOrder, key)
|
||||||
|
} else {
|
||||||
|
c.updateAccessOrder(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return newVal
|
||||||
|
}
|
||||||
|
|||||||
15
internal/cache/l2.go
vendored
15
internal/cache/l2.go
vendored
@@ -17,6 +17,7 @@ type L2Cache interface {
|
|||||||
Delete(ctx context.Context, key string) error
|
Delete(ctx context.Context, key string) error
|
||||||
Exists(ctx context.Context, key string) (bool, error)
|
Exists(ctx context.Context, key string) (bool, error)
|
||||||
Clear(ctx context.Context) error
|
Clear(ctx context.Context) error
|
||||||
|
Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error)
|
||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -127,6 +128,20 @@ func (c *RedisCache) Close() error {
|
|||||||
return c.client.Close()
|
return c.client.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *RedisCache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
|
||||||
|
if !c.enabled || c.client == nil {
|
||||||
|
return 0, errors.New("redis is not enabled")
|
||||||
|
}
|
||||||
|
result, err := c.client.IncrBy(ctx, key, delta).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
if ttl > 0 {
|
||||||
|
c.client.Expire(ctx, key, ttl)
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
func decodeRedisValue(raw string) (interface{}, error) {
|
func decodeRedisValue(raw string) (interface{}, error) {
|
||||||
decoder := json.NewDecoder(strings.NewReader(raw))
|
decoder := json.NewDecoder(strings.NewReader(raw))
|
||||||
decoder.UseNumber()
|
decoder.UseNumber()
|
||||||
|
|||||||
@@ -494,17 +494,23 @@ func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
current := 0
|
// 使用原子递增,避免竞态条件
|
||||||
if value, ok := s.cache.Get(ctx, key); ok {
|
newVal, err := s.cache.Increment(ctx, key, 1, s.loginLockDuration)
|
||||||
current = attemptCount(value)
|
if err != nil {
|
||||||
}
|
log.Printf("auth: increment login attempts failed, key=%s err=%v", key, err)
|
||||||
current++
|
// 回退到原来的非原子方式
|
||||||
|
current := 0
|
||||||
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil {
|
if value, ok := s.cache.Get(ctx, key); ok {
|
||||||
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err)
|
current = attemptCount(value)
|
||||||
|
}
|
||||||
|
current++
|
||||||
|
if setErr := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); setErr != nil {
|
||||||
|
log.Printf("auth: store login attempts failed, key=%s err=%v", key, setErr)
|
||||||
|
}
|
||||||
|
return current
|
||||||
}
|
}
|
||||||
|
|
||||||
return current
|
return int(newVal)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isValidPhoneSimple(phone string) bool {
|
func isValidPhoneSimple(phone string) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user