Files
user-system/internal/service/auth_runtime.go
long-agent 2ecd1fef1e refactor: 提取 service 层 best-effort 超时常量
- 新增 defaultBETimeout = 5 * time.Second
- 替换 auth/auth_runtime/password_reset/user_service/webhook 中 6 处硬编码 5*time.Second
2026-05-08 12:44:05 +08:00

368 lines
9.3 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/domain"
"gorm.io/gorm"
)
type oauthRegistrar interface {
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
}
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
if cfg == nil {
return
}
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
registrar.RegisterProvider(provider, cfg)
}
}
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
// P1性能优化使用单一查询替代 username->email->phone 串行查询减少DB往返
user, err := s.userRepo.FindByAccount(ctx, account)
if err != nil {
if isUserNotFoundError(err) {
return nil, err
}
return nil, fmt.Errorf("lookup user failed: %w", err)
}
return user, nil
}
func isUserNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, gorm.ErrRecordNotFound) {
return true
}
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
return strings.Contains(lowerErr, "record not found") ||
strings.Contains(lowerErr, "user not found") ||
strings.Contains(err.Error(), "用户不存在") ||
strings.Contains(lowerErr, "not found")
}
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
return
}
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
if err != nil {
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
return
}
if len(defaultRoles) == 0 {
return
}
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
for _, role := range defaultRoles {
userRoles = append(userRoles, &domain.UserRole{
UserID: userID,
RoleID: role.ID,
})
}
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
}
}
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
if s == nil || s.userRepo == nil {
return
}
// PERF-01: 改为异步 goroutine不阻塞登录响应返回
go func() {
defer func() {
if r := recover(); r != nil {
log.Printf("auth: update last login panic recovered, source=%s user_id=%d err=%v", source, userID, r)
}
}()
bgCtx, cancel := context.WithTimeout(context.Background(), defaultBETimeout)
defer cancel()
if err := s.userRepo.UpdateLastLogin(bgCtx, userID, ip); err != nil {
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
}
}()
}
func loginAttemptKey(account string, user *domain.User) string {
if user != nil {
return fmt.Sprintf("login_attempt:user:%d", user.ID)
}
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
}
func attemptCount(value interface{}) int {
if count, ok := intValue(value); ok {
return count
}
return 0
}
func intValue(value interface{}) (int, bool) {
switch v := value.(type) {
case int:
return v, true
case int64:
return int(v), true
case float64:
return int(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return int(n), true
default:
return 0, false
}
}
func int64Value(value interface{}) (int64, bool) {
switch v := value.(type) {
case int64:
return v, true
case int:
return int64(v), true
case float64:
return int64(v), true
case json.Number:
n, err := v.Int64()
if err != nil {
return 0, false
}
return n, true
default:
return 0, false
}
}
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
if req == nil || req.Phone == "" {
return nil
}
if s.smsCodeSvc == nil {
return errors.New("手机注册未启用")
}
if req.PhoneCode == "" {
return errors.New("手机验证码不能为空")
}
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
}
const (
oauthStateCachePrefix = "oauth_state:"
oauthHandoffCachePrefix = "oauth_handoff:"
oauthStateTTL = 10 * time.Minute
oauthHandoffTTL = time.Minute
)
type OAuthStatePurpose string
const (
OAuthStatePurposeLogin OAuthStatePurpose = "login"
OAuthStatePurposeBind OAuthStatePurpose = "bind"
)
type OAuthStatePayload struct {
Purpose OAuthStatePurpose `json:"purpose"`
ReturnTo string `json:"return_to"`
UserID int64 `json:"user_id,omitempty"`
}
func generateOAuthEphemeralCode() (string, error) {
buffer := make([]byte, 32)
if _, err := rand.Read(buffer); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buffer), nil
}
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(returnTo),
})
}
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
if userID <= 0 {
return "", errors.New("oauth binding user is required")
}
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
Purpose: OAuthStatePurposeBind,
ReturnTo: strings.TrimSpace(returnTo),
UserID: userID,
})
}
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth state storage unavailable")
}
if payload == nil {
return "", errors.New("oauth state payload is required")
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
state, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
return "", err
}
return state, nil
}
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
if err != nil {
return "", err
}
if payload == nil {
return "", nil
}
return strings.TrimSpace(payload.ReturnTo), nil
}
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth state storage unavailable")
}
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth state validation failed")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *OAuthStatePayload:
payload := *typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case OAuthStatePayload:
payload := typed
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
case string:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(typed),
}, nil
case nil:
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
case map[string]interface{}:
payloadBytes, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var payload OAuthStatePayload
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
return nil, err
}
if payload.Purpose == "" {
payload.Purpose = OAuthStatePurposeLogin
}
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
return &payload, nil
default:
return &OAuthStatePayload{
Purpose: OAuthStatePurposeLogin,
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
}, nil
}
}
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
if s == nil || s.cache == nil {
return "", errors.New("oauth handoff storage unavailable")
}
if loginResp == nil {
return "", errors.New("oauth handoff payload is required")
}
code, err := generateOAuthEphemeralCode()
if err != nil {
return "", err
}
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
return "", err
}
return code, nil
}
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
if s == nil || s.cache == nil {
return nil, errors.New("oauth handoff storage unavailable")
}
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
value, ok := s.cache.Get(ctx, cacheKey)
if !ok {
return nil, errors.New("OAuth handoff code is invalid or expired")
}
_ = s.cache.Delete(ctx, cacheKey)
switch typed := value.(type) {
case *LoginResponse:
return typed, nil
case LoginResponse:
resp := typed
return &resp, nil
case map[string]interface{}:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
default:
payload, err := json.Marshal(typed)
if err != nil {
return nil, err
}
var resp LoginResponse
if err := json.Unmarshal(payload, &resp); err != nil {
return nil, err
}
return &resp, nil
}
}