1451 lines
39 KiB
Go
1451 lines
39 KiB
Go
package service
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"log"
|
||
"strings"
|
||
"time"
|
||
"unicode"
|
||
"unicode/utf8"
|
||
|
||
"github.com/user-management-system/internal/auth"
|
||
"github.com/user-management-system/internal/cache"
|
||
"github.com/user-management-system/internal/domain"
|
||
"github.com/user-management-system/internal/repository"
|
||
"github.com/user-management-system/internal/security"
|
||
)
|
||
|
||
const (
|
||
userInfoCachePrefix = "auth_user_info:"
|
||
tokenBlacklistPrefix = "auth_token_blacklist:"
|
||
defaultUserCacheTTL = 15 * time.Minute
|
||
defaultBlacklistTTL = time.Hour
|
||
defaultPasswordMinLen = 8
|
||
)
|
||
|
||
type userRepositoryInterface interface {
|
||
Create(ctx context.Context, user *domain.User) error
|
||
Update(ctx context.Context, user *domain.User) error
|
||
UpdateTOTP(ctx context.Context, user *domain.User) error
|
||
Delete(ctx context.Context, id int64) error
|
||
GetByID(ctx context.Context, id int64) (*domain.User, error)
|
||
GetByUsername(ctx context.Context, username string) (*domain.User, error)
|
||
GetByEmail(ctx context.Context, email string) (*domain.User, error)
|
||
GetByPhone(ctx context.Context, phone string) (*domain.User, error)
|
||
List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error)
|
||
ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error)
|
||
UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error
|
||
UpdateLastLogin(ctx context.Context, id int64, ip string) error
|
||
ExistsByUsername(ctx context.Context, username string) (bool, error)
|
||
ExistsByEmail(ctx context.Context, email string) (bool, error)
|
||
ExistsByPhone(ctx context.Context, phone string) (bool, error)
|
||
Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error)
|
||
}
|
||
|
||
type userRoleRepositoryInterface interface {
|
||
BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error
|
||
GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error)
|
||
}
|
||
|
||
type roleRepositoryInterface interface {
|
||
GetDefaultRoles(ctx context.Context) ([]*domain.Role, error)
|
||
GetByCode(ctx context.Context, code string) (*domain.Role, error)
|
||
}
|
||
|
||
type loginLogRepositoryInterface interface {
|
||
Create(ctx context.Context, loginRecord *domain.LoginLog) error
|
||
}
|
||
|
||
type anomalyRecorder interface {
|
||
RecordLogin(ctx context.Context, userID int64, ip, location, deviceFingerprint string, success bool) []security.AnomalyEvent
|
||
}
|
||
|
||
type PasswordStrengthInfo struct {
|
||
Score int `json:"score"`
|
||
Length int `json:"length"`
|
||
HasUpper bool `json:"has_upper"`
|
||
HasLower bool `json:"has_lower"`
|
||
HasDigit bool `json:"has_digit"`
|
||
HasSpecial bool `json:"has_special"`
|
||
}
|
||
|
||
type RegisterRequest struct {
|
||
Username string `json:"username" binding:"required"`
|
||
Email string `json:"email"`
|
||
Phone string `json:"phone"`
|
||
PhoneCode string `json:"phone_code"`
|
||
Password string `json:"password" binding:"required"`
|
||
Nickname string `json:"nickname"`
|
||
}
|
||
|
||
type LoginRequest struct {
|
||
Account string `json:"account"`
|
||
Username string `json:"username"`
|
||
Email string `json:"email"`
|
||
Phone string `json:"phone"`
|
||
Password string `json:"password"`
|
||
Remember bool `json:"remember"` // 记住登录
|
||
DeviceID string `json:"device_id,omitempty"` // 设备唯一标识
|
||
DeviceName string `json:"device_name,omitempty"` // 设备名称
|
||
DeviceBrowser string `json:"device_browser,omitempty"` // 浏览器
|
||
DeviceOS string `json:"device_os,omitempty"` // 操作系统
|
||
}
|
||
|
||
func (r *LoginRequest) GetAccount() string {
|
||
if r == nil {
|
||
return ""
|
||
}
|
||
for _, candidate := range []string{r.Account, r.Username, r.Email, r.Phone} {
|
||
if trimmed := strings.TrimSpace(candidate); trimmed != "" {
|
||
return trimmed
|
||
}
|
||
}
|
||
return ""
|
||
}
|
||
|
||
type UserInfo struct {
|
||
ID int64 `json:"id"`
|
||
Username string `json:"username"`
|
||
Email string `json:"email,omitempty"`
|
||
Phone string `json:"phone,omitempty"`
|
||
Nickname string `json:"nickname,omitempty"`
|
||
Avatar string `json:"avatar,omitempty"`
|
||
Status domain.UserStatus `json:"status"`
|
||
}
|
||
|
||
type LoginResponse struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
ExpiresIn int64 `json:"expires_in"`
|
||
User *UserInfo `json:"user"`
|
||
}
|
||
|
||
type LogoutRequest struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
}
|
||
|
||
type AuthService struct {
|
||
userRepo userRepositoryInterface
|
||
socialRepo repository.SocialAccountRepository
|
||
jwtManager *auth.JWT
|
||
cache *cache.CacheManager
|
||
passwordMinLength int
|
||
maxLoginAttempts int
|
||
loginLockDuration time.Duration
|
||
|
||
userRoleRepo userRoleRepositoryInterface
|
||
roleRepo roleRepositoryInterface
|
||
loginLogRepo loginLogRepositoryInterface
|
||
|
||
webhookSvc *WebhookService
|
||
passwordPolicy security.PasswordPolicy
|
||
passwordPolicySet bool
|
||
anomalyDetector anomalyRecorder
|
||
smsCodeSvc *SMSCodeService
|
||
emailActivationSvc *EmailActivationService
|
||
emailCodeSvc *EmailCodeService
|
||
oauthManager auth.OAuthManager
|
||
deviceService *DeviceService
|
||
}
|
||
|
||
func NewAuthService(
|
||
userRepo userRepositoryInterface,
|
||
socialRepo repository.SocialAccountRepository,
|
||
jwtManager *auth.JWT,
|
||
cacheManager *cache.CacheManager,
|
||
passwordMinLength int,
|
||
maxLoginAttempts int,
|
||
loginLockDuration time.Duration,
|
||
) *AuthService {
|
||
if passwordMinLength <= 0 {
|
||
passwordMinLength = defaultPasswordMinLen
|
||
}
|
||
if maxLoginAttempts <= 0 {
|
||
maxLoginAttempts = 5
|
||
}
|
||
if loginLockDuration <= 0 {
|
||
loginLockDuration = 15 * time.Minute
|
||
}
|
||
|
||
return &AuthService{
|
||
userRepo: userRepo,
|
||
socialRepo: socialRepo,
|
||
jwtManager: jwtManager,
|
||
cache: cacheManager,
|
||
passwordMinLength: passwordMinLength,
|
||
maxLoginAttempts: maxLoginAttempts,
|
||
loginLockDuration: loginLockDuration,
|
||
oauthManager: auth.NewOAuthManager(),
|
||
}
|
||
}
|
||
|
||
func (s *AuthService) SetWebhookService(webhookSvc *WebhookService) {
|
||
s.webhookSvc = webhookSvc
|
||
}
|
||
|
||
func (s *AuthService) SetRoleRepositories(userRoleRepo userRoleRepositoryInterface, roleRepo roleRepositoryInterface) {
|
||
s.userRoleRepo = userRoleRepo
|
||
s.roleRepo = roleRepo
|
||
}
|
||
|
||
func (s *AuthService) SetLoginLogRepository(loginLogRepo loginLogRepositoryInterface) {
|
||
s.loginLogRepo = loginLogRepo
|
||
}
|
||
|
||
func (s *AuthService) SetPasswordPolicy(policy security.PasswordPolicy) {
|
||
s.passwordPolicy = policy.Normalize()
|
||
s.passwordPolicySet = true
|
||
}
|
||
|
||
func (s *AuthService) SetAnomalyDetector(detector anomalyRecorder) {
|
||
s.anomalyDetector = detector
|
||
}
|
||
|
||
func (s *AuthService) SetDeviceService(svc *DeviceService) {
|
||
s.deviceService = svc
|
||
}
|
||
|
||
func (s *AuthService) SetSMSCodeService(svc *SMSCodeService) {
|
||
s.smsCodeSvc = svc
|
||
}
|
||
|
||
func sanitizeUsername(value string) string {
|
||
trimmed := strings.TrimSpace(value)
|
||
if trimmed == "" {
|
||
return "user"
|
||
}
|
||
|
||
var builder strings.Builder
|
||
lastUnderscore := false
|
||
for _, r := range trimmed {
|
||
switch {
|
||
case unicode.IsLetter(r) || unicode.IsDigit(r):
|
||
builder.WriteRune(unicode.ToLower(r))
|
||
lastUnderscore = false
|
||
case r == '.' || r == '-' || r == '_':
|
||
builder.WriteRune(r)
|
||
lastUnderscore = false
|
||
case unicode.IsSpace(r):
|
||
if !lastUnderscore && builder.Len() > 0 {
|
||
builder.WriteByte('_')
|
||
lastUnderscore = true
|
||
}
|
||
}
|
||
}
|
||
|
||
result := strings.Trim(builder.String(), "._-")
|
||
if result == "" {
|
||
return "user"
|
||
}
|
||
|
||
runes := []rune(result)
|
||
if len(runes) > 50 {
|
||
result = string(runes[:50])
|
||
}
|
||
|
||
return result
|
||
}
|
||
|
||
func (s *AuthService) generateUniqueUsername(ctx context.Context, base string) (string, error) {
|
||
username := sanitizeUsername(base)
|
||
if s == nil || s.userRepo == nil {
|
||
return username, nil
|
||
}
|
||
|
||
exists, err := s.userRepo.ExistsByUsername(ctx, username)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if !exists {
|
||
return username, nil
|
||
}
|
||
|
||
baseRunes := []rune(username)
|
||
if len(baseRunes) > 40 {
|
||
username = string(baseRunes[:40])
|
||
}
|
||
|
||
for i := 1; i <= 1000; i++ {
|
||
candidate := fmt.Sprintf("%s_%d", username, i)
|
||
exists, err = s.userRepo.ExistsByUsername(ctx, candidate)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
if !exists {
|
||
return candidate, nil
|
||
}
|
||
}
|
||
|
||
return "", errors.New("unable to generate unique username")
|
||
}
|
||
|
||
func validatePasswordStrength(password string, minLength int, strict bool) error {
|
||
if minLength <= 0 {
|
||
minLength = defaultPasswordMinLen
|
||
}
|
||
|
||
info := GetPasswordStrength(password)
|
||
if info.Length < minLength {
|
||
return fmt.Errorf("密码长度不能少于%d位", minLength)
|
||
}
|
||
|
||
if strict {
|
||
if !info.HasUpper || !info.HasLower || !info.HasDigit {
|
||
return errors.New("密码必须包含大小写字母和数字")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
if info.Score < 2 {
|
||
return errors.New("密码强度不足")
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func GetPasswordStrength(password string) PasswordStrengthInfo {
|
||
info := PasswordStrengthInfo{
|
||
Length: utf8.RuneCountInString(password),
|
||
}
|
||
|
||
for _, r := range password {
|
||
switch {
|
||
case unicode.IsUpper(r):
|
||
info.HasUpper = true
|
||
case unicode.IsLower(r):
|
||
info.HasLower = true
|
||
case unicode.IsDigit(r):
|
||
info.HasDigit = true
|
||
case unicode.IsPunct(r) || unicode.IsSymbol(r):
|
||
info.HasSpecial = true
|
||
}
|
||
}
|
||
|
||
if info.HasUpper {
|
||
info.Score++
|
||
}
|
||
if info.HasLower {
|
||
info.Score++
|
||
}
|
||
if info.HasDigit {
|
||
info.Score++
|
||
}
|
||
if info.HasSpecial {
|
||
info.Score++
|
||
}
|
||
|
||
return info
|
||
}
|
||
|
||
func (s *AuthService) validatePassword(password string) error {
|
||
if s != nil && s.passwordPolicySet {
|
||
return s.passwordPolicy.Validate(password)
|
||
}
|
||
minLength := defaultPasswordMinLen
|
||
if s != nil && s.passwordMinLength > 0 {
|
||
minLength = s.passwordMinLength
|
||
}
|
||
return validatePasswordStrength(password, minLength, false)
|
||
}
|
||
|
||
func (s *AuthService) accessTokenTTLSeconds() int64 {
|
||
if s == nil || s.jwtManager == nil {
|
||
return 0
|
||
}
|
||
return int64(s.jwtManager.GetAccessTokenExpire().Seconds())
|
||
}
|
||
|
||
func (s *AuthService) RefreshTokenTTLSeconds() int64 {
|
||
if s == nil || s.jwtManager == nil {
|
||
return 0
|
||
}
|
||
return int64(s.jwtManager.GetRefreshTokenExpire().Seconds())
|
||
}
|
||
|
||
func (s *AuthService) buildUserInfo(user *domain.User) *UserInfo {
|
||
if user == nil {
|
||
return nil
|
||
}
|
||
|
||
return &UserInfo{
|
||
ID: user.ID,
|
||
Username: user.Username,
|
||
Email: domain.DerefStr(user.Email),
|
||
Phone: domain.DerefStr(user.Phone),
|
||
Nickname: user.Nickname,
|
||
Avatar: user.Avatar,
|
||
Status: user.Status,
|
||
}
|
||
}
|
||
|
||
func (s *AuthService) ensureUserActive(user *domain.User) error {
|
||
if user == nil {
|
||
return errors.New("用户不存在")
|
||
}
|
||
|
||
switch user.Status {
|
||
case domain.UserStatusActive:
|
||
return nil
|
||
case domain.UserStatusInactive:
|
||
return errors.New("账号未激活")
|
||
case domain.UserStatusLocked:
|
||
return errors.New("账号已锁定")
|
||
case domain.UserStatusDisabled:
|
||
return errors.New("账号已禁用")
|
||
default:
|
||
return errors.New("账号状态异常")
|
||
}
|
||
}
|
||
|
||
func (s *AuthService) blacklistTokenClaims(ctx context.Context, token string, validate func(string) (*auth.Claims, error)) error {
|
||
if s == nil || s.cache == nil {
|
||
return nil
|
||
}
|
||
|
||
token = strings.TrimSpace(token)
|
||
if token == "" || validate == nil {
|
||
return nil
|
||
}
|
||
|
||
claims, err := validate(token)
|
||
if err != nil || claims == nil || strings.TrimSpace(claims.JTI) == "" {
|
||
return nil
|
||
}
|
||
|
||
ttl := defaultBlacklistTTL
|
||
if claims.ExpiresAt != nil {
|
||
if until := time.Until(claims.ExpiresAt.Time); until > 0 {
|
||
ttl = until
|
||
}
|
||
}
|
||
|
||
return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl)
|
||
}
|
||
|
||
func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) {
|
||
if s == nil || s.anomalyDetector == nil || userID == nil {
|
||
return
|
||
}
|
||
|
||
events := s.anomalyDetector.RecordLogin(ctx, *userID, ip, location, deviceFingerprint, success)
|
||
if len(events) == 0 {
|
||
return
|
||
}
|
||
|
||
s.publishEvent(ctx, domain.EventAnomalyDetected, map[string]interface{}{
|
||
"user_id": *userID,
|
||
"ip": ip,
|
||
"location": location,
|
||
"device": deviceFingerprint,
|
||
"events": events,
|
||
"success": success,
|
||
})
|
||
}
|
||
|
||
func (s *AuthService) publishEvent(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
||
if s == nil || s.webhookSvc == nil {
|
||
return
|
||
}
|
||
|
||
go s.webhookSvc.Publish(ctx, eventType, data)
|
||
}
|
||
|
||
func (s *AuthService) writeLoginLog(
|
||
ctx context.Context,
|
||
userID *int64,
|
||
loginType domain.LoginType,
|
||
ip string,
|
||
success bool,
|
||
failReason string,
|
||
) {
|
||
if s == nil || s.loginLogRepo == nil {
|
||
return
|
||
}
|
||
|
||
status := 0
|
||
if success {
|
||
status = 1
|
||
}
|
||
|
||
loginRecord := &domain.LoginLog{
|
||
UserID: userID,
|
||
LoginType: int(loginType),
|
||
IP: ip,
|
||
Status: status,
|
||
FailReason: failReason,
|
||
}
|
||
|
||
go func() {
|
||
if err := s.loginLogRepo.Create(context.Background(), loginRecord); err != nil {
|
||
log.Printf("auth: write login log failed, user_id=%v login_type=%d err=%v", userID, loginType, err)
|
||
}
|
||
}()
|
||
}
|
||
|
||
func (s *AuthService) incrementFailAttempts(ctx context.Context, key string) int {
|
||
if s == nil || s.cache == nil || key == "" {
|
||
return 0
|
||
}
|
||
|
||
current := 0
|
||
if value, ok := s.cache.Get(ctx, key); ok {
|
||
current = attemptCount(value)
|
||
}
|
||
current++
|
||
|
||
if err := s.cache.Set(ctx, key, current, s.loginLockDuration, s.loginLockDuration); err != nil {
|
||
log.Printf("auth: store login attempts failed, key=%s err=%v", key, err)
|
||
}
|
||
|
||
return current
|
||
}
|
||
|
||
func isValidPhoneSimple(phone string) bool {
|
||
return isValidPhone(phone)
|
||
}
|
||
|
||
// buildDeviceFingerprint 构建设备指纹字符串
|
||
func buildDeviceFingerprint(req *LoginRequest) string {
|
||
if req == nil {
|
||
return ""
|
||
}
|
||
var parts []string
|
||
if req.DeviceID != "" {
|
||
parts = append(parts, req.DeviceID)
|
||
}
|
||
if req.DeviceName != "" {
|
||
parts = append(parts, req.DeviceName)
|
||
}
|
||
if req.DeviceBrowser != "" {
|
||
parts = append(parts, req.DeviceBrowser)
|
||
}
|
||
if req.DeviceOS != "" {
|
||
parts = append(parts, req.DeviceOS)
|
||
}
|
||
result := strings.Join(parts, "|")
|
||
if result == "" {
|
||
return ""
|
||
}
|
||
return result
|
||
}
|
||
|
||
// bestEffortRegisterDevice 尝试自动注册/更新设备记录
|
||
func (s *AuthService) bestEffortRegisterDevice(ctx context.Context, userID int64, req *LoginRequest) {
|
||
if s == nil || s.deviceService == nil || req == nil || req.DeviceID == "" {
|
||
return
|
||
}
|
||
|
||
createReq := &CreateDeviceRequest{
|
||
DeviceID: req.DeviceID,
|
||
DeviceName: req.DeviceName,
|
||
DeviceBrowser: req.DeviceBrowser,
|
||
DeviceOS: req.DeviceOS,
|
||
}
|
||
_, _ = s.deviceService.CreateDevice(ctx, userID, createReq)
|
||
}
|
||
|
||
func (s *AuthService) cacheUserInfo(ctx context.Context, user *domain.User) {
|
||
if s == nil || s.cache == nil || user == nil {
|
||
return
|
||
}
|
||
info := s.buildUserInfo(user)
|
||
if info == nil {
|
||
return
|
||
}
|
||
_ = s.cache.Set(ctx, userInfoCachePrefix+fmt.Sprintf("%d", user.ID), info, defaultUserCacheTTL, defaultUserCacheTTL)
|
||
}
|
||
|
||
func userInfoFromCacheValue(value interface{}) (*UserInfo, bool) {
|
||
switch typed := value.(type) {
|
||
case *UserInfo:
|
||
return typed, true
|
||
case UserInfo:
|
||
userInfo := typed
|
||
return &userInfo, true
|
||
case map[string]interface{}:
|
||
payload, err := json.Marshal(typed)
|
||
if err != nil {
|
||
return nil, false
|
||
}
|
||
var userInfo UserInfo
|
||
if err := json.Unmarshal(payload, &userInfo); err != nil {
|
||
return nil, false
|
||
}
|
||
return &userInfo, true
|
||
default:
|
||
return nil, false
|
||
}
|
||
}
|
||
|
||
func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
|
||
if req == nil {
|
||
return nil, errors.New("注册请求不能为空")
|
||
}
|
||
if s == nil || s.userRepo == nil {
|
||
return nil, errors.New("user repository is not configured")
|
||
}
|
||
|
||
req.Username = strings.TrimSpace(req.Username)
|
||
req.Email = strings.TrimSpace(req.Email)
|
||
req.Phone = strings.TrimSpace(req.Phone)
|
||
|
||
if req.Username == "" {
|
||
return nil, errors.New("用户名不能为空")
|
||
}
|
||
if req.Password == "" {
|
||
return nil, errors.New("密码不能为空")
|
||
}
|
||
if req.Phone != "" && !isValidPhoneSimple(req.Phone) {
|
||
return nil, errors.New("手机号格式不正确")
|
||
}
|
||
if err := s.validatePassword(req.Password); err != nil {
|
||
return nil, err
|
||
}
|
||
if err := s.verifyPhoneRegistration(ctx, req); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if exists {
|
||
return nil, errors.New("用户名已存在")
|
||
}
|
||
|
||
if req.Email != "" {
|
||
exists, err = s.userRepo.ExistsByEmail(ctx, req.Email)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if exists {
|
||
return nil, errors.New("邮箱已存在")
|
||
}
|
||
}
|
||
|
||
if req.Phone != "" {
|
||
exists, err = s.userRepo.ExistsByPhone(ctx, req.Phone)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if exists {
|
||
return nil, errors.New("手机号已存在")
|
||
}
|
||
}
|
||
|
||
hashedPassword, err := auth.HashPassword(req.Password)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
nickname := strings.TrimSpace(req.Nickname)
|
||
if nickname == "" {
|
||
nickname = req.Username
|
||
}
|
||
|
||
user := &domain.User{
|
||
Username: req.Username,
|
||
Email: domain.StrPtr(req.Email),
|
||
Phone: domain.StrPtr(req.Phone),
|
||
Password: hashedPassword,
|
||
Nickname: nickname,
|
||
Status: domain.UserStatusActive,
|
||
}
|
||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register")
|
||
s.cacheUserInfo(ctx, user)
|
||
|
||
userInfo := s.buildUserInfo(user)
|
||
s.publishEvent(ctx, domain.EventUserRegistered, userInfo)
|
||
return userInfo, nil
|
||
}
|
||
|
||
func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (*LoginResponse, error) {
|
||
if req == nil {
|
||
return nil, errors.New("登录请求不能为空")
|
||
}
|
||
if s == nil || s.userRepo == nil || s.jwtManager == nil {
|
||
return nil, errors.New("auth service is not fully configured")
|
||
}
|
||
|
||
account := req.GetAccount()
|
||
if account == "" {
|
||
return nil, errors.New("账号不能为空")
|
||
}
|
||
if strings.TrimSpace(req.Password) == "" {
|
||
return nil, errors.New("密码不能为空")
|
||
}
|
||
|
||
// 构建设备指纹
|
||
deviceFingerprint := buildDeviceFingerprint(req)
|
||
|
||
user, err := s.findUserForLogin(ctx, account)
|
||
if err != nil && !isUserNotFoundError(err) {
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, err.Error())
|
||
return nil, err
|
||
}
|
||
|
||
attemptKey := loginAttemptKey(account, user)
|
||
if s.cache != nil {
|
||
if value, ok := s.cache.Get(ctx, attemptKey); ok && attemptCount(value) >= s.maxLoginAttempts {
|
||
lockErr := errors.New("账号已锁定,请稍后再试")
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, lockErr.Error())
|
||
return nil, lockErr
|
||
}
|
||
}
|
||
|
||
if user == nil {
|
||
s.incrementFailAttempts(ctx, attemptKey)
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypePassword, ip, false, "用户不存在")
|
||
return nil, errors.New("账号或密码错误")
|
||
}
|
||
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, err.Error())
|
||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false)
|
||
return nil, err
|
||
}
|
||
|
||
if !auth.VerifyPassword(user.Password, req.Password) {
|
||
failCount := s.incrementFailAttempts(ctx, attemptKey)
|
||
failErr := errors.New("账号或密码错误")
|
||
if failCount >= s.maxLoginAttempts {
|
||
s.publishEvent(ctx, domain.EventUserLocked, map[string]interface{}{
|
||
"user_id": user.ID,
|
||
"username": user.Username,
|
||
"ip": ip,
|
||
})
|
||
}
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, false, failErr.Error())
|
||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, false)
|
||
s.publishEvent(ctx, domain.EventLoginFailed, map[string]interface{}{
|
||
"user_id": user.ID,
|
||
"username": user.Username,
|
||
"ip": ip,
|
||
})
|
||
return nil, failErr
|
||
}
|
||
|
||
if s.cache != nil {
|
||
_ = s.cache.Delete(ctx, attemptKey)
|
||
}
|
||
|
||
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "password")
|
||
s.cacheUserInfo(ctx, user)
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, true, "")
|
||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", deviceFingerprint, true)
|
||
s.bestEffortRegisterDevice(ctx, user.ID, req)
|
||
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
|
||
"user_id": user.ID,
|
||
"username": user.Username,
|
||
"ip": ip,
|
||
"method": "password",
|
||
})
|
||
|
||
return s.generateLoginResponse(ctx, user, req.Remember)
|
||
}
|
||
|
||
func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*LoginResponse, error) {
|
||
if s == nil || s.jwtManager == nil || s.userRepo == nil {
|
||
return nil, errors.New("auth service is not fully configured")
|
||
}
|
||
|
||
claims, err := s.jwtManager.ValidateRefreshToken(strings.TrimSpace(refreshToken))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if s.IsTokenBlacklisted(ctx, claims.JTI) {
|
||
return nil, errors.New("refresh token has been revoked")
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, claims.UserID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return s.generateLoginResponse(ctx, user, claims.Remember)
|
||
}
|
||
|
||
func (s *AuthService) GetUserInfo(ctx context.Context, userID int64) (*UserInfo, error) {
|
||
if s == nil || s.userRepo == nil {
|
||
return nil, errors.New("user repository is not configured")
|
||
}
|
||
|
||
if s.cache != nil {
|
||
cacheKey := userInfoCachePrefix + fmt.Sprintf("%d", userID)
|
||
if value, ok := s.cache.Get(ctx, cacheKey); ok {
|
||
if info, ok := userInfoFromCacheValue(value); ok {
|
||
return info, nil
|
||
}
|
||
}
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.cacheUserInfo(ctx, user)
|
||
return s.buildUserInfo(user), nil
|
||
}
|
||
|
||
func (s *AuthService) Logout(ctx context.Context, username string, req *LogoutRequest) error {
|
||
if s == nil {
|
||
return nil
|
||
}
|
||
if req == nil {
|
||
return nil
|
||
}
|
||
|
||
_ = s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) {
|
||
if s.jwtManager == nil {
|
||
return nil, nil
|
||
}
|
||
return s.jwtManager.ValidateAccessToken(token)
|
||
})
|
||
_ = s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) {
|
||
if s.jwtManager == nil {
|
||
return nil, nil
|
||
}
|
||
return s.jwtManager.ValidateRefreshToken(token)
|
||
})
|
||
|
||
if strings.TrimSpace(username) != "" {
|
||
s.publishEvent(ctx, domain.EventUserLogout, map[string]interface{}{
|
||
"username": strings.TrimSpace(username),
|
||
})
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func (s *AuthService) IsTokenBlacklisted(ctx context.Context, jti string) bool {
|
||
if s == nil || s.cache == nil {
|
||
return false
|
||
}
|
||
jti = strings.TrimSpace(jti)
|
||
if jti == "" {
|
||
return false
|
||
}
|
||
_, ok := s.cache.Get(ctx, tokenBlacklistPrefix+jti)
|
||
return ok
|
||
}
|
||
|
||
func (s *AuthService) OAuthLogin(ctx context.Context, provider, state string) (string, error) {
|
||
if s == nil || s.oauthManager == nil {
|
||
return "", errors.New("oauth manager is not configured")
|
||
}
|
||
return s.oauthManager.GetAuthURL(auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider))), state)
|
||
}
|
||
|
||
func (s *AuthService) OAuthCallback(ctx context.Context, provider, code string) (*LoginResponse, error) {
|
||
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return nil, errors.New("oauth login is not fully configured")
|
||
}
|
||
|
||
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
||
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if oauthUser == nil {
|
||
return nil, errors.New("oauth user info is empty")
|
||
}
|
||
|
||
socialAccount, err := s.socialRepo.GetByProviderAndOpenID(ctx, string(oauthProvider), oauthUser.OpenID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
var user *domain.User
|
||
if socialAccount != nil {
|
||
user, err = s.userRepo.GetByID(ctx, socialAccount.UserID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
socialAccount.UnionID = oauthUser.UnionID
|
||
socialAccount.Nickname = oauthUser.Nickname
|
||
socialAccount.Avatar = oauthUser.Avatar
|
||
socialAccount.Gender = oauthUser.Gender
|
||
socialAccount.Email = oauthUser.Email
|
||
socialAccount.Phone = oauthUser.Phone
|
||
socialAccount.Status = domain.SocialAccountStatusActive
|
||
if oauthUser.Extra != nil {
|
||
socialAccount.Extra = oauthUser.Extra
|
||
}
|
||
if err := s.socialRepo.Update(ctx, socialAccount); err != nil {
|
||
log.Printf("auth: update social account failed, provider=%s open_id=%s err=%v", oauthProvider, oauthUser.OpenID, err)
|
||
}
|
||
} else {
|
||
if strings.TrimSpace(oauthUser.Email) != "" {
|
||
user, err = s.userRepo.GetByEmail(ctx, strings.TrimSpace(oauthUser.Email))
|
||
if err != nil {
|
||
if !isUserNotFoundError(err) {
|
||
return nil, err
|
||
}
|
||
user = nil
|
||
}
|
||
}
|
||
|
||
if user == nil {
|
||
baseUsername := oauthUser.Nickname
|
||
if baseUsername == "" && oauthUser.Email != "" {
|
||
baseUsername = strings.Split(strings.TrimSpace(oauthUser.Email), "@")[0]
|
||
}
|
||
if baseUsername == "" {
|
||
baseUsername = string(oauthProvider) + "_" + oauthUser.OpenID
|
||
}
|
||
|
||
username, err := s.generateUniqueUsername(ctx, baseUsername)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
user = &domain.User{
|
||
Username: username,
|
||
Email: domain.StrPtr(strings.TrimSpace(oauthUser.Email)),
|
||
Phone: domain.StrPtr(strings.TrimSpace(oauthUser.Phone)),
|
||
Nickname: strings.TrimSpace(oauthUser.Nickname),
|
||
Avatar: strings.TrimSpace(oauthUser.Avatar),
|
||
Status: domain.UserStatusActive,
|
||
}
|
||
if user.Nickname == "" {
|
||
user.Nickname = user.Username
|
||
}
|
||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||
return nil, err
|
||
}
|
||
s.bestEffortAssignDefaultRoles(ctx, user.ID, "oauth")
|
||
s.publishEvent(ctx, domain.EventUserRegistered, s.buildUserInfo(user))
|
||
}
|
||
|
||
socialAccount = &domain.SocialAccount{
|
||
UserID: user.ID,
|
||
Provider: string(oauthProvider),
|
||
OpenID: oauthUser.OpenID,
|
||
UnionID: oauthUser.UnionID,
|
||
Nickname: oauthUser.Nickname,
|
||
Avatar: oauthUser.Avatar,
|
||
Gender: oauthUser.Gender,
|
||
Email: oauthUser.Email,
|
||
Phone: oauthUser.Phone,
|
||
Status: domain.SocialAccountStatusActive,
|
||
}
|
||
if oauthUser.Extra != nil {
|
||
socialAccount.Extra = oauthUser.Extra
|
||
}
|
||
if err := s.socialRepo.Create(ctx, socialAccount); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.bestEffortUpdateLastLogin(ctx, user.ID, "", "oauth")
|
||
s.cacheUserInfo(ctx, user)
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeOAuth, "", true, "")
|
||
s.recordLoginAnomaly(ctx, &user.ID, "", "", "", true)
|
||
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
|
||
"user_id": user.ID,
|
||
"username": user.Username,
|
||
"method": "oauth",
|
||
"provider": string(oauthProvider),
|
||
})
|
||
|
||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||
}
|
||
|
||
func (s *AuthService) StartSocialAccountBinding(
|
||
ctx context.Context,
|
||
userID int64,
|
||
provider string,
|
||
returnTo string,
|
||
currentPassword string,
|
||
totpCode string,
|
||
) (string, string, error) {
|
||
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return "", "", errors.New("social account binding is not fully configured")
|
||
}
|
||
|
||
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return "", "", err
|
||
}
|
||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||
return "", "", err
|
||
}
|
||
|
||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
if existing := findSocialAccountByProvider(accounts, normalizedProvider); existing != nil {
|
||
return "", "", auth.ErrOAuthAlreadyBound
|
||
}
|
||
|
||
state, err := s.CreateOAuthBindState(ctx, userID, returnTo)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
|
||
authURL, err := s.OAuthLogin(ctx, normalizedProvider, state)
|
||
if err != nil {
|
||
return "", "", err
|
||
}
|
||
|
||
return authURL, state, nil
|
||
}
|
||
|
||
func (s *AuthService) OAuthBindCallback(ctx context.Context, userID int64, provider, code string) (*domain.SocialAccountInfo, error) {
|
||
if s == nil || s.oauthManager == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return nil, errors.New("social account binding is not fully configured")
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
oauthProvider := auth.OAuthProvider(strings.ToLower(strings.TrimSpace(provider)))
|
||
token, err := s.oauthManager.ExchangeCode(oauthProvider, strings.TrimSpace(code))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
oauthUser, err := s.oauthManager.GetUserInfo(oauthProvider, token)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if oauthUser == nil {
|
||
return nil, errors.New("oauth user info is empty")
|
||
}
|
||
|
||
account, err := s.upsertOAuthSocialAccount(ctx, userID, oauthProvider, oauthUser)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return account.ToInfo(), nil
|
||
}
|
||
|
||
func (s *AuthService) upsertOAuthSocialAccount(
|
||
ctx context.Context,
|
||
userID int64,
|
||
provider auth.OAuthProvider,
|
||
oauthUser *auth.OAuthUser,
|
||
) (*domain.SocialAccount, error) {
|
||
if s == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return nil, errors.New("social account binding is not configured")
|
||
}
|
||
if oauthUser == nil {
|
||
return nil, errors.New("oauth user info is empty")
|
||
}
|
||
|
||
normalizedProvider := strings.ToLower(strings.TrimSpace(string(provider)))
|
||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if currentProviderBinding := findSocialAccountByProvider(accounts, normalizedProvider); currentProviderBinding != nil &&
|
||
!strings.EqualFold(strings.TrimSpace(currentProviderBinding.OpenID), strings.TrimSpace(oauthUser.OpenID)) {
|
||
return nil, errors.New("provider already bound to current account")
|
||
}
|
||
|
||
existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, strings.TrimSpace(oauthUser.OpenID))
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if existing != nil {
|
||
if existing.UserID != userID {
|
||
return nil, auth.ErrOAuthAlreadyBound
|
||
}
|
||
existing.UnionID = oauthUser.UnionID
|
||
existing.Nickname = oauthUser.Nickname
|
||
existing.Avatar = oauthUser.Avatar
|
||
existing.Gender = oauthUser.Gender
|
||
existing.Email = oauthUser.Email
|
||
existing.Phone = oauthUser.Phone
|
||
existing.Status = domain.SocialAccountStatusActive
|
||
if oauthUser.Extra != nil {
|
||
existing.Extra = oauthUser.Extra
|
||
}
|
||
if err := s.socialRepo.Update(ctx, existing); err != nil {
|
||
return nil, err
|
||
}
|
||
return existing, nil
|
||
}
|
||
|
||
account := &domain.SocialAccount{
|
||
UserID: userID,
|
||
Provider: normalizedProvider,
|
||
OpenID: strings.TrimSpace(oauthUser.OpenID),
|
||
UnionID: oauthUser.UnionID,
|
||
Nickname: oauthUser.Nickname,
|
||
Avatar: oauthUser.Avatar,
|
||
Gender: oauthUser.Gender,
|
||
Email: oauthUser.Email,
|
||
Phone: oauthUser.Phone,
|
||
Status: domain.SocialAccountStatusActive,
|
||
}
|
||
if oauthUser.Extra != nil {
|
||
account.Extra = oauthUser.Extra
|
||
}
|
||
if err := s.socialRepo.Create(ctx, account); err != nil {
|
||
return nil, err
|
||
}
|
||
return account, nil
|
||
}
|
||
|
||
func (s *AuthService) verifySensitiveAction(
|
||
ctx context.Context,
|
||
user *domain.User,
|
||
currentPassword string,
|
||
totpCode string,
|
||
) error {
|
||
if user == nil {
|
||
return errors.New("user is required")
|
||
}
|
||
|
||
password := strings.TrimSpace(currentPassword)
|
||
code := strings.TrimSpace(totpCode)
|
||
hasPassword := strings.TrimSpace(user.Password) != ""
|
||
hasTOTP := user.TOTPEnabled && strings.TrimSpace(user.TOTPSecret) != ""
|
||
|
||
// 如果用户既没有密码也没有启用TOTP,禁止执行敏感操作
|
||
if !hasPassword && !hasTOTP {
|
||
return errors.New("请先设置密码或启用两步验证")
|
||
}
|
||
|
||
if password != "" {
|
||
if !hasPassword || !auth.VerifyPassword(user.Password, password) {
|
||
return errors.New("当前密码不正确")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
if code != "" {
|
||
if !hasTOTP {
|
||
return errors.New("TOTP verification is not available")
|
||
}
|
||
return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code)
|
||
}
|
||
|
||
return errors.New("password or TOTP verification is required")
|
||
}
|
||
|
||
func (s *AuthService) verifyTOTPCodeOrRecoveryCode(ctx context.Context, user *domain.User, code string) error {
|
||
if user == nil {
|
||
return errors.New("user is required")
|
||
}
|
||
if !user.TOTPEnabled || strings.TrimSpace(user.TOTPSecret) == "" {
|
||
return errors.New("TOTP verification is not available")
|
||
}
|
||
|
||
manager := auth.NewTOTPManager()
|
||
if manager.ValidateCode(user.TOTPSecret, code) {
|
||
return nil
|
||
}
|
||
|
||
var hashedCodes []string
|
||
if strings.TrimSpace(user.TOTPRecoveryCodes) != "" {
|
||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
|
||
}
|
||
index, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||
if !matched {
|
||
return errors.New("TOTP code or recovery code is invalid")
|
||
}
|
||
|
||
hashedCodes = append(hashedCodes[:index], hashedCodes[index+1:]...)
|
||
payload, err := json.Marshal(hashedCodes)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
user.TOTPRecoveryCodes = string(payload)
|
||
return s.userRepo.UpdateTOTP(ctx, user)
|
||
}
|
||
|
||
// VerifyTOTP 验证 TOTP(支持设备信任跳过)
|
||
// 如果设备已信任且未过期,跳过 TOTP 验证
|
||
func (s *AuthService) VerifyTOTP(ctx context.Context, userID int64, code, deviceID string) error {
|
||
if s == nil || s.userRepo == nil {
|
||
return errors.New("auth service is not fully configured")
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return errors.New("用户不存在")
|
||
}
|
||
|
||
// 检查设备信任状态
|
||
if deviceID != "" && s.deviceService != nil {
|
||
device, err := s.deviceService.GetDeviceByDeviceID(ctx, userID, deviceID)
|
||
if err == nil && device.IsTrusted {
|
||
// 检查信任是否过期
|
||
if device.TrustExpiresAt == nil || device.TrustExpiresAt.After(time.Now()) {
|
||
return nil // 设备已信任,跳过 TOTP 验证
|
||
}
|
||
}
|
||
}
|
||
|
||
// 执行 TOTP 验证
|
||
return s.verifyTOTPCodeOrRecoveryCode(ctx, user, code)
|
||
}
|
||
|
||
func findSocialAccountByProvider(accounts []*domain.SocialAccount, provider string) *domain.SocialAccount {
|
||
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
|
||
for _, account := range accounts {
|
||
if account == nil {
|
||
continue
|
||
}
|
||
if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedProvider) {
|
||
return account
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s *AuthService) availableLoginMethodCount(
|
||
user *domain.User,
|
||
accounts []*domain.SocialAccount,
|
||
excludeProvider string,
|
||
) int {
|
||
if user == nil {
|
||
return 0
|
||
}
|
||
|
||
count := 0
|
||
if strings.TrimSpace(user.Password) != "" {
|
||
count++
|
||
}
|
||
if s.emailCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Email)) != "" {
|
||
count++
|
||
}
|
||
if s.smsCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Phone)) != "" {
|
||
count++
|
||
}
|
||
|
||
normalizedExcludeProvider := strings.ToLower(strings.TrimSpace(excludeProvider))
|
||
for _, account := range accounts {
|
||
if account == nil || account.Status != domain.SocialAccountStatusActive {
|
||
continue
|
||
}
|
||
if strings.EqualFold(strings.TrimSpace(account.Provider), normalizedExcludeProvider) {
|
||
continue
|
||
}
|
||
count++
|
||
}
|
||
|
||
return count
|
||
}
|
||
|
||
func (s *AuthService) generateLoginResponse(ctx context.Context, user *domain.User, remember bool) (*LoginResponse, error) {
|
||
if s == nil || s.jwtManager == nil {
|
||
return nil, errors.New("jwt manager is not configured")
|
||
}
|
||
if user == nil {
|
||
return nil, errors.New("user is required")
|
||
}
|
||
|
||
var accessToken, refreshToken string
|
||
var err error
|
||
|
||
if remember {
|
||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPairWithRemember(user.ID, user.Username, remember)
|
||
} else {
|
||
accessToken, refreshToken, err = s.jwtManager.GenerateTokenPair(user.ID, user.Username)
|
||
}
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.cacheUserInfo(ctx, user)
|
||
|
||
return &LoginResponse{
|
||
AccessToken: accessToken,
|
||
RefreshToken: refreshToken,
|
||
ExpiresIn: s.accessTokenTTLSeconds(),
|
||
User: s.buildUserInfo(user),
|
||
}, nil
|
||
}
|
||
|
||
// generateLoginResponseWithoutRemember 生成登录响应(不支持记住登录)
|
||
func (s *AuthService) generateLoginResponseWithoutRemember(ctx context.Context, user *domain.User) (*LoginResponse, error) {
|
||
return s.generateLoginResponse(ctx, user, false)
|
||
}
|
||
|
||
func (s *AuthService) BindSocialAccount(ctx context.Context, userID int64, provider, openID string) error {
|
||
if s == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return errors.New("social account binding is not configured")
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return err
|
||
}
|
||
|
||
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
|
||
normalizedOpenID := strings.TrimSpace(openID)
|
||
if normalizedProvider == "" || normalizedOpenID == "" {
|
||
return errors.New("provider and open_id are required")
|
||
}
|
||
|
||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if existingProvider := findSocialAccountByProvider(accounts, normalizedProvider); existingProvider != nil &&
|
||
!strings.EqualFold(strings.TrimSpace(existingProvider.OpenID), normalizedOpenID) {
|
||
return errors.New("provider already bound to current account")
|
||
}
|
||
|
||
existing, err := s.socialRepo.GetByProviderAndOpenID(ctx, normalizedProvider, normalizedOpenID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if existing != nil {
|
||
if existing.UserID == userID {
|
||
return nil
|
||
}
|
||
return auth.ErrOAuthAlreadyBound
|
||
}
|
||
|
||
return s.socialRepo.Create(ctx, &domain.SocialAccount{
|
||
UserID: userID,
|
||
Provider: normalizedProvider,
|
||
OpenID: normalizedOpenID,
|
||
Status: domain.SocialAccountStatusActive,
|
||
})
|
||
}
|
||
|
||
func (s *AuthService) UnbindSocialAccount(ctx context.Context, userID int64, provider, currentPassword, totpCode string) error {
|
||
if s == nil || s.socialRepo == nil || s.userRepo == nil {
|
||
return errors.New("social account binding is not configured")
|
||
}
|
||
|
||
user, err := s.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
return err
|
||
}
|
||
|
||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
normalizedProvider := strings.ToLower(strings.TrimSpace(provider))
|
||
if findSocialAccountByProvider(accounts, normalizedProvider) == nil {
|
||
return auth.ErrOAuthNotFound
|
||
}
|
||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||
return err
|
||
}
|
||
if s.availableLoginMethodCount(user, accounts, normalizedProvider) == 0 {
|
||
return errors.New("at least one login method must remain after unbinding")
|
||
}
|
||
|
||
return s.socialRepo.DeleteByProviderAndUserID(ctx, normalizedProvider, userID)
|
||
}
|
||
|
||
func (s *AuthService) GetSocialAccounts(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
|
||
if s == nil || s.socialRepo == nil {
|
||
return []*domain.SocialAccount{}, nil
|
||
}
|
||
|
||
accounts, err := s.socialRepo.GetByUserID(ctx, userID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if accounts == nil {
|
||
return []*domain.SocialAccount{}, nil
|
||
}
|
||
return accounts, nil
|
||
}
|
||
|
||
func (s *AuthService) GetEnabledOAuthProviders() []auth.OAuthProviderInfo {
|
||
if s == nil || s.oauthManager == nil {
|
||
return []auth.OAuthProviderInfo{}
|
||
}
|
||
|
||
providers := s.oauthManager.GetEnabledProviders()
|
||
if providers == nil {
|
||
return []auth.OAuthProviderInfo{}
|
||
}
|
||
return providers
|
||
}
|
||
|
||
func (s *AuthService) LoginByCode(ctx context.Context, phone, code, ip string) (*LoginResponse, error) {
|
||
if s == nil || s.smsCodeSvc == nil || s.userRepo == nil {
|
||
return nil, errors.New("sms code login is not configured")
|
||
}
|
||
|
||
phone = strings.TrimSpace(phone)
|
||
if phone == "" {
|
||
return nil, errors.New("手机号不能为空")
|
||
}
|
||
|
||
if err := s.smsCodeSvc.VerifyCode(ctx, phone, "login", strings.TrimSpace(code)); err != nil {
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error())
|
||
return nil, err
|
||
}
|
||
|
||
user, err := s.userRepo.GetByPhone(ctx, phone)
|
||
if err != nil {
|
||
if isUserNotFoundError(err) {
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, "手机号未注册")
|
||
return nil, errors.New("手机号未注册")
|
||
}
|
||
s.writeLoginLog(ctx, nil, domain.LoginTypeSMSCode, ip, false, err.Error())
|
||
return nil, err
|
||
}
|
||
|
||
if err := s.ensureUserActive(user); err != nil {
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, false, err.Error())
|
||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", false)
|
||
return nil, err
|
||
}
|
||
|
||
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "sms_code")
|
||
s.cacheUserInfo(ctx, user)
|
||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeSMSCode, ip, true, "")
|
||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", true)
|
||
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
|
||
"user_id": user.ID,
|
||
"username": user.Username,
|
||
"ip": ip,
|
||
"method": "sms_code",
|
||
})
|
||
|
||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||
}
|