package service import ( "context" "crypto/rand" "encoding/base64" "encoding/json" "errors" "fmt" "log" "strings" "time" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/domain" "gorm.io/gorm" ) type oauthRegistrar interface { RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig) } func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) { if cfg == nil { return } if registrar, ok := s.oauthManager.(oauthRegistrar); ok { registrar.RegisterProvider(provider, cfg) } } func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) { user, err := s.userRepo.GetByUsername(ctx, account) if err == nil { return user, nil } if !isUserNotFoundError(err) { return nil, fmt.Errorf("lookup user by username failed: %w", err) } user, err = s.userRepo.GetByEmail(ctx, account) if err == nil { return user, nil } if !isUserNotFoundError(err) { return nil, fmt.Errorf("lookup user by email failed: %w", err) } user, err = s.userRepo.GetByPhone(ctx, account) if err != nil && !isUserNotFoundError(err) { return nil, fmt.Errorf("lookup user by phone failed: %w", err) } return user, err } func isUserNotFoundError(err error) bool { if err == nil { return false } if errors.Is(err, gorm.ErrRecordNotFound) { return true } lowerErr := strings.ToLower(strings.TrimSpace(err.Error())) return strings.Contains(lowerErr, "record not found") || strings.Contains(lowerErr, "user not found") || strings.Contains(err.Error(), "用户不存在") || strings.Contains(lowerErr, "not found") } func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) { if s == nil || s.userRoleRepo == nil || s.roleRepo == nil { return } defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx) if err != nil { log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err) return } if len(defaultRoles) == 0 { return } userRoles := make([]*domain.UserRole, 0, len(defaultRoles)) for _, role := range defaultRoles { userRoles = append(userRoles, &domain.UserRole{ UserID: userID, RoleID: role.ID, }) } if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil { log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err) } } func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) { if s == nil || s.userRepo == nil { return } if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil { log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err) } } func loginAttemptKey(account string, user *domain.User) string { if user != nil { return fmt.Sprintf("login_attempt:user:%d", user.ID) } return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account)) } func attemptCount(value interface{}) int { if count, ok := intValue(value); ok { return count } return 0 } func intValue(value interface{}) (int, bool) { switch v := value.(type) { case int: return v, true case int64: return int(v), true case float64: return int(v), true case json.Number: n, err := v.Int64() if err != nil { return 0, false } return int(n), true default: return 0, false } } func int64Value(value interface{}) (int64, bool) { switch v := value.(type) { case int64: return v, true case int: return int64(v), true case float64: return int64(v), true case json.Number: n, err := v.Int64() if err != nil { return 0, false } return n, true default: return 0, false } } func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error { if req == nil || req.Phone == "" { return nil } if s.smsCodeSvc == nil { return errors.New("手机注册未启用") } if req.PhoneCode == "" { return errors.New("手机验证码不能为空") } return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode) } const ( oauthStateCachePrefix = "oauth_state:" oauthHandoffCachePrefix = "oauth_handoff:" oauthStateTTL = 10 * time.Minute oauthHandoffTTL = time.Minute ) type OAuthStatePurpose string const ( OAuthStatePurposeLogin OAuthStatePurpose = "login" OAuthStatePurposeBind OAuthStatePurpose = "bind" ) type OAuthStatePayload struct { Purpose OAuthStatePurpose `json:"purpose"` ReturnTo string `json:"return_to"` UserID int64 `json:"user_id,omitempty"` } func generateOAuthEphemeralCode() (string, error) { buffer := make([]byte, 32) if _, err := rand.Read(buffer); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(buffer), nil } func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) { return s.createOAuthStatePayload(ctx, &OAuthStatePayload{ Purpose: OAuthStatePurposeLogin, ReturnTo: strings.TrimSpace(returnTo), }) } func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) { if userID <= 0 { return "", errors.New("oauth binding user is required") } return s.createOAuthStatePayload(ctx, &OAuthStatePayload{ Purpose: OAuthStatePurposeBind, ReturnTo: strings.TrimSpace(returnTo), UserID: userID, }) } func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) { if s == nil || s.cache == nil { return "", errors.New("oauth state storage unavailable") } if payload == nil { return "", errors.New("oauth state payload is required") } if payload.Purpose == "" { payload.Purpose = OAuthStatePurposeLogin } state, err := generateOAuthEphemeralCode() if err != nil { return "", err } if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil { return "", err } return state, nil } func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) { payload, err := s.ConsumeOAuthStatePayload(ctx, state) if err != nil { return "", err } if payload == nil { return "", nil } return strings.TrimSpace(payload.ReturnTo), nil } func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) { if s == nil || s.cache == nil { return nil, errors.New("oauth state storage unavailable") } cacheKey := oauthStateCachePrefix + strings.TrimSpace(state) value, ok := s.cache.Get(ctx, cacheKey) if !ok { return nil, errors.New("OAuth state validation failed") } _ = s.cache.Delete(ctx, cacheKey) switch typed := value.(type) { case *OAuthStatePayload: payload := *typed if payload.Purpose == "" { payload.Purpose = OAuthStatePurposeLogin } payload.ReturnTo = strings.TrimSpace(payload.ReturnTo) return &payload, nil case OAuthStatePayload: payload := typed if payload.Purpose == "" { payload.Purpose = OAuthStatePurposeLogin } payload.ReturnTo = strings.TrimSpace(payload.ReturnTo) return &payload, nil case string: return &OAuthStatePayload{ Purpose: OAuthStatePurposeLogin, ReturnTo: strings.TrimSpace(typed), }, nil case nil: return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil case map[string]interface{}: payloadBytes, err := json.Marshal(typed) if err != nil { return nil, err } var payload OAuthStatePayload if err := json.Unmarshal(payloadBytes, &payload); err != nil { return nil, err } if payload.Purpose == "" { payload.Purpose = OAuthStatePurposeLogin } payload.ReturnTo = strings.TrimSpace(payload.ReturnTo) return &payload, nil default: return &OAuthStatePayload{ Purpose: OAuthStatePurposeLogin, ReturnTo: strings.TrimSpace(fmt.Sprint(typed)), }, nil } } func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) { if s == nil || s.cache == nil { return "", errors.New("oauth handoff storage unavailable") } if loginResp == nil { return "", errors.New("oauth handoff payload is required") } code, err := generateOAuthEphemeralCode() if err != nil { return "", err } if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil { return "", err } return code, nil } func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) { if s == nil || s.cache == nil { return nil, errors.New("oauth handoff storage unavailable") } cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code) value, ok := s.cache.Get(ctx, cacheKey) if !ok { return nil, errors.New("OAuth handoff code is invalid or expired") } _ = s.cache.Delete(ctx, cacheKey) switch typed := value.(type) { case *LoginResponse: return typed, nil case LoginResponse: resp := typed return &resp, nil case map[string]interface{}: payload, err := json.Marshal(typed) if err != nil { return nil, err } var resp LoginResponse if err := json.Unmarshal(payload, &resp); err != nil { return nil, err } return &resp, nil default: payload, err := json.Marshal(typed) if err != nil { return nil, err } var resp LoginResponse if err := json.Unmarshal(payload, &resp); err != nil { return nil, err } return &resp, nil } }