370 lines
9.3 KiB
Go
370 lines
9.3 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/user-management-system/internal/auth"
|
|
"github.com/user-management-system/internal/domain"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type oauthRegistrar interface {
|
|
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
|
|
}
|
|
|
|
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
|
|
if cfg == nil {
|
|
return
|
|
}
|
|
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
|
|
registrar.RegisterProvider(provider, cfg)
|
|
}
|
|
}
|
|
|
|
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
|
|
user, err := s.userRepo.GetByUsername(ctx, account)
|
|
if err == nil {
|
|
return user, nil
|
|
}
|
|
if !isUserNotFoundError(err) {
|
|
return nil, fmt.Errorf("lookup user by username failed: %w", err)
|
|
}
|
|
|
|
user, err = s.userRepo.GetByEmail(ctx, account)
|
|
if err == nil {
|
|
return user, nil
|
|
}
|
|
if !isUserNotFoundError(err) {
|
|
return nil, fmt.Errorf("lookup user by email failed: %w", err)
|
|
}
|
|
|
|
user, err = s.userRepo.GetByPhone(ctx, account)
|
|
if err != nil && !isUserNotFoundError(err) {
|
|
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
|
|
}
|
|
return user, err
|
|
}
|
|
|
|
func isUserNotFoundError(err error) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
return true
|
|
}
|
|
|
|
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
|
|
return strings.Contains(lowerErr, "record not found") ||
|
|
strings.Contains(lowerErr, "user not found") ||
|
|
strings.Contains(err.Error(), "用户不存在") ||
|
|
strings.Contains(lowerErr, "not found")
|
|
}
|
|
|
|
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
|
|
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
|
|
return
|
|
}
|
|
|
|
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
|
|
if err != nil {
|
|
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
|
|
return
|
|
}
|
|
if len(defaultRoles) == 0 {
|
|
return
|
|
}
|
|
|
|
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
|
|
for _, role := range defaultRoles {
|
|
userRoles = append(userRoles, &domain.UserRole{
|
|
UserID: userID,
|
|
RoleID: role.ID,
|
|
})
|
|
}
|
|
|
|
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
|
|
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
|
|
}
|
|
}
|
|
|
|
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
|
|
if s == nil || s.userRepo == nil {
|
|
return
|
|
}
|
|
|
|
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
|
|
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
|
|
}
|
|
}
|
|
|
|
func loginAttemptKey(account string, user *domain.User) string {
|
|
if user != nil {
|
|
return fmt.Sprintf("login_attempt:user:%d", user.ID)
|
|
}
|
|
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
|
|
}
|
|
|
|
func attemptCount(value interface{}) int {
|
|
if count, ok := intValue(value); ok {
|
|
return count
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func intValue(value interface{}) (int, bool) {
|
|
switch v := value.(type) {
|
|
case int:
|
|
return v, true
|
|
case int64:
|
|
return int(v), true
|
|
case float64:
|
|
return int(v), true
|
|
case json.Number:
|
|
n, err := v.Int64()
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return int(n), true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func int64Value(value interface{}) (int64, bool) {
|
|
switch v := value.(type) {
|
|
case int64:
|
|
return v, true
|
|
case int:
|
|
return int64(v), true
|
|
case float64:
|
|
return int64(v), true
|
|
case json.Number:
|
|
n, err := v.Int64()
|
|
if err != nil {
|
|
return 0, false
|
|
}
|
|
return n, true
|
|
default:
|
|
return 0, false
|
|
}
|
|
}
|
|
|
|
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
|
|
if req == nil || req.Phone == "" {
|
|
return nil
|
|
}
|
|
if s.smsCodeSvc == nil {
|
|
return errors.New("手机注册未启用")
|
|
}
|
|
if req.PhoneCode == "" {
|
|
return errors.New("手机验证码不能为空")
|
|
}
|
|
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
|
|
}
|
|
|
|
const (
|
|
oauthStateCachePrefix = "oauth_state:"
|
|
oauthHandoffCachePrefix = "oauth_handoff:"
|
|
oauthStateTTL = 10 * time.Minute
|
|
oauthHandoffTTL = time.Minute
|
|
)
|
|
|
|
type OAuthStatePurpose string
|
|
|
|
const (
|
|
OAuthStatePurposeLogin OAuthStatePurpose = "login"
|
|
OAuthStatePurposeBind OAuthStatePurpose = "bind"
|
|
)
|
|
|
|
type OAuthStatePayload struct {
|
|
Purpose OAuthStatePurpose `json:"purpose"`
|
|
ReturnTo string `json:"return_to"`
|
|
UserID int64 `json:"user_id,omitempty"`
|
|
}
|
|
|
|
func generateOAuthEphemeralCode() (string, error) {
|
|
buffer := make([]byte, 32)
|
|
if _, err := rand.Read(buffer); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
|
}
|
|
|
|
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
|
|
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeLogin,
|
|
ReturnTo: strings.TrimSpace(returnTo),
|
|
})
|
|
}
|
|
|
|
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
|
|
if userID <= 0 {
|
|
return "", errors.New("oauth binding user is required")
|
|
}
|
|
|
|
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeBind,
|
|
ReturnTo: strings.TrimSpace(returnTo),
|
|
UserID: userID,
|
|
})
|
|
}
|
|
|
|
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
|
|
if s == nil || s.cache == nil {
|
|
return "", errors.New("oauth state storage unavailable")
|
|
}
|
|
if payload == nil {
|
|
return "", errors.New("oauth state payload is required")
|
|
}
|
|
if payload.Purpose == "" {
|
|
payload.Purpose = OAuthStatePurposeLogin
|
|
}
|
|
|
|
state, err := generateOAuthEphemeralCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return state, nil
|
|
}
|
|
|
|
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
|
|
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if payload == nil {
|
|
return "", nil
|
|
}
|
|
return strings.TrimSpace(payload.ReturnTo), nil
|
|
}
|
|
|
|
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
|
|
if s == nil || s.cache == nil {
|
|
return nil, errors.New("oauth state storage unavailable")
|
|
}
|
|
|
|
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
|
|
value, ok := s.cache.Get(ctx, cacheKey)
|
|
if !ok {
|
|
return nil, errors.New("OAuth state validation failed")
|
|
}
|
|
_ = s.cache.Delete(ctx, cacheKey)
|
|
|
|
switch typed := value.(type) {
|
|
case *OAuthStatePayload:
|
|
payload := *typed
|
|
if payload.Purpose == "" {
|
|
payload.Purpose = OAuthStatePurposeLogin
|
|
}
|
|
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
|
return &payload, nil
|
|
case OAuthStatePayload:
|
|
payload := typed
|
|
if payload.Purpose == "" {
|
|
payload.Purpose = OAuthStatePurposeLogin
|
|
}
|
|
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
|
return &payload, nil
|
|
case string:
|
|
return &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeLogin,
|
|
ReturnTo: strings.TrimSpace(typed),
|
|
}, nil
|
|
case nil:
|
|
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
|
|
case map[string]interface{}:
|
|
payloadBytes, err := json.Marshal(typed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var payload OAuthStatePayload
|
|
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
|
return nil, err
|
|
}
|
|
if payload.Purpose == "" {
|
|
payload.Purpose = OAuthStatePurposeLogin
|
|
}
|
|
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
|
return &payload, nil
|
|
default:
|
|
return &OAuthStatePayload{
|
|
Purpose: OAuthStatePurposeLogin,
|
|
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
|
|
if s == nil || s.cache == nil {
|
|
return "", errors.New("oauth handoff storage unavailable")
|
|
}
|
|
if loginResp == nil {
|
|
return "", errors.New("oauth handoff payload is required")
|
|
}
|
|
|
|
code, err := generateOAuthEphemeralCode()
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return code, nil
|
|
}
|
|
|
|
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
|
|
if s == nil || s.cache == nil {
|
|
return nil, errors.New("oauth handoff storage unavailable")
|
|
}
|
|
|
|
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
|
|
value, ok := s.cache.Get(ctx, cacheKey)
|
|
if !ok {
|
|
return nil, errors.New("OAuth handoff code is invalid or expired")
|
|
}
|
|
_ = s.cache.Delete(ctx, cacheKey)
|
|
|
|
switch typed := value.(type) {
|
|
case *LoginResponse:
|
|
return typed, nil
|
|
case LoginResponse:
|
|
resp := typed
|
|
return &resp, nil
|
|
case map[string]interface{}:
|
|
payload, err := json.Marshal(typed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var resp LoginResponse
|
|
if err := json.Unmarshal(payload, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
return &resp, nil
|
|
default:
|
|
payload, err := json.Marshal(typed)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var resp LoginResponse
|
|
if err := json.Unmarshal(payload, &resp); err != nil {
|
|
return nil, err
|
|
}
|
|
return &resp, nil
|
|
}
|
|
}
|