feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
369
internal/service/auth_runtime.go
Normal file
369
internal/service/auth_runtime.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type oauthRegistrar interface {
|
||||
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
|
||||
}
|
||||
|
||||
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
|
||||
registrar.RegisterProvider(provider, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
|
||||
user, err := s.userRepo.GetByUsername(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by username failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByEmail(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by email failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByPhone(ctx, account)
|
||||
if err != nil && !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
func isUserNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
return strings.Contains(lowerErr, "record not found") ||
|
||||
strings.Contains(lowerErr, "user not found") ||
|
||||
strings.Contains(err.Error(), "用户不存在") ||
|
||||
strings.Contains(lowerErr, "not found")
|
||||
}
|
||||
|
||||
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
|
||||
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
|
||||
if err != nil {
|
||||
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
|
||||
return
|
||||
}
|
||||
if len(defaultRoles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
|
||||
for _, role := range defaultRoles {
|
||||
userRoles = append(userRoles, &domain.UserRole{
|
||||
UserID: userID,
|
||||
RoleID: role.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
|
||||
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
|
||||
if s == nil || s.userRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
|
||||
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
|
||||
}
|
||||
}
|
||||
|
||||
func loginAttemptKey(account string, user *domain.User) string {
|
||||
if user != nil {
|
||||
return fmt.Sprintf("login_attempt:user:%d", user.ID)
|
||||
}
|
||||
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
|
||||
}
|
||||
|
||||
func attemptCount(value interface{}) int {
|
||||
if count, ok := intValue(value); ok {
|
||||
return count
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func intValue(value interface{}) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v, true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case float64:
|
||||
return int(v), true
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func int64Value(value interface{}) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
|
||||
if req == nil || req.Phone == "" {
|
||||
return nil
|
||||
}
|
||||
if s.smsCodeSvc == nil {
|
||||
return errors.New("手机注册未启用")
|
||||
}
|
||||
if req.PhoneCode == "" {
|
||||
return errors.New("手机验证码不能为空")
|
||||
}
|
||||
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
|
||||
}
|
||||
|
||||
const (
|
||||
oauthStateCachePrefix = "oauth_state:"
|
||||
oauthHandoffCachePrefix = "oauth_handoff:"
|
||||
oauthStateTTL = 10 * time.Minute
|
||||
oauthHandoffTTL = time.Minute
|
||||
)
|
||||
|
||||
type OAuthStatePurpose string
|
||||
|
||||
const (
|
||||
OAuthStatePurposeLogin OAuthStatePurpose = "login"
|
||||
OAuthStatePurposeBind OAuthStatePurpose = "bind"
|
||||
)
|
||||
|
||||
type OAuthStatePayload struct {
|
||||
Purpose OAuthStatePurpose `json:"purpose"`
|
||||
ReturnTo string `json:"return_to"`
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func generateOAuthEphemeralCode() (string, error) {
|
||||
buffer := make([]byte, 32)
|
||||
if _, err := rand.Read(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
|
||||
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(returnTo),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
|
||||
if userID <= 0 {
|
||||
return "", errors.New("oauth binding user is required")
|
||||
}
|
||||
|
||||
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeBind,
|
||||
ReturnTo: strings.TrimSpace(returnTo),
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return "", errors.New("oauth state storage unavailable")
|
||||
}
|
||||
if payload == nil {
|
||||
return "", errors.New("oauth state payload is required")
|
||||
}
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
|
||||
state, err := generateOAuthEphemeralCode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
|
||||
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if payload == nil {
|
||||
return "", nil
|
||||
}
|
||||
return strings.TrimSpace(payload.ReturnTo), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil, errors.New("oauth state storage unavailable")
|
||||
}
|
||||
|
||||
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
|
||||
value, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return nil, errors.New("OAuth state validation failed")
|
||||
}
|
||||
_ = s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
switch typed := value.(type) {
|
||||
case *OAuthStatePayload:
|
||||
payload := *typed
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
case OAuthStatePayload:
|
||||
payload := typed
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
case string:
|
||||
return &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(typed),
|
||||
}, nil
|
||||
case nil:
|
||||
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
|
||||
case map[string]interface{}:
|
||||
payloadBytes, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payload OAuthStatePayload
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
default:
|
||||
return &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return "", errors.New("oauth handoff storage unavailable")
|
||||
}
|
||||
if loginResp == nil {
|
||||
return "", errors.New("oauth handoff payload is required")
|
||||
}
|
||||
|
||||
code, err := generateOAuthEphemeralCode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil, errors.New("oauth handoff storage unavailable")
|
||||
}
|
||||
|
||||
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
|
||||
value, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return nil, errors.New("OAuth handoff code is invalid or expired")
|
||||
}
|
||||
_ = s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
switch typed := value.(type) {
|
||||
case *LoginResponse:
|
||||
return typed, nil
|
||||
case LoginResponse:
|
||||
resp := typed
|
||||
return &resp, nil
|
||||
case map[string]interface{}:
|
||||
payload, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp LoginResponse
|
||||
if err := json.Unmarshal(payload, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
default:
|
||||
payload, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp LoginResponse
|
||||
if err := json.Unmarshal(payload, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user