feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
1450
internal/service/auth.go
Normal file
1450
internal/service/auth.go
Normal file
File diff suppressed because it is too large
Load Diff
116
internal/service/auth_admin_bootstrap.go
Normal file
116
internal/service/auth_admin_bootstrap.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
var ErrAdminBootstrapUnavailable = errors.New("管理员初始化入口已关闭")
|
||||
|
||||
type BootstrapAdminRequest struct {
|
||||
Username string `json:"username" binding:"required"`
|
||||
Password string `json:"password" binding:"required"`
|
||||
Email string `json:"email"`
|
||||
Nickname string `json:"nickname"`
|
||||
}
|
||||
|
||||
func (s *AuthService) BootstrapAdmin(ctx context.Context, req *BootstrapAdminRequest, ip string) (*LoginResponse, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("管理员初始化请求不能为空")
|
||||
}
|
||||
if s == nil || s.userRepo == nil || s.userRoleRepo == nil || s.roleRepo == nil || s.jwtManager == nil {
|
||||
return nil, errors.New("管理员初始化能力未正确配置")
|
||||
}
|
||||
if !s.IsAdminBootstrapRequired(ctx) {
|
||||
return nil, ErrAdminBootstrapUnavailable
|
||||
}
|
||||
|
||||
username := strings.TrimSpace(req.Username)
|
||||
email := strings.TrimSpace(req.Email)
|
||||
nickname := strings.TrimSpace(req.Nickname)
|
||||
|
||||
if username == "" {
|
||||
return nil, errors.New("用户名不能为空")
|
||||
}
|
||||
if strings.TrimSpace(req.Password) == "" {
|
||||
return nil, errors.New("密码不能为空")
|
||||
}
|
||||
if err := s.validatePassword(req.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByUsername(ctx, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("用户名已存在")
|
||||
}
|
||||
|
||||
if email != "" {
|
||||
exists, err = s.userRepo.ExistsByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("邮箱已存在")
|
||||
}
|
||||
}
|
||||
|
||||
adminRole, err := s.roleRepo.GetByCode(ctx, adminRoleCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if adminRole == nil || adminRole.Status != domain.RoleStatusEnabled {
|
||||
return nil, errors.New("管理员角色不可用")
|
||||
}
|
||||
|
||||
passwordHash, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if nickname == "" {
|
||||
nickname = username
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
Username: username,
|
||||
Email: domain.StrPtr(email),
|
||||
Password: passwordHash,
|
||||
Nickname: nickname,
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.userRoleRepo.BatchCreate(ctx, []*domain.UserRole{
|
||||
{UserID: user.ID, RoleID: adminRole.ID},
|
||||
}); err != nil {
|
||||
_ = s.userRepo.Delete(ctx, user.ID)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "admin_bootstrap")
|
||||
s.cacheUserInfo(ctx, user)
|
||||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypePassword, ip, true, "")
|
||||
s.publishEvent(ctx, domain.EventUserRegistered, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"role": adminRoleCode,
|
||||
"source": "admin_bootstrap",
|
||||
})
|
||||
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"ip": ip,
|
||||
"method": "admin_bootstrap",
|
||||
})
|
||||
|
||||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||||
}
|
||||
99
internal/service/auth_capabilities.go
Normal file
99
internal/service/auth_capabilities.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
const adminRoleCode = "admin"
|
||||
|
||||
type AuthCapabilities struct {
|
||||
Password bool `json:"password"`
|
||||
EmailActivation bool `json:"email_activation"`
|
||||
EmailCode bool `json:"email_code"`
|
||||
SMSCode bool `json:"sms_code"`
|
||||
PasswordReset bool `json:"password_reset"`
|
||||
AdminBootstrapRequired bool `json:"admin_bootstrap_required"`
|
||||
OAuthProviders []auth.OAuthProviderInfo `json:"oauth_providers"`
|
||||
}
|
||||
|
||||
func (s *AuthService) SupportsEmailActivation() bool {
|
||||
return s != nil && s.emailActivationSvc != nil
|
||||
}
|
||||
|
||||
func (s *AuthService) SupportsEmailCodeLogin() bool {
|
||||
return s != nil && s.emailCodeSvc != nil
|
||||
}
|
||||
|
||||
func (s *AuthService) SupportsSMSCodeLogin() bool {
|
||||
return s != nil && s.smsCodeSvc != nil
|
||||
}
|
||||
|
||||
func (s *AuthService) GetAuthCapabilities(ctx context.Context) AuthCapabilities {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
return AuthCapabilities{
|
||||
Password: true,
|
||||
EmailActivation: s.SupportsEmailActivation(),
|
||||
EmailCode: s.SupportsEmailCodeLogin(),
|
||||
SMSCode: s.SupportsSMSCodeLogin(),
|
||||
AdminBootstrapRequired: s.IsAdminBootstrapRequired(ctx),
|
||||
OAuthProviders: s.GetEnabledOAuthProviders(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) IsAdminBootstrapRequired(ctx context.Context) bool {
|
||||
if s == nil || s.userRepo == nil || s.roleRepo == nil || s.userRoleRepo == nil {
|
||||
return false
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
adminRole, err := s.roleRepo.GetByCode(ctx, adminRoleCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return true
|
||||
}
|
||||
log.Printf("auth: resolve auth capabilities failed while loading admin role: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
userIDs, err := s.userRoleRepo.GetUserIDByRoleID(ctx, adminRole.ID)
|
||||
if err != nil {
|
||||
log.Printf("auth: resolve auth capabilities failed while loading admin users: role_id=%d err=%v", adminRole.ID, err)
|
||||
return false
|
||||
}
|
||||
if len(userIDs) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
hadUnexpectedLookupError := false
|
||||
for _, userID := range userIDs {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
continue
|
||||
}
|
||||
hadUnexpectedLookupError = true
|
||||
log.Printf("auth: resolve auth capabilities failed while loading admin user: user_id=%d err=%v", userID, err)
|
||||
continue
|
||||
}
|
||||
if user != nil && user.Status == domain.UserStatusActive {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if hadUnexpectedLookupError {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
299
internal/service/auth_contact_binding.go
Normal file
299
internal/service/auth_contact_binding.go
Normal file
@@ -0,0 +1,299 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func (s *AuthService) SendEmailBindCode(ctx context.Context, userID int64, email string) error {
|
||||
if s == nil || s.userRepo == nil || s.emailCodeSvc == nil {
|
||||
return errors.New("email binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
normalizedEmail := strings.TrimSpace(email)
|
||||
if normalizedEmail == "" {
|
||||
return errors.New("email is required")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(domain.DerefStr(user.Email)), normalizedEmail) {
|
||||
return errors.New("email is already bound to the current account")
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByEmail(ctx, normalizedEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return errors.New("email already in use")
|
||||
}
|
||||
|
||||
return s.emailCodeSvc.SendEmailCode(ctx, normalizedEmail, "bind")
|
||||
}
|
||||
|
||||
func (s *AuthService) BindEmail(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
email string,
|
||||
code string,
|
||||
currentPassword string,
|
||||
totpCode string,
|
||||
) error {
|
||||
if s == nil || s.userRepo == nil || s.emailCodeSvc == nil {
|
||||
return errors.New("email binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
normalizedEmail := strings.TrimSpace(email)
|
||||
if normalizedEmail == "" {
|
||||
return errors.New("email is required")
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(domain.DerefStr(user.Email)), normalizedEmail) {
|
||||
return errors.New("email is already bound to the current account")
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByEmail(ctx, normalizedEmail)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return errors.New("email already in use")
|
||||
}
|
||||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.emailCodeSvc.VerifyEmailCode(ctx, normalizedEmail, "bind", strings.TrimSpace(code)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Email = domain.StrPtr(normalizedEmail)
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cacheUserInfo(ctx, user)
|
||||
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"email": normalizedEmail,
|
||||
"action": "bind_email",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) UnbindEmail(ctx context.Context, userID int64, currentPassword, totpCode string) error {
|
||||
if s == nil || s.userRepo == nil {
|
||||
return errors.New("email binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(domain.DerefStr(user.Email)) == "" {
|
||||
return errors.New("email is not bound")
|
||||
}
|
||||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.availableLoginMethodCountAfterContactRemoval(user, accounts, true, false) == 0 {
|
||||
return errors.New("at least one login method must remain after unbinding")
|
||||
}
|
||||
|
||||
user.Email = nil
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cacheUserInfo(ctx, user)
|
||||
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"action": "unbind_email",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) SendPhoneBindCode(ctx context.Context, userID int64, phone string) (*SendCodeResponse, error) {
|
||||
if s == nil || s.userRepo == nil || s.smsCodeSvc == nil {
|
||||
return nil, errors.New("phone binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
normalizedPhone := strings.TrimSpace(phone)
|
||||
if normalizedPhone == "" {
|
||||
return nil, errors.New("phone is required")
|
||||
}
|
||||
if strings.TrimSpace(domain.DerefStr(user.Phone)) == normalizedPhone {
|
||||
return nil, errors.New("phone is already bound to the current account")
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByPhone(ctx, normalizedPhone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("phone already in use")
|
||||
}
|
||||
|
||||
return s.smsCodeSvc.SendCode(ctx, &SendCodeRequest{
|
||||
Phone: normalizedPhone,
|
||||
Purpose: "bind",
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) BindPhone(
|
||||
ctx context.Context,
|
||||
userID int64,
|
||||
phone string,
|
||||
code string,
|
||||
currentPassword string,
|
||||
totpCode string,
|
||||
) error {
|
||||
if s == nil || s.userRepo == nil || s.smsCodeSvc == nil {
|
||||
return errors.New("phone binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
normalizedPhone := strings.TrimSpace(phone)
|
||||
if normalizedPhone == "" {
|
||||
return errors.New("phone is required")
|
||||
}
|
||||
if strings.TrimSpace(domain.DerefStr(user.Phone)) == normalizedPhone {
|
||||
return errors.New("phone is already bound to the current account")
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByPhone(ctx, normalizedPhone)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if exists {
|
||||
return errors.New("phone already in use")
|
||||
}
|
||||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.smsCodeSvc.VerifyCode(ctx, normalizedPhone, "bind", strings.TrimSpace(code)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.Phone = domain.StrPtr(normalizedPhone)
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cacheUserInfo(ctx, user)
|
||||
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"phone": normalizedPhone,
|
||||
"action": "bind_phone",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) UnbindPhone(ctx context.Context, userID int64, currentPassword, totpCode string) error {
|
||||
if s == nil || s.userRepo == nil {
|
||||
return errors.New("phone binding is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
return err
|
||||
}
|
||||
if strings.TrimSpace(domain.DerefStr(user.Phone)) == "" {
|
||||
return errors.New("phone is not bound")
|
||||
}
|
||||
if err := s.verifySensitiveAction(ctx, user, currentPassword, totpCode); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accounts, err := s.GetSocialAccounts(ctx, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if s.availableLoginMethodCountAfterContactRemoval(user, accounts, false, true) == 0 {
|
||||
return errors.New("at least one login method must remain after unbinding")
|
||||
}
|
||||
|
||||
user.Phone = nil
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.cacheUserInfo(ctx, user)
|
||||
s.publishEvent(ctx, domain.EventUserUpdated, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"action": "unbind_phone",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) availableLoginMethodCountAfterContactRemoval(
|
||||
user *domain.User,
|
||||
accounts []*domain.SocialAccount,
|
||||
removeEmail bool,
|
||||
removePhone bool,
|
||||
) int {
|
||||
if user == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
count := 0
|
||||
if strings.TrimSpace(user.Password) != "" {
|
||||
count++
|
||||
}
|
||||
if !removeEmail && s.emailCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Email)) != "" {
|
||||
count++
|
||||
}
|
||||
if !removePhone && s.smsCodeSvc != nil && strings.TrimSpace(domain.DerefStr(user.Phone)) != "" {
|
||||
count++
|
||||
}
|
||||
|
||||
for _, account := range accounts {
|
||||
if account == nil || account.Status != domain.SocialAccountStatusActive {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
201
internal/service/auth_email.go
Normal file
201
internal/service/auth_email.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func (s *AuthService) SetEmailActivationService(svc *EmailActivationService) {
|
||||
s.emailActivationSvc = svc
|
||||
}
|
||||
|
||||
func (s *AuthService) SetEmailCodeService(svc *EmailCodeService) {
|
||||
s.emailCodeSvc = svc
|
||||
}
|
||||
|
||||
func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
|
||||
if err := s.validatePassword(req.Password); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.verifyPhoneRegistration(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("username already exists")
|
||||
}
|
||||
|
||||
if req.Email != "" {
|
||||
exists, err = s.userRepo.ExistsByEmail(ctx, req.Email)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("email already exists")
|
||||
}
|
||||
}
|
||||
|
||||
if req.Phone != "" {
|
||||
exists, err = s.userRepo.ExistsByPhone(ctx, req.Phone)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("phone already exists")
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(req.Password)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
initialStatus := domain.UserStatusActive
|
||||
if s.emailActivationSvc != nil && req.Email != "" {
|
||||
initialStatus = domain.UserStatusInactive
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
Username: req.Username,
|
||||
Email: domain.StrPtr(req.Email),
|
||||
Phone: domain.StrPtr(req.Phone),
|
||||
Password: hashedPassword,
|
||||
Nickname: req.Nickname,
|
||||
Status: initialStatus,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register_with_activation")
|
||||
|
||||
if s.emailActivationSvc != nil && req.Email != "" {
|
||||
nickname := req.Nickname
|
||||
if nickname == "" {
|
||||
nickname = req.Username
|
||||
}
|
||||
go func() {
|
||||
if err := s.emailActivationSvc.SendActivationEmail(ctx, user.ID, req.Email, nickname); err != nil {
|
||||
log.Printf("auth: send activation email failed, user_id=%d email=%s err=%v", user.ID, req.Email, err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
userInfo := s.buildUserInfo(user)
|
||||
s.publishEvent(ctx, domain.EventUserRegistered, userInfo)
|
||||
return userInfo, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ActivateEmail(ctx context.Context, token string) error {
|
||||
if s.emailActivationSvc == nil {
|
||||
return errors.New("email activation service is not configured")
|
||||
}
|
||||
|
||||
userID, err := s.emailActivationSvc.ValidateActivationToken(ctx, token)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("user not found: %w", err)
|
||||
}
|
||||
|
||||
if user.Status == domain.UserStatusActive {
|
||||
return errors.New("account already activated")
|
||||
}
|
||||
if user.Status != domain.UserStatusInactive {
|
||||
return errors.New("account status does not allow activation")
|
||||
}
|
||||
|
||||
return s.userRepo.UpdateStatus(ctx, userID, domain.UserStatusActive)
|
||||
}
|
||||
|
||||
func (s *AuthService) ResendActivationEmail(ctx context.Context, email string) error {
|
||||
if s.emailActivationSvc == nil {
|
||||
return errors.New("email activation service is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
if user.Status == domain.UserStatusActive {
|
||||
return nil
|
||||
}
|
||||
if user.Status != domain.UserStatusInactive {
|
||||
return errors.New("account status does not allow activation")
|
||||
}
|
||||
|
||||
nickname := user.Nickname
|
||||
if nickname == "" {
|
||||
nickname = user.Username
|
||||
}
|
||||
return s.emailActivationSvc.SendActivationEmail(ctx, user.ID, email, nickname)
|
||||
}
|
||||
|
||||
func (s *AuthService) SendEmailLoginCode(ctx context.Context, email string) error {
|
||||
if s.emailCodeSvc == nil {
|
||||
return errors.New("email code service is not configured")
|
||||
}
|
||||
|
||||
_, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return s.emailCodeSvc.SendEmailCode(ctx, email, "login")
|
||||
}
|
||||
|
||||
func (s *AuthService) LoginByEmailCode(ctx context.Context, email, code, ip string) (*LoginResponse, error) {
|
||||
if s.emailCodeSvc == nil {
|
||||
return nil, errors.New("email code login is disabled")
|
||||
}
|
||||
|
||||
if err := s.emailCodeSvc.VerifyEmailCode(ctx, email, "login", code); err != nil {
|
||||
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
if isUserNotFoundError(err) {
|
||||
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, "email not registered")
|
||||
return nil, errors.New("email not registered")
|
||||
}
|
||||
s.writeLoginLog(ctx, nil, domain.LoginTypeEmailCode, ip, false, err.Error())
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.ensureUserActive(user); err != nil {
|
||||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeEmailCode, ip, false, err.Error())
|
||||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", false)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.bestEffortUpdateLastLogin(ctx, user.ID, ip, "email_code")
|
||||
s.writeLoginLog(ctx, &user.ID, domain.LoginTypeEmailCode, ip, true, "")
|
||||
s.recordLoginAnomaly(ctx, &user.ID, ip, "", "", true)
|
||||
s.publishEvent(ctx, domain.EventUserLogin, map[string]interface{}{
|
||||
"user_id": user.ID,
|
||||
"username": user.Username,
|
||||
"ip": ip,
|
||||
"method": "email_code",
|
||||
})
|
||||
|
||||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||||
}
|
||||
369
internal/service/auth_runtime.go
Normal file
369
internal/service/auth_runtime.go
Normal file
@@ -0,0 +1,369 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type oauthRegistrar interface {
|
||||
RegisterProvider(provider auth.OAuthProvider, config *auth.OAuthConfig)
|
||||
}
|
||||
|
||||
func (s *AuthService) RegisterOAuthProvider(provider auth.OAuthProvider, cfg *auth.OAuthConfig) {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
if registrar, ok := s.oauthManager.(oauthRegistrar); ok {
|
||||
registrar.RegisterProvider(provider, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) findUserForLogin(ctx context.Context, account string) (*domain.User, error) {
|
||||
user, err := s.userRepo.GetByUsername(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by username failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByEmail(ctx, account)
|
||||
if err == nil {
|
||||
return user, nil
|
||||
}
|
||||
if !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by email failed: %w", err)
|
||||
}
|
||||
|
||||
user, err = s.userRepo.GetByPhone(ctx, account)
|
||||
if err != nil && !isUserNotFoundError(err) {
|
||||
return nil, fmt.Errorf("lookup user by phone failed: %w", err)
|
||||
}
|
||||
return user, err
|
||||
}
|
||||
|
||||
func isUserNotFoundError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return true
|
||||
}
|
||||
|
||||
lowerErr := strings.ToLower(strings.TrimSpace(err.Error()))
|
||||
return strings.Contains(lowerErr, "record not found") ||
|
||||
strings.Contains(lowerErr, "user not found") ||
|
||||
strings.Contains(err.Error(), "用户不存在") ||
|
||||
strings.Contains(lowerErr, "not found")
|
||||
}
|
||||
|
||||
func (s *AuthService) bestEffortAssignDefaultRoles(ctx context.Context, userID int64, source string) {
|
||||
if s == nil || s.userRoleRepo == nil || s.roleRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
defaultRoles, err := s.roleRepo.GetDefaultRoles(ctx)
|
||||
if err != nil {
|
||||
log.Printf("auth: load default roles failed, source=%s user_id=%d err=%v", source, userID, err)
|
||||
return
|
||||
}
|
||||
if len(defaultRoles) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
userRoles := make([]*domain.UserRole, 0, len(defaultRoles))
|
||||
for _, role := range defaultRoles {
|
||||
userRoles = append(userRoles, &domain.UserRole{
|
||||
UserID: userID,
|
||||
RoleID: role.ID,
|
||||
})
|
||||
}
|
||||
|
||||
if err := s.userRoleRepo.BatchCreate(ctx, userRoles); err != nil {
|
||||
log.Printf("auth: assign default roles failed, source=%s user_id=%d role_count=%d err=%v", source, userID, len(userRoles), err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) bestEffortUpdateLastLogin(ctx context.Context, userID int64, ip, source string) {
|
||||
if s == nil || s.userRepo == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.userRepo.UpdateLastLogin(ctx, userID, ip); err != nil {
|
||||
log.Printf("auth: update last login failed, source=%s user_id=%d ip=%s err=%v", source, userID, ip, err)
|
||||
}
|
||||
}
|
||||
|
||||
func loginAttemptKey(account string, user *domain.User) string {
|
||||
if user != nil {
|
||||
return fmt.Sprintf("login_attempt:user:%d", user.ID)
|
||||
}
|
||||
return "login_attempt:account:" + strings.ToLower(strings.TrimSpace(account))
|
||||
}
|
||||
|
||||
func attemptCount(value interface{}) int {
|
||||
if count, ok := intValue(value); ok {
|
||||
return count
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func intValue(value interface{}) (int, bool) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v, true
|
||||
case int64:
|
||||
return int(v), true
|
||||
case float64:
|
||||
return int(v), true
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return int(n), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func int64Value(value interface{}) (int64, bool) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return v, true
|
||||
case int:
|
||||
return int64(v), true
|
||||
case float64:
|
||||
return int64(v), true
|
||||
case json.Number:
|
||||
n, err := v.Int64()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
return n, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) verifyPhoneRegistration(ctx context.Context, req *RegisterRequest) error {
|
||||
if req == nil || req.Phone == "" {
|
||||
return nil
|
||||
}
|
||||
if s.smsCodeSvc == nil {
|
||||
return errors.New("手机注册未启用")
|
||||
}
|
||||
if req.PhoneCode == "" {
|
||||
return errors.New("手机验证码不能为空")
|
||||
}
|
||||
return s.smsCodeSvc.VerifyCode(ctx, req.Phone, "register", req.PhoneCode)
|
||||
}
|
||||
|
||||
const (
|
||||
oauthStateCachePrefix = "oauth_state:"
|
||||
oauthHandoffCachePrefix = "oauth_handoff:"
|
||||
oauthStateTTL = 10 * time.Minute
|
||||
oauthHandoffTTL = time.Minute
|
||||
)
|
||||
|
||||
type OAuthStatePurpose string
|
||||
|
||||
const (
|
||||
OAuthStatePurposeLogin OAuthStatePurpose = "login"
|
||||
OAuthStatePurposeBind OAuthStatePurpose = "bind"
|
||||
)
|
||||
|
||||
type OAuthStatePayload struct {
|
||||
Purpose OAuthStatePurpose `json:"purpose"`
|
||||
ReturnTo string `json:"return_to"`
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
func generateOAuthEphemeralCode() (string, error) {
|
||||
buffer := make([]byte, 32)
|
||||
if _, err := rand.Read(buffer); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(buffer), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthState(ctx context.Context, returnTo string) (string, error) {
|
||||
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(returnTo),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthBindState(ctx context.Context, userID int64, returnTo string) (string, error) {
|
||||
if userID <= 0 {
|
||||
return "", errors.New("oauth binding user is required")
|
||||
}
|
||||
|
||||
return s.createOAuthStatePayload(ctx, &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeBind,
|
||||
ReturnTo: strings.TrimSpace(returnTo),
|
||||
UserID: userID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *AuthService) createOAuthStatePayload(ctx context.Context, payload *OAuthStatePayload) (string, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return "", errors.New("oauth state storage unavailable")
|
||||
}
|
||||
if payload == nil {
|
||||
return "", errors.New("oauth state payload is required")
|
||||
}
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
|
||||
state, err := generateOAuthEphemeralCode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, oauthStateCachePrefix+state, payload, oauthStateTTL, oauthStateTTL); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthState(ctx context.Context, state string) (string, error) {
|
||||
payload, err := s.ConsumeOAuthStatePayload(ctx, state)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if payload == nil {
|
||||
return "", nil
|
||||
}
|
||||
return strings.TrimSpace(payload.ReturnTo), nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthStatePayload(ctx context.Context, state string) (*OAuthStatePayload, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil, errors.New("oauth state storage unavailable")
|
||||
}
|
||||
|
||||
cacheKey := oauthStateCachePrefix + strings.TrimSpace(state)
|
||||
value, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return nil, errors.New("OAuth state validation failed")
|
||||
}
|
||||
_ = s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
switch typed := value.(type) {
|
||||
case *OAuthStatePayload:
|
||||
payload := *typed
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
case OAuthStatePayload:
|
||||
payload := typed
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
case string:
|
||||
return &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(typed),
|
||||
}, nil
|
||||
case nil:
|
||||
return &OAuthStatePayload{Purpose: OAuthStatePurposeLogin}, nil
|
||||
case map[string]interface{}:
|
||||
payloadBytes, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var payload OAuthStatePayload
|
||||
if err := json.Unmarshal(payloadBytes, &payload); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if payload.Purpose == "" {
|
||||
payload.Purpose = OAuthStatePurposeLogin
|
||||
}
|
||||
payload.ReturnTo = strings.TrimSpace(payload.ReturnTo)
|
||||
return &payload, nil
|
||||
default:
|
||||
return &OAuthStatePayload{
|
||||
Purpose: OAuthStatePurposeLogin,
|
||||
ReturnTo: strings.TrimSpace(fmt.Sprint(typed)),
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) CreateOAuthHandoff(ctx context.Context, loginResp *LoginResponse) (string, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return "", errors.New("oauth handoff storage unavailable")
|
||||
}
|
||||
if loginResp == nil {
|
||||
return "", errors.New("oauth handoff payload is required")
|
||||
}
|
||||
|
||||
code, err := generateOAuthEphemeralCode()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := s.cache.Set(ctx, oauthHandoffCachePrefix+code, loginResp, oauthHandoffTTL, oauthHandoffTTL); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) ConsumeOAuthHandoff(ctx context.Context, code string) (*LoginResponse, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil, errors.New("oauth handoff storage unavailable")
|
||||
}
|
||||
|
||||
cacheKey := oauthHandoffCachePrefix + strings.TrimSpace(code)
|
||||
value, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return nil, errors.New("OAuth handoff code is invalid or expired")
|
||||
}
|
||||
_ = s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
switch typed := value.(type) {
|
||||
case *LoginResponse:
|
||||
return typed, nil
|
||||
case LoginResponse:
|
||||
resp := typed
|
||||
return &resp, nil
|
||||
case map[string]interface{}:
|
||||
payload, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp LoginResponse
|
||||
if err := json.Unmarshal(payload, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
default:
|
||||
payload, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var resp LoginResponse
|
||||
if err := json.Unmarshal(payload, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
}
|
||||
343
internal/service/captcha.go
Normal file
343
internal/service/captcha.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
crand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
"image/color"
|
||||
"image/draw"
|
||||
"image/png"
|
||||
"math/big"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/cache"
|
||||
)
|
||||
|
||||
const (
|
||||
captchaWidth = 120
|
||||
captchaHeight = 40
|
||||
captchaLength = 4 // 验证码位数
|
||||
captchaTTL = 5 * time.Minute
|
||||
)
|
||||
|
||||
// captchaChars 验证码字符集(去掉容易混淆的字符 0/O/1/I/l)
|
||||
const captchaChars = "23456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghjkmnpqrstuvwxyz"
|
||||
|
||||
// CaptchaService 图形验证码服务
|
||||
type CaptchaService struct {
|
||||
cache *cache.CacheManager
|
||||
}
|
||||
|
||||
// NewCaptchaService 创建验证码服务
|
||||
func NewCaptchaService(cache *cache.CacheManager) *CaptchaService {
|
||||
return &CaptchaService{cache: cache}
|
||||
}
|
||||
|
||||
// CaptchaResult 验证码生成结果
|
||||
type CaptchaResult struct {
|
||||
CaptchaID string // 验证码ID(UUID)
|
||||
ImageData []byte // PNG图片字节
|
||||
}
|
||||
|
||||
// Generate 生成图形验证码
|
||||
func (s *CaptchaService) Generate(ctx context.Context) (*CaptchaResult, error) {
|
||||
// 生成随机验证码文字
|
||||
text, err := s.randomText(captchaLength)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成验证码文本失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成验证码ID
|
||||
captchaID, err := s.generateID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成验证码ID失败: %w", err)
|
||||
}
|
||||
|
||||
// 生成图片
|
||||
imgData, err := s.renderImage(text)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成验证码图片失败: %w", err)
|
||||
}
|
||||
|
||||
// 存入缓存(不区分大小写,存小写)
|
||||
cacheKey := "captcha:" + captchaID
|
||||
s.cache.Set(ctx, cacheKey, strings.ToLower(text), captchaTTL, captchaTTL)
|
||||
|
||||
return &CaptchaResult{
|
||||
CaptchaID: captchaID,
|
||||
ImageData: imgData,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Verify 验证验证码(验证后立即删除,防止重放)
|
||||
func (s *CaptchaService) Verify(ctx context.Context, captchaID, answer string) bool {
|
||||
if captchaID == "" || answer == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
cacheKey := "captcha:" + captchaID
|
||||
val, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 删除验证码(一次性使用)
|
||||
s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
expected, ok := val.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.ToLower(answer) == expected
|
||||
}
|
||||
|
||||
// VerifyWithoutDelete 验证验证码但不删除(用于测试)
|
||||
func (s *CaptchaService) VerifyWithoutDelete(ctx context.Context, captchaID, answer string) bool {
|
||||
if captchaID == "" || answer == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
cacheKey := "captcha:" + captchaID
|
||||
val, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
expected, ok := val.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return strings.ToLower(answer) == expected
|
||||
}
|
||||
|
||||
// ValidateCaptcha 验证验证码(对外暴露,验证后删除)
|
||||
func (s *CaptchaService) ValidateCaptcha(ctx context.Context, captchaID, answer string) error {
|
||||
if captchaID == "" {
|
||||
return errors.New("验证码ID不能为空")
|
||||
}
|
||||
if answer == "" {
|
||||
return errors.New("验证码答案不能为空")
|
||||
}
|
||||
if !s.Verify(ctx, captchaID, answer) {
|
||||
return errors.New("验证码错误或已过期")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// randomText 生成随机验证码文字
|
||||
func (s *CaptchaService) randomText(length int) (string, error) {
|
||||
chars := []byte(captchaChars)
|
||||
result := make([]byte, length)
|
||||
for i := range result {
|
||||
n, err := crand.Int(crand.Reader, big.NewInt(int64(len(chars))))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
result[i] = chars[n.Int64()]
|
||||
}
|
||||
return string(result), nil
|
||||
}
|
||||
|
||||
// generateID 生成验证码ID(crypto/rand 保证全局唯一,无碰撞)
|
||||
func (s *CaptchaService) generateID() (string, error) {
|
||||
b := make([]byte, 16)
|
||||
if _, err := crand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return fmt.Sprintf("%d-%s", time.Now().UnixNano(), hex.EncodeToString(b)), nil
|
||||
}
|
||||
|
||||
// renderImage 将文字渲染为PNG验证码图片(纯Go实现,无外部字体依赖)
|
||||
func (s *CaptchaService) renderImage(text string) ([]byte, error) {
|
||||
// 创建 RGBA 图像
|
||||
img := image.NewRGBA(image.Rect(0, 0, captchaWidth, captchaHeight))
|
||||
|
||||
// 随机背景色(浅色)
|
||||
rng := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
bgColor := color.RGBA{
|
||||
R: uint8(220 + rng.Intn(35)),
|
||||
G: uint8(220 + rng.Intn(35)),
|
||||
B: uint8(220 + rng.Intn(35)),
|
||||
A: 255,
|
||||
}
|
||||
draw.Draw(img, img.Bounds(), &image.Uniform{bgColor}, image.Point{}, draw.Src)
|
||||
|
||||
// 绘制干扰线
|
||||
for i := 0; i < 5; i++ {
|
||||
lineColor := color.RGBA{
|
||||
R: uint8(rng.Intn(200)),
|
||||
G: uint8(rng.Intn(200)),
|
||||
B: uint8(rng.Intn(200)),
|
||||
A: 255,
|
||||
}
|
||||
x1 := rng.Intn(captchaWidth)
|
||||
y1 := rng.Intn(captchaHeight)
|
||||
x2 := rng.Intn(captchaWidth)
|
||||
y2 := rng.Intn(captchaHeight)
|
||||
drawLine(img, x1, y1, x2, y2, lineColor)
|
||||
}
|
||||
|
||||
// 绘制文字(使用像素字体)
|
||||
for i, ch := range text {
|
||||
charColor := color.RGBA{
|
||||
R: uint8(rng.Intn(150)),
|
||||
G: uint8(rng.Intn(150)),
|
||||
B: uint8(rng.Intn(150)),
|
||||
A: 255,
|
||||
}
|
||||
x := 10 + i*25 + rng.Intn(5)
|
||||
y := 8 + rng.Intn(12)
|
||||
drawChar(img, x, y, byte(ch), charColor)
|
||||
}
|
||||
|
||||
// 绘制干扰点
|
||||
for i := 0; i < 80; i++ {
|
||||
dotColor := color.RGBA{
|
||||
R: uint8(rng.Intn(255)),
|
||||
G: uint8(rng.Intn(255)),
|
||||
B: uint8(rng.Intn(255)),
|
||||
A: uint8(100 + rng.Intn(100)),
|
||||
}
|
||||
img.Set(rng.Intn(captchaWidth), rng.Intn(captchaHeight), dotColor)
|
||||
}
|
||||
|
||||
// 编码为 PNG
|
||||
var buf bytes.Buffer
|
||||
if err := png.Encode(&buf, img); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// drawLine 画直线(Bresenham算法)
|
||||
func drawLine(img *image.RGBA, x1, y1, x2, y2 int, c color.RGBA) {
|
||||
dx := abs(x2 - x1)
|
||||
dy := abs(y2 - y1)
|
||||
sx, sy := 1, 1
|
||||
if x1 > x2 {
|
||||
sx = -1
|
||||
}
|
||||
if y1 > y2 {
|
||||
sy = -1
|
||||
}
|
||||
err := dx - dy
|
||||
for {
|
||||
img.Set(x1, y1, c)
|
||||
if x1 == x2 && y1 == y2 {
|
||||
break
|
||||
}
|
||||
e2 := 2 * err
|
||||
if e2 > -dy {
|
||||
err -= dy
|
||||
x1 += sx
|
||||
}
|
||||
if e2 < dx {
|
||||
err += dx
|
||||
y1 += sy
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func abs(x int) int {
|
||||
if x < 0 {
|
||||
return -x
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
// pixelFont 5x7 像素字体位图(ASCII 32-127)
|
||||
// 每个字符用5个uint8表示(5列),每个uint8的低7位是每行是否亮起
|
||||
var pixelFont = map[byte][5]uint8{
|
||||
'0': {0x3E, 0x51, 0x49, 0x45, 0x3E},
|
||||
'1': {0x00, 0x42, 0x7F, 0x40, 0x00},
|
||||
'2': {0x42, 0x61, 0x51, 0x49, 0x46},
|
||||
'3': {0x21, 0x41, 0x45, 0x4B, 0x31},
|
||||
'4': {0x18, 0x14, 0x12, 0x7F, 0x10},
|
||||
'5': {0x27, 0x45, 0x45, 0x45, 0x39},
|
||||
'6': {0x3C, 0x4A, 0x49, 0x49, 0x30},
|
||||
'7': {0x01, 0x71, 0x09, 0x05, 0x03},
|
||||
'8': {0x36, 0x49, 0x49, 0x49, 0x36},
|
||||
'9': {0x06, 0x49, 0x49, 0x29, 0x1E},
|
||||
'A': {0x7E, 0x11, 0x11, 0x11, 0x7E},
|
||||
'B': {0x7F, 0x49, 0x49, 0x49, 0x36},
|
||||
'C': {0x3E, 0x41, 0x41, 0x41, 0x22},
|
||||
'D': {0x7F, 0x41, 0x41, 0x22, 0x1C},
|
||||
'E': {0x7F, 0x49, 0x49, 0x49, 0x41},
|
||||
'F': {0x7F, 0x09, 0x09, 0x09, 0x01},
|
||||
'G': {0x3E, 0x41, 0x49, 0x49, 0x7A},
|
||||
'H': {0x7F, 0x08, 0x08, 0x08, 0x7F},
|
||||
'J': {0x20, 0x40, 0x41, 0x3F, 0x01},
|
||||
'K': {0x7F, 0x08, 0x14, 0x22, 0x41},
|
||||
'L': {0x7F, 0x40, 0x40, 0x40, 0x40},
|
||||
'M': {0x7F, 0x02, 0x0C, 0x02, 0x7F},
|
||||
'N': {0x7F, 0x04, 0x08, 0x10, 0x7F},
|
||||
'P': {0x7F, 0x09, 0x09, 0x09, 0x06},
|
||||
'Q': {0x3E, 0x41, 0x51, 0x21, 0x5E},
|
||||
'R': {0x7F, 0x09, 0x19, 0x29, 0x46},
|
||||
'S': {0x46, 0x49, 0x49, 0x49, 0x31},
|
||||
'T': {0x01, 0x01, 0x7F, 0x01, 0x01},
|
||||
'U': {0x3F, 0x40, 0x40, 0x40, 0x3F},
|
||||
'V': {0x1F, 0x20, 0x40, 0x20, 0x1F},
|
||||
'W': {0x3F, 0x40, 0x38, 0x40, 0x3F},
|
||||
'X': {0x63, 0x14, 0x08, 0x14, 0x63},
|
||||
'Y': {0x07, 0x08, 0x70, 0x08, 0x07},
|
||||
'Z': {0x61, 0x51, 0x49, 0x45, 0x43},
|
||||
'a': {0x20, 0x54, 0x54, 0x54, 0x78},
|
||||
'b': {0x7F, 0x48, 0x44, 0x44, 0x38},
|
||||
'c': {0x38, 0x44, 0x44, 0x44, 0x20},
|
||||
'd': {0x38, 0x44, 0x44, 0x48, 0x7F},
|
||||
'e': {0x38, 0x54, 0x54, 0x54, 0x18},
|
||||
'f': {0x08, 0x7E, 0x09, 0x01, 0x02},
|
||||
'g': {0x0C, 0x52, 0x52, 0x52, 0x3E},
|
||||
'h': {0x7F, 0x08, 0x04, 0x04, 0x78},
|
||||
'j': {0x20, 0x40, 0x44, 0x3D, 0x00},
|
||||
'k': {0x7F, 0x10, 0x28, 0x44, 0x00},
|
||||
'm': {0x7C, 0x04, 0x18, 0x04, 0x78},
|
||||
'n': {0x7C, 0x08, 0x04, 0x04, 0x78},
|
||||
'p': {0x7C, 0x14, 0x14, 0x14, 0x08},
|
||||
'q': {0x08, 0x14, 0x14, 0x18, 0x7C},
|
||||
'r': {0x7C, 0x08, 0x04, 0x04, 0x08},
|
||||
's': {0x48, 0x54, 0x54, 0x54, 0x20},
|
||||
't': {0x04, 0x3F, 0x44, 0x40, 0x20},
|
||||
'u': {0x3C, 0x40, 0x40, 0x20, 0x7C},
|
||||
'v': {0x1C, 0x20, 0x40, 0x20, 0x1C},
|
||||
'w': {0x3C, 0x40, 0x30, 0x40, 0x3C},
|
||||
'x': {0x44, 0x28, 0x10, 0x28, 0x44},
|
||||
'y': {0x0C, 0x50, 0x50, 0x50, 0x3C},
|
||||
'z': {0x44, 0x64, 0x54, 0x4C, 0x44},
|
||||
}
|
||||
|
||||
// drawChar 在图像上绘制单个字符
|
||||
func drawChar(img *image.RGBA, x, y int, ch byte, c color.RGBA) {
|
||||
glyph, ok := pixelFont[ch]
|
||||
if !ok {
|
||||
// 未知字符画个方块
|
||||
for dy := 0; dy < 7; dy++ {
|
||||
for dx := 0; dx < 5; dx++ {
|
||||
img.Set(x+dx*2, y+dy*2, c)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for col, colData := range glyph {
|
||||
for row := 0; row < 7; row++ {
|
||||
if colData&(1<<uint(row)) != 0 {
|
||||
// 放大2倍绘制
|
||||
img.Set(x+col*2, y+row*2, c)
|
||||
img.Set(x+col*2+1, y+row*2, c)
|
||||
img.Set(x+col*2, y+row*2+1, c)
|
||||
img.Set(x+col*2+1, y+row*2+1, c)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
41
internal/service/classified_error.go
Normal file
41
internal/service/classified_error.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package service
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrRateLimitExceeded = errors.New("rate limit exceeded")
|
||||
ErrValidationFailed = errors.New("validation failed")
|
||||
)
|
||||
|
||||
type classifiedError struct {
|
||||
message string
|
||||
cause error
|
||||
}
|
||||
|
||||
func (e *classifiedError) Error() string {
|
||||
if e.message != "" {
|
||||
return e.message
|
||||
}
|
||||
if e.cause != nil {
|
||||
return e.cause.Error()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (e *classifiedError) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
func newRateLimitError(message string) error {
|
||||
return &classifiedError{
|
||||
message: message,
|
||||
cause: ErrRateLimitExceeded,
|
||||
}
|
||||
}
|
||||
|
||||
func newValidationError(message string) error {
|
||||
return &classifiedError{
|
||||
message: message,
|
||||
cause: ErrValidationFailed,
|
||||
}
|
||||
}
|
||||
319
internal/service/custom_field.go
Normal file
319
internal/service/custom_field.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// CustomFieldService 自定义字段服务
|
||||
type CustomFieldService struct {
|
||||
fieldRepo *repository.CustomFieldRepository
|
||||
valueRepo *repository.UserCustomFieldValueRepository
|
||||
}
|
||||
|
||||
// NewCustomFieldService 创建自定义字段服务
|
||||
func NewCustomFieldService(
|
||||
fieldRepo *repository.CustomFieldRepository,
|
||||
valueRepo *repository.UserCustomFieldValueRepository,
|
||||
) *CustomFieldService {
|
||||
return &CustomFieldService{
|
||||
fieldRepo: fieldRepo,
|
||||
valueRepo: valueRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateFieldRequest 创建字段请求
|
||||
type CreateFieldRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
FieldKey string `json:"field_key" binding:"required"`
|
||||
Type int `json:"type" binding:"required"`
|
||||
Required bool `json:"required"`
|
||||
Default string `json:"default"`
|
||||
MinLen int `json:"min_len"`
|
||||
MaxLen int `json:"max_len"`
|
||||
MinVal float64 `json:"min_val"`
|
||||
MaxVal float64 `json:"max_val"`
|
||||
Options string `json:"options"`
|
||||
Sort int `json:"sort"`
|
||||
}
|
||||
|
||||
// UpdateFieldRequest 更新字段请求
|
||||
type UpdateFieldRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type int `json:"type"`
|
||||
Required *bool `json:"required"`
|
||||
Default string `json:"default"`
|
||||
MinLen int `json:"min_len"`
|
||||
MaxLen int `json:"max_len"`
|
||||
MinVal float64 `json:"min_val"`
|
||||
MaxVal float64 `json:"max_val"`
|
||||
Options string `json:"options"`
|
||||
Sort int `json:"sort"`
|
||||
Status *int `json:"status"`
|
||||
}
|
||||
|
||||
// CreateField 创建自定义字段
|
||||
func (s *CustomFieldService) CreateField(ctx context.Context, req *CreateFieldRequest) (*domain.CustomField, error) {
|
||||
// 检查field_key是否已存在
|
||||
existing, err := s.fieldRepo.GetByFieldKey(ctx, req.FieldKey)
|
||||
if err == nil && existing != nil {
|
||||
return nil, errors.New("字段标识符已存在")
|
||||
}
|
||||
|
||||
field := &domain.CustomField{
|
||||
Name: req.Name,
|
||||
FieldKey: req.FieldKey,
|
||||
Type: domain.CustomFieldType(req.Type),
|
||||
Required: req.Required,
|
||||
DefaultVal: req.Default,
|
||||
MinLen: req.MinLen,
|
||||
MaxLen: req.MaxLen,
|
||||
MinVal: req.MinVal,
|
||||
MaxVal: req.MaxVal,
|
||||
Options: req.Options,
|
||||
Sort: req.Sort,
|
||||
Status: 1,
|
||||
}
|
||||
|
||||
if err := s.fieldRepo.Create(ctx, field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return field, nil
|
||||
}
|
||||
|
||||
// UpdateField 更新自定义字段
|
||||
func (s *CustomFieldService) UpdateField(ctx context.Context, id int64, req *UpdateFieldRequest) (*domain.CustomField, error) {
|
||||
field, err := s.fieldRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, errors.New("字段不存在")
|
||||
}
|
||||
|
||||
if req.Name != "" {
|
||||
field.Name = req.Name
|
||||
}
|
||||
if req.Type > 0 {
|
||||
field.Type = domain.CustomFieldType(req.Type)
|
||||
}
|
||||
if req.Required != nil {
|
||||
field.Required = *req.Required
|
||||
}
|
||||
if req.Default != "" {
|
||||
field.DefaultVal = req.Default
|
||||
}
|
||||
if req.MinLen > 0 {
|
||||
field.MinLen = req.MinLen
|
||||
}
|
||||
if req.MaxLen > 0 {
|
||||
field.MaxLen = req.MaxLen
|
||||
}
|
||||
if req.MinVal > 0 {
|
||||
field.MinVal = req.MinVal
|
||||
}
|
||||
if req.MaxVal > 0 {
|
||||
field.MaxVal = req.MaxVal
|
||||
}
|
||||
if req.Options != "" {
|
||||
field.Options = req.Options
|
||||
}
|
||||
if req.Sort > 0 {
|
||||
field.Sort = req.Sort
|
||||
}
|
||||
if req.Status != nil {
|
||||
field.Status = *req.Status
|
||||
}
|
||||
|
||||
if err := s.fieldRepo.Update(ctx, field); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return field, nil
|
||||
}
|
||||
|
||||
// DeleteField 删除自定义字段
|
||||
func (s *CustomFieldService) DeleteField(ctx context.Context, id int64) error {
|
||||
field, err := s.fieldRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.New("字段不存在")
|
||||
}
|
||||
|
||||
// 删除字段定义
|
||||
if err := s.fieldRepo.Delete(ctx, id); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清理用户的该字段值(可选,取决于业务需求)
|
||||
_ = field
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetField 获取自定义字段
|
||||
func (s *CustomFieldService) GetField(ctx context.Context, id int64) (*domain.CustomField, error) {
|
||||
return s.fieldRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListFields 获取所有启用的自定义字段
|
||||
func (s *CustomFieldService) ListFields(ctx context.Context) ([]*domain.CustomField, error) {
|
||||
return s.fieldRepo.List(ctx)
|
||||
}
|
||||
|
||||
// ListAllFields 获取所有自定义字段
|
||||
func (s *CustomFieldService) ListAllFields(ctx context.Context) ([]*domain.CustomField, error) {
|
||||
return s.fieldRepo.ListAll(ctx)
|
||||
}
|
||||
|
||||
// SetUserFieldValue 设置用户的自定义字段值
|
||||
func (s *CustomFieldService) SetUserFieldValue(ctx context.Context, userID int64, fieldKey string, value string) error {
|
||||
// 获取字段定义
|
||||
field, err := s.fieldRepo.GetByFieldKey(ctx, fieldKey)
|
||||
if err != nil {
|
||||
return errors.New("字段不存在")
|
||||
}
|
||||
|
||||
// 验证值
|
||||
if err := s.validateFieldValue(field, value); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.valueRepo.Set(ctx, userID, field.ID, fieldKey, value)
|
||||
}
|
||||
|
||||
// BatchSetUserFieldValues 批量设置用户的自定义字段值
|
||||
func (s *CustomFieldService) BatchSetUserFieldValues(ctx context.Context, userID int64, values map[string]string) error {
|
||||
// 获取所有启用的字段定义
|
||||
fields, err := s.fieldRepo.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fieldMap := make(map[string]*domain.CustomField)
|
||||
for _, f := range fields {
|
||||
fieldMap[f.FieldKey] = f
|
||||
}
|
||||
|
||||
// 验证每个值
|
||||
for fieldKey, value := range values {
|
||||
field, ok := fieldMap[fieldKey]
|
||||
if !ok {
|
||||
return fmt.Errorf("字段不存在: %s", fieldKey)
|
||||
}
|
||||
if err := s.validateFieldValue(field, value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 批量设置值
|
||||
return s.valueRepo.BatchSet(ctx, userID, values)
|
||||
}
|
||||
|
||||
// GetUserFieldValues 获取用户的所有自定义字段值
|
||||
func (s *CustomFieldService) GetUserFieldValues(ctx context.Context, userID int64) ([]*domain.CustomFieldValueResponse, error) {
|
||||
// 获取所有启用的字段定义
|
||||
fields, err := s.fieldRepo.List(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户的字段值
|
||||
values, err := s.valueRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建字段值映射
|
||||
valueMap := make(map[int64]*domain.UserCustomFieldValue)
|
||||
for _, v := range values {
|
||||
valueMap[v.FieldID] = v
|
||||
}
|
||||
|
||||
// 构建响应
|
||||
fieldMap := make(map[string]*domain.CustomField)
|
||||
for _, f := range fields {
|
||||
fieldMap[f.FieldKey] = f
|
||||
}
|
||||
|
||||
var result []*domain.CustomFieldValueResponse
|
||||
for _, field := range fields {
|
||||
resp := &domain.CustomFieldValueResponse{
|
||||
FieldKey: field.FieldKey,
|
||||
}
|
||||
|
||||
if val, ok := valueMap[field.ID]; ok {
|
||||
resp.Value = val.GetValueAsInterface(field)
|
||||
} else if field.DefaultVal != "" {
|
||||
resp.Value = field.DefaultVal
|
||||
} else {
|
||||
resp.Value = nil
|
||||
}
|
||||
|
||||
result = append(result, resp)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteUserFieldValue 删除用户的自定义字段值
|
||||
func (s *CustomFieldService) DeleteUserFieldValue(ctx context.Context, userID int64, fieldKey string) error {
|
||||
field, err := s.fieldRepo.GetByFieldKey(ctx, fieldKey)
|
||||
if err != nil {
|
||||
return errors.New("字段不存在")
|
||||
}
|
||||
|
||||
return s.valueRepo.Delete(ctx, userID, field.ID)
|
||||
}
|
||||
|
||||
// validateFieldValue 验证字段值
|
||||
func (s *CustomFieldService) validateFieldValue(field *domain.CustomField, value string) error {
|
||||
// 检查必填
|
||||
if field.Required && value == "" {
|
||||
return errors.New("字段值不能为空")
|
||||
}
|
||||
|
||||
// 如果值为空且有默认值,跳过验证
|
||||
if value == "" && field.DefaultVal != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch field.Type {
|
||||
case domain.CustomFieldTypeString:
|
||||
// 字符串长度验证
|
||||
if field.MinLen > 0 && len(value) < field.MinLen {
|
||||
return fmt.Errorf("值长度不能小于%d", field.MinLen)
|
||||
}
|
||||
if field.MaxLen > 0 && len(value) > field.MaxLen {
|
||||
return fmt.Errorf("值长度不能大于%d", field.MaxLen)
|
||||
}
|
||||
case domain.CustomFieldTypeNumber:
|
||||
// 数字验证
|
||||
numVal, err := strconv.ParseFloat(value, 64)
|
||||
if err != nil {
|
||||
return errors.New("值必须是数字")
|
||||
}
|
||||
if field.MinVal > 0 && numVal < field.MinVal {
|
||||
return fmt.Errorf("值不能小于%.2f", field.MinVal)
|
||||
}
|
||||
if field.MaxVal > 0 && numVal > field.MaxVal {
|
||||
return fmt.Errorf("值不能大于%.2f", field.MaxVal)
|
||||
}
|
||||
case domain.CustomFieldTypeBoolean:
|
||||
// 布尔验证
|
||||
if value != "true" && value != "false" && value != "1" && value != "0" {
|
||||
return errors.New("值必须是布尔值(true/false/1/0)")
|
||||
}
|
||||
case domain.CustomFieldTypeDate:
|
||||
// 日期验证
|
||||
_, err := time.Parse("2006-01-02", value)
|
||||
if err != nil {
|
||||
return errors.New("值必须是有效的日期格式(YYYY-MM-DD)")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
276
internal/service/device.go
Normal file
276
internal/service/device.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// DeviceService 设备服务
|
||||
type DeviceService struct {
|
||||
deviceRepo *repository.DeviceRepository
|
||||
userRepo *repository.UserRepository
|
||||
}
|
||||
|
||||
// NewDeviceService 创建设备服务
|
||||
func NewDeviceService(
|
||||
deviceRepo *repository.DeviceRepository,
|
||||
userRepo *repository.UserRepository,
|
||||
) *DeviceService {
|
||||
return &DeviceService{
|
||||
deviceRepo: deviceRepo,
|
||||
userRepo: userRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateDeviceRequest 创建设备请求
|
||||
type CreateDeviceRequest struct {
|
||||
DeviceID string `json:"device_id" binding:"required"`
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceType int `json:"device_type"`
|
||||
DeviceOS string `json:"device_os"`
|
||||
DeviceBrowser string `json:"device_browser"`
|
||||
IP string `json:"ip"`
|
||||
Location string `json:"location"`
|
||||
}
|
||||
|
||||
// UpdateDeviceRequest 更新设备请求
|
||||
type UpdateDeviceRequest struct {
|
||||
DeviceName string `json:"device_name"`
|
||||
DeviceType int `json:"device_type"`
|
||||
DeviceOS string `json:"device_os"`
|
||||
DeviceBrowser string `json:"device_browser"`
|
||||
IP string `json:"ip"`
|
||||
Location string `json:"location"`
|
||||
Status int `json:"status"`
|
||||
}
|
||||
|
||||
// CreateDevice 创建设备
|
||||
func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *CreateDeviceRequest) (*domain.Device, error) {
|
||||
// 检查用户是否存在
|
||||
_, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 检查设备是否已存在
|
||||
exists, err := s.deviceRepo.Exists(ctx, userID, req.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
// 设备已存在,更新最后活跃时间
|
||||
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, req.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
device.LastActiveTime = time.Now()
|
||||
return device, s.deviceRepo.Update(ctx, device)
|
||||
}
|
||||
|
||||
// 创建设备
|
||||
device := &domain.Device{
|
||||
UserID: userID,
|
||||
DeviceID: req.DeviceID,
|
||||
DeviceName: req.DeviceName,
|
||||
DeviceType: domain.DeviceType(req.DeviceType),
|
||||
DeviceOS: req.DeviceOS,
|
||||
DeviceBrowser: req.DeviceBrowser,
|
||||
IP: req.IP,
|
||||
Location: req.Location,
|
||||
Status: domain.DeviceStatusActive,
|
||||
}
|
||||
|
||||
if err := s.deviceRepo.Create(ctx, device); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// UpdateDevice 更新设备
|
||||
func (s *DeviceService) UpdateDevice(ctx context.Context, deviceID int64, req *UpdateDeviceRequest) (*domain.Device, error) {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return nil, errors.New("设备不存在")
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.DeviceName != "" {
|
||||
device.DeviceName = req.DeviceName
|
||||
}
|
||||
if req.DeviceType >= 0 {
|
||||
device.DeviceType = domain.DeviceType(req.DeviceType)
|
||||
}
|
||||
if req.DeviceOS != "" {
|
||||
device.DeviceOS = req.DeviceOS
|
||||
}
|
||||
if req.DeviceBrowser != "" {
|
||||
device.DeviceBrowser = req.DeviceBrowser
|
||||
}
|
||||
if req.IP != "" {
|
||||
device.IP = req.IP
|
||||
}
|
||||
if req.Location != "" {
|
||||
device.Location = req.Location
|
||||
}
|
||||
if req.Status >= 0 {
|
||||
device.Status = domain.DeviceStatus(req.Status)
|
||||
}
|
||||
|
||||
if err := s.deviceRepo.Update(ctx, device); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return device, nil
|
||||
}
|
||||
|
||||
// DeleteDevice 删除设备
|
||||
func (s *DeviceService) DeleteDevice(ctx context.Context, deviceID int64) error {
|
||||
return s.deviceRepo.Delete(ctx, deviceID)
|
||||
}
|
||||
|
||||
// GetDevice 获取设备信息
|
||||
func (s *DeviceService) GetDevice(ctx context.Context, deviceID int64) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByID(ctx, deviceID)
|
||||
}
|
||||
|
||||
// GetUserDevices 获取用户设备列表
|
||||
func (s *DeviceService) GetUserDevices(ctx context.Context, userID int64, page, pageSize int) ([]*domain.Device, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return s.deviceRepo.ListByUserID(ctx, userID, offset, pageSize)
|
||||
}
|
||||
|
||||
// UpdateDeviceStatus 更新设备状态
|
||||
func (s *DeviceService) UpdateDeviceStatus(ctx context.Context, deviceID int64, status domain.DeviceStatus) error {
|
||||
return s.deviceRepo.UpdateStatus(ctx, deviceID, status)
|
||||
}
|
||||
|
||||
// UpdateLastActiveTime 更新最后活跃时间
|
||||
func (s *DeviceService) UpdateLastActiveTime(ctx context.Context, deviceID int64) error {
|
||||
return s.deviceRepo.UpdateLastActiveTime(ctx, deviceID)
|
||||
}
|
||||
|
||||
// GetActiveDevices 获取活跃设备
|
||||
func (s *DeviceService) GetActiveDevices(ctx context.Context, page, pageSize int) ([]*domain.Device, int64, error) {
|
||||
offset := (page - 1) * pageSize
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
|
||||
return s.deviceRepo.ListByStatus(ctx, domain.DeviceStatusActive, offset, pageSize)
|
||||
}
|
||||
|
||||
// TrustDevice 设置设备为信任状态
|
||||
func (s *DeviceService) TrustDevice(ctx context.Context, deviceID int64, trustDuration time.Duration) error {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
}
|
||||
|
||||
var trustExpiresAt *time.Time
|
||||
if trustDuration > 0 {
|
||||
expiresAt := time.Now().Add(trustDuration)
|
||||
trustExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
|
||||
}
|
||||
|
||||
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
|
||||
func (s *DeviceService) TrustDeviceByDeviceID(ctx context.Context, userID int64, deviceID string, trustDuration time.Duration) error {
|
||||
device, err := s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
}
|
||||
|
||||
var trustExpiresAt *time.Time
|
||||
if trustDuration > 0 {
|
||||
expiresAt := time.Now().Add(trustDuration)
|
||||
trustExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
return s.deviceRepo.TrustDevice(ctx, device.ID, trustExpiresAt)
|
||||
}
|
||||
|
||||
// UntrustDevice 取消设备信任状态
|
||||
func (s *DeviceService) UntrustDevice(ctx context.Context, deviceID int64) error {
|
||||
device, err := s.deviceRepo.GetByID(ctx, deviceID)
|
||||
if err != nil {
|
||||
return errors.New("设备不存在")
|
||||
}
|
||||
|
||||
return s.deviceRepo.UntrustDevice(ctx, device.ID)
|
||||
}
|
||||
|
||||
// LogoutAllOtherDevices 登出所有其他设备
|
||||
func (s *DeviceService) LogoutAllOtherDevices(ctx context.Context, userID int64, currentDeviceID int64) error {
|
||||
return s.deviceRepo.DeleteAllByUserIDExcept(ctx, userID, currentDeviceID)
|
||||
}
|
||||
|
||||
// GetTrustedDevices 获取用户的信任设备列表
|
||||
func (s *DeviceService) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) {
|
||||
return s.deviceRepo.GetTrustedDevices(ctx, userID)
|
||||
}
|
||||
|
||||
// GetAllDevicesRequest 获取所有设备请求参数
|
||||
type GetAllDevicesRequest struct {
|
||||
Page int
|
||||
PageSize int
|
||||
UserID int64 `form:"user_id"`
|
||||
Status int `form:"status"`
|
||||
IsTrusted *bool `form:"is_trusted"`
|
||||
Keyword string `form:"keyword"`
|
||||
}
|
||||
|
||||
// GetAllDevices 获取所有设备(管理员用)
|
||||
func (s *DeviceService) GetAllDevices(ctx context.Context, req *GetAllDevicesRequest) ([]*domain.Device, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
if req.PageSize > 100 {
|
||||
req.PageSize = 100
|
||||
}
|
||||
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
params := &repository.ListDevicesParams{
|
||||
UserID: req.UserID,
|
||||
Keyword: req.Keyword,
|
||||
Offset: offset,
|
||||
Limit: req.PageSize,
|
||||
}
|
||||
|
||||
// 处理状态筛选
|
||||
if req.Status >= 0 {
|
||||
params.Status = domain.DeviceStatus(req.Status)
|
||||
}
|
||||
|
||||
// 处理信任状态筛选
|
||||
if req.IsTrusted != nil {
|
||||
params.IsTrusted = req.IsTrusted
|
||||
}
|
||||
|
||||
return s.deviceRepo.ListAll(ctx, params)
|
||||
}
|
||||
|
||||
// GetDeviceByDeviceID 根据设备标识获取设备(用于设备信任检查)
|
||||
func (s *DeviceService) GetDeviceByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) {
|
||||
return s.deviceRepo.GetByDeviceID(ctx, userID, deviceID)
|
||||
}
|
||||
308
internal/service/email.go
Normal file
308
internal/service/email.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EmailProvider interface {
|
||||
SendMail(ctx context.Context, to, subject, htmlBody string) error
|
||||
}
|
||||
|
||||
type SMTPEmailConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
FromEmail string
|
||||
FromName string
|
||||
TLS bool
|
||||
}
|
||||
|
||||
type SMTPEmailProvider struct {
|
||||
cfg SMTPEmailConfig
|
||||
}
|
||||
|
||||
func NewSMTPEmailProvider(cfg SMTPEmailConfig) EmailProvider {
|
||||
return &SMTPEmailProvider{cfg: cfg}
|
||||
}
|
||||
|
||||
func (p *SMTPEmailProvider) SendMail(ctx context.Context, to, subject, htmlBody string) error {
|
||||
_ = ctx
|
||||
|
||||
var authInfo smtp.Auth
|
||||
if p.cfg.Username != "" || p.cfg.Password != "" {
|
||||
authInfo = smtp.PlainAuth("", p.cfg.Username, p.cfg.Password, p.cfg.Host)
|
||||
}
|
||||
|
||||
from := p.cfg.FromEmail
|
||||
if p.cfg.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", p.cfg.FromName, p.cfg.FromEmail)
|
||||
}
|
||||
|
||||
headers := []string{
|
||||
fmt.Sprintf("From: %s", from),
|
||||
fmt.Sprintf("To: %s", to),
|
||||
fmt.Sprintf("Subject: %s", subject),
|
||||
"MIME-Version: 1.0",
|
||||
"Content-Type: text/html; charset=UTF-8",
|
||||
"",
|
||||
}
|
||||
|
||||
message := strings.Join(headers, "\r\n") + htmlBody
|
||||
addr := fmt.Sprintf("%s:%d", p.cfg.Host, p.cfg.Port)
|
||||
return smtp.SendMail(addr, authInfo, p.cfg.FromEmail, []string{to}, []byte(message))
|
||||
}
|
||||
|
||||
type MockEmailProvider struct{}
|
||||
|
||||
func (m *MockEmailProvider) SendMail(ctx context.Context, to, subject, htmlBody string) error {
|
||||
_ = ctx
|
||||
log.Printf("[email-mock] to=%s subject=%s body_bytes=%d", to, subject, len(htmlBody))
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmailCodeConfig struct {
|
||||
CodeTTL time.Duration
|
||||
ResendCooldown time.Duration
|
||||
MaxDailyLimit int
|
||||
SiteURL string
|
||||
SiteName string
|
||||
}
|
||||
|
||||
func DefaultEmailCodeConfig() EmailCodeConfig {
|
||||
return EmailCodeConfig{
|
||||
CodeTTL: 5 * time.Minute,
|
||||
ResendCooldown: time.Minute,
|
||||
MaxDailyLimit: 10,
|
||||
SiteURL: "http://localhost:8080",
|
||||
SiteName: "User Management System",
|
||||
}
|
||||
}
|
||||
|
||||
type EmailCodeService struct {
|
||||
provider EmailProvider
|
||||
cache cacheInterface
|
||||
cfg EmailCodeConfig
|
||||
}
|
||||
|
||||
func NewEmailCodeService(provider EmailProvider, cache cacheInterface, cfg EmailCodeConfig) *EmailCodeService {
|
||||
if cfg.CodeTTL <= 0 {
|
||||
cfg.CodeTTL = 5 * time.Minute
|
||||
}
|
||||
if cfg.ResendCooldown <= 0 {
|
||||
cfg.ResendCooldown = time.Minute
|
||||
}
|
||||
if cfg.MaxDailyLimit <= 0 {
|
||||
cfg.MaxDailyLimit = 10
|
||||
}
|
||||
return &EmailCodeService{
|
||||
provider: provider,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EmailCodeService) SendEmailCode(ctx context.Context, email, purpose string) error {
|
||||
cooldownKey := fmt.Sprintf("email_cooldown:%s:%s", purpose, email)
|
||||
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
|
||||
return newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
|
||||
}
|
||||
|
||||
dailyKey := fmt.Sprintf("email_daily:%s:%s", email, time.Now().Format("2006-01-02"))
|
||||
var dailyCount int
|
||||
if value, ok := s.cache.Get(ctx, dailyKey); ok {
|
||||
if count, ok := intValue(value); ok {
|
||||
dailyCount = count
|
||||
}
|
||||
}
|
||||
if dailyCount >= s.cfg.MaxDailyLimit {
|
||||
return newRateLimitError("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff0c\u8bf7\u660e\u5929\u518d\u8bd5")
|
||||
}
|
||||
|
||||
code, err := generateEmailCode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
codeKey := fmt.Sprintf("email_code:%s:%s", purpose, email)
|
||||
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
|
||||
return fmt.Errorf("store email code failed: %w", err)
|
||||
}
|
||||
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
return fmt.Errorf("store email cooldown failed: %w", err)
|
||||
}
|
||||
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
_ = s.cache.Delete(ctx, cooldownKey)
|
||||
return fmt.Errorf("store email daily counter failed: %w", err)
|
||||
}
|
||||
|
||||
subject, body := buildEmailCodeContent(purpose, code, s.cfg.SiteName, s.cfg.CodeTTL)
|
||||
if err := s.provider.SendMail(ctx, email, subject, body); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
_ = s.cache.Delete(ctx, cooldownKey)
|
||||
return fmt.Errorf("email delivery failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *EmailCodeService) VerifyEmailCode(ctx context.Context, email, purpose, code string) error {
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return fmt.Errorf("verification code is required")
|
||||
}
|
||||
|
||||
codeKey := fmt.Sprintf("email_code:%s:%s", purpose, email)
|
||||
value, ok := s.cache.Get(ctx, codeKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("verification code expired or missing")
|
||||
}
|
||||
|
||||
storedCode, ok := value.(string)
|
||||
if !ok || storedCode != code {
|
||||
return fmt.Errorf("verification code is invalid")
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(ctx, codeKey); err != nil {
|
||||
return fmt.Errorf("consume email code failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmailActivationService struct {
|
||||
provider EmailProvider
|
||||
cache cacheInterface
|
||||
tokenTTL time.Duration
|
||||
siteURL string
|
||||
siteName string
|
||||
}
|
||||
|
||||
func NewEmailActivationService(provider EmailProvider, cache cacheInterface, siteURL, siteName string) *EmailActivationService {
|
||||
return &EmailActivationService{
|
||||
provider: provider,
|
||||
cache: cache,
|
||||
tokenTTL: 24 * time.Hour,
|
||||
siteURL: siteURL,
|
||||
siteName: siteName,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *EmailActivationService) SendActivationEmail(ctx context.Context, userID int64, email, username string) error {
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := cryptorand.Read(tokenBytes); err != nil {
|
||||
return fmt.Errorf("generate activation token failed: %w", err)
|
||||
}
|
||||
token := hex.EncodeToString(tokenBytes)
|
||||
|
||||
cacheKey := fmt.Sprintf("email_activation:%s", token)
|
||||
if err := s.cache.Set(ctx, cacheKey, userID, s.tokenTTL, s.tokenTTL); err != nil {
|
||||
return fmt.Errorf("store activation token failed: %w", err)
|
||||
}
|
||||
|
||||
activationURL := buildFrontendActivationURL(s.siteURL, token)
|
||||
subject := fmt.Sprintf("[%s] Activate Your Account", s.siteName)
|
||||
body := buildActivationEmailBody(username, activationURL, s.siteName, s.tokenTTL)
|
||||
return s.provider.SendMail(ctx, email, subject, body)
|
||||
}
|
||||
|
||||
func buildFrontendActivationURL(siteURL, token string) string {
|
||||
base := strings.TrimRight(strings.TrimSpace(siteURL), "/")
|
||||
if base == "" {
|
||||
base = DefaultEmailCodeConfig().SiteURL
|
||||
}
|
||||
return fmt.Sprintf("%s/activate-account?token=%s", base, url.QueryEscape(token))
|
||||
}
|
||||
|
||||
func (s *EmailActivationService) ValidateActivationToken(ctx context.Context, token string) (int64, error) {
|
||||
token = strings.TrimSpace(token)
|
||||
if token == "" {
|
||||
return 0, fmt.Errorf("activation token is required")
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("email_activation:%s", token)
|
||||
value, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("activation token expired or missing")
|
||||
}
|
||||
|
||||
userID, ok := int64Value(value)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("activation token payload is invalid")
|
||||
}
|
||||
if err := s.cache.Delete(ctx, cacheKey); err != nil {
|
||||
return 0, fmt.Errorf("consume activation token failed: %w", err)
|
||||
}
|
||||
|
||||
return userID, nil
|
||||
}
|
||||
|
||||
func buildEmailCodeContent(purpose, code, siteName string, ttl time.Duration) (subject, body string) {
|
||||
purposeText := map[string]string{
|
||||
"login": "login verification",
|
||||
"register": "registration verification",
|
||||
"reset": "password reset",
|
||||
"bind": "binding verification",
|
||||
}
|
||||
label := purposeText[purpose]
|
||||
if label == "" {
|
||||
label = "identity verification"
|
||||
}
|
||||
|
||||
subject = fmt.Sprintf("[%s] Your %s code: %s", siteName, label, code)
|
||||
body = fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html>
|
||||
<body style="font-family:Arial,sans-serif;max-width:600px;margin:0 auto;padding:20px;">
|
||||
<h2 style="color:#333;">%s</h2>
|
||||
<p>Your %s code is:</p>
|
||||
<div style="background:#f5f5f5;padding:20px;text-align:center;margin:20px 0;border-radius:8px;">
|
||||
<span style="font-size:36px;font-weight:bold;color:#2563eb;letter-spacing:8px;">%s</span>
|
||||
</div>
|
||||
<p>This code expires in <strong>%d minutes</strong>.</p>
|
||||
<p style="color:#999;font-size:12px;">If you did not request this code, you can ignore this email.</p>
|
||||
</body>
|
||||
</html>`, siteName, label, code, int(ttl.Minutes()))
|
||||
return subject, body
|
||||
}
|
||||
|
||||
func buildActivationEmailBody(username, activationURL, siteName string, ttl time.Duration) string {
|
||||
return fmt.Sprintf(`<!DOCTYPE html>
|
||||
<html>
|
||||
<body style="font-family:Arial,sans-serif;max-width:600px;margin:0 auto;padding:20px;">
|
||||
<h2 style="color:#333;">Welcome to %s</h2>
|
||||
<p>Hello <strong>%s</strong>,</p>
|
||||
<p>Please click the button below to activate your account.</p>
|
||||
<div style="text-align:center;margin:30px 0;">
|
||||
<a href="%s"
|
||||
style="background:#2563eb;color:#fff;padding:14px 32px;text-decoration:none;border-radius:8px;font-size:16px;font-weight:bold;">
|
||||
Activate Account
|
||||
</a>
|
||||
</div>
|
||||
<p>If the button does not work, copy this link into your browser:</p>
|
||||
<p style="word-break:break-all;color:#2563eb;">%s</p>
|
||||
<p>This link expires in <strong>%d hours</strong>.</p>
|
||||
</body>
|
||||
</html>`, siteName, username, activationURL, activationURL, int(ttl.Hours()))
|
||||
}
|
||||
|
||||
func generateEmailCode() (string, error) {
|
||||
buffer := make([]byte, 3)
|
||||
if _, err := cryptorand.Read(buffer); err != nil {
|
||||
return "", fmt.Errorf("generate email code failed: %w", err)
|
||||
}
|
||||
|
||||
value := int(buffer[0])<<16 | int(buffer[1])<<8 | int(buffer[2])
|
||||
value = value % 1000000
|
||||
if value < 100000 {
|
||||
value += 100000
|
||||
}
|
||||
return fmt.Sprintf("%06d", value), nil
|
||||
}
|
||||
534
internal/service/export.go
Normal file
534
internal/service/export.go
Normal file
@@ -0,0 +1,534 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/xuri/excelize/v2"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
ExportFormatCSV = "csv"
|
||||
ExportFormatXLSX = "xlsx"
|
||||
)
|
||||
|
||||
// ExportUsersRequest defines the supported export filters and output options.
|
||||
type ExportUsersRequest struct {
|
||||
Format string
|
||||
Fields []string
|
||||
Keyword string
|
||||
Status *int
|
||||
}
|
||||
|
||||
type exportColumn struct {
|
||||
Key string
|
||||
Header string
|
||||
Value func(*domain.User) string
|
||||
}
|
||||
|
||||
var defaultExportColumns = []exportColumn{
|
||||
{Key: "id", Header: "ID", Value: func(u *domain.User) string { return fmt.Sprintf("%d", u.ID) }},
|
||||
{Key: "username", Header: "用户名", Value: func(u *domain.User) string { return u.Username }},
|
||||
{Key: "email", Header: "邮箱", Value: func(u *domain.User) string { return domain.DerefStr(u.Email) }},
|
||||
{Key: "phone", Header: "手机号", Value: func(u *domain.User) string { return domain.DerefStr(u.Phone) }},
|
||||
{Key: "nickname", Header: "昵称", Value: func(u *domain.User) string { return u.Nickname }},
|
||||
{Key: "avatar", Header: "头像", Value: func(u *domain.User) string { return u.Avatar }},
|
||||
{Key: "gender", Header: "性别", Value: func(u *domain.User) string { return genderLabel(u.Gender) }},
|
||||
{Key: "status", Header: "状态", Value: func(u *domain.User) string { return userStatusLabel(u.Status) }},
|
||||
{Key: "region", Header: "地区", Value: func(u *domain.User) string { return u.Region }},
|
||||
{Key: "bio", Header: "个人简介", Value: func(u *domain.User) string { return u.Bio }},
|
||||
{Key: "totp_enabled", Header: "TOTP已启用", Value: func(u *domain.User) string { return boolLabel(u.TOTPEnabled) }},
|
||||
{Key: "last_login_time", Header: "最后登录时间", Value: func(u *domain.User) string { return timeLabel(u.LastLoginTime) }},
|
||||
{Key: "last_login_ip", Header: "最后登录IP", Value: func(u *domain.User) string { return u.LastLoginIP }},
|
||||
{Key: "created_at", Header: "注册时间", Value: func(u *domain.User) string { return u.CreatedAt.Format("2006-01-02 15:04:05") }},
|
||||
}
|
||||
|
||||
// ExportService 用户数据导入导出服务
|
||||
type ExportService struct {
|
||||
userRepo *repository.UserRepository
|
||||
roleRepo *repository.RoleRepository
|
||||
}
|
||||
|
||||
// NewExportService 创建导入导出服务
|
||||
func NewExportService(
|
||||
userRepo *repository.UserRepository,
|
||||
roleRepo *repository.RoleRepository,
|
||||
) *ExportService {
|
||||
return &ExportService{
|
||||
userRepo: userRepo,
|
||||
roleRepo: roleRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ExportUsers exports users as CSV or XLSX.
|
||||
func (s *ExportService) ExportUsers(ctx context.Context, req *ExportUsersRequest) ([]byte, string, string, error) {
|
||||
if req == nil {
|
||||
req = &ExportUsersRequest{}
|
||||
}
|
||||
|
||||
format, err := normalizeExportFormat(req.Format)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
columns, err := resolveExportColumns(req.Fields)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
users, err := s.listUsersForExport(ctx, req)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("users_%s.%s", time.Now().Format("20060102_150405"), format)
|
||||
switch format {
|
||||
case ExportFormatCSV:
|
||||
data, err := buildCSVExport(columns, users)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
case ExportFormatXLSX:
|
||||
data, err := buildXLSXExport(columns, users)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
default:
|
||||
return nil, "", "", fmt.Errorf("不支持的导出格式: %s", req.Format)
|
||||
}
|
||||
}
|
||||
|
||||
// ExportUsersCSV keeps backward compatibility for callers that still expect CSV-only export.
|
||||
func (s *ExportService) ExportUsersCSV(ctx context.Context) ([]byte, string, error) {
|
||||
data, filename, _, err := s.ExportUsers(ctx, &ExportUsersRequest{Format: ExportFormatCSV})
|
||||
return data, filename, err
|
||||
}
|
||||
|
||||
// ExportUsersXLSX exports users as Excel.
|
||||
func (s *ExportService) ExportUsersXLSX(ctx context.Context) ([]byte, string, error) {
|
||||
data, filename, _, err := s.ExportUsers(ctx, &ExportUsersRequest{Format: ExportFormatXLSX})
|
||||
return data, filename, err
|
||||
}
|
||||
|
||||
func (s *ExportService) listUsersForExport(ctx context.Context, req *ExportUsersRequest) ([]*domain.User, error) {
|
||||
var allUsers []*domain.User
|
||||
offset := 0
|
||||
batchSize := 500
|
||||
|
||||
for {
|
||||
var (
|
||||
users []*domain.User
|
||||
total int64
|
||||
err error
|
||||
)
|
||||
|
||||
if req.Keyword != "" || req.Status != nil {
|
||||
filter := &repository.AdvancedFilter{
|
||||
Keyword: req.Keyword,
|
||||
Status: -1,
|
||||
SortBy: "created_at",
|
||||
SortOrder: "desc",
|
||||
Offset: offset,
|
||||
Limit: batchSize,
|
||||
}
|
||||
if req.Status != nil {
|
||||
filter.Status = *req.Status
|
||||
}
|
||||
users, total, err = s.userRepo.AdvancedSearch(ctx, filter)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
allUsers = append(allUsers, users...)
|
||||
offset += len(users)
|
||||
if offset >= int(total) || len(users) == 0 {
|
||||
break
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
users, _, err = s.userRepo.List(ctx, offset, batchSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("查询用户失败: %w", err)
|
||||
}
|
||||
allUsers = append(allUsers, users...)
|
||||
if len(users) < batchSize {
|
||||
break
|
||||
}
|
||||
offset += batchSize
|
||||
}
|
||||
|
||||
return allUsers, nil
|
||||
}
|
||||
|
||||
// ImportUsers imports users from CSV or XLSX.
|
||||
func (s *ExportService) ImportUsers(ctx context.Context, data []byte, format string) (successCount, failCount int, errs []string) {
|
||||
normalized, err := normalizeExportFormat(format)
|
||||
if err != nil {
|
||||
return 0, 0, []string{err.Error()}
|
||||
}
|
||||
|
||||
var records [][]string
|
||||
switch normalized {
|
||||
case ExportFormatCSV:
|
||||
records, err = parseCSVRecords(data)
|
||||
case ExportFormatXLSX:
|
||||
records, err = parseXLSXRecords(data)
|
||||
default:
|
||||
err = fmt.Errorf("不支持的导入格式: %s", format)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, 0, []string{err.Error()}
|
||||
}
|
||||
|
||||
return s.importUsersRecords(ctx, records)
|
||||
}
|
||||
|
||||
// ImportUsersCSV keeps backward compatibility for callers that still upload CSV.
|
||||
func (s *ExportService) ImportUsersCSV(ctx context.Context, data []byte) (successCount, failCount int, errs []string) {
|
||||
return s.ImportUsers(ctx, data, ExportFormatCSV)
|
||||
}
|
||||
|
||||
// ImportUsersXLSX imports users from Excel.
|
||||
func (s *ExportService) ImportUsersXLSX(ctx context.Context, data []byte) (successCount, failCount int, errs []string) {
|
||||
return s.ImportUsers(ctx, data, ExportFormatXLSX)
|
||||
}
|
||||
|
||||
func (s *ExportService) importUsersRecords(ctx context.Context, records [][]string) (successCount, failCount int, errs []string) {
|
||||
if len(records) < 2 {
|
||||
return 0, 0, []string{"导入文件为空或没有数据行"}
|
||||
}
|
||||
|
||||
headers := records[0]
|
||||
colIdx := buildColIndex(headers)
|
||||
getCol := func(row []string, name string) string {
|
||||
idx, ok := colIdx[name]
|
||||
if !ok || idx >= len(row) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(row[idx])
|
||||
}
|
||||
|
||||
for i, row := range records[1:] {
|
||||
lineNum := i + 2
|
||||
username := getCol(row, "用户名")
|
||||
password := getCol(row, "密码")
|
||||
|
||||
if username == "" || password == "" {
|
||||
failCount++
|
||||
errs = append(errs, fmt.Sprintf("第%d行:用户名和密码不能为空", lineNum))
|
||||
continue
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByUsername(ctx, username)
|
||||
if err != nil {
|
||||
failCount++
|
||||
errs = append(errs, fmt.Sprintf("第%d行:检查用户名失败: %v", lineNum, err))
|
||||
continue
|
||||
}
|
||||
if exists {
|
||||
failCount++
|
||||
errs = append(errs, fmt.Sprintf("第%d行:用户名 '%s' 已存在", lineNum, username))
|
||||
continue
|
||||
}
|
||||
|
||||
hashedPwd, err := hashPassword(password)
|
||||
if err != nil {
|
||||
failCount++
|
||||
errs = append(errs, fmt.Sprintf("第%d行:密码加密失败: %v", lineNum, err))
|
||||
continue
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
Username: username,
|
||||
Email: domain.StrPtr(getCol(row, "邮箱")),
|
||||
Phone: domain.StrPtr(getCol(row, "手机号")),
|
||||
Nickname: getCol(row, "昵称"),
|
||||
Password: hashedPwd,
|
||||
Region: getCol(row, "地区"),
|
||||
Bio: getCol(row, "个人简介"),
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
failCount++
|
||||
errs = append(errs, fmt.Sprintf("第%d行:创建用户失败: %v", lineNum, err))
|
||||
continue
|
||||
}
|
||||
successCount++
|
||||
}
|
||||
|
||||
return successCount, failCount, errs
|
||||
}
|
||||
|
||||
// GetImportTemplate keeps backward compatibility for callers that still expect CSV templates.
|
||||
func (s *ExportService) GetImportTemplate() ([]byte, string) {
|
||||
data, filename, _, _ := s.GetImportTemplateByFormat(ExportFormatCSV)
|
||||
return data, filename
|
||||
}
|
||||
|
||||
// GetImportTemplateByFormat returns a CSV or XLSX template for imports.
|
||||
func (s *ExportService) GetImportTemplateByFormat(format string) ([]byte, string, string, error) {
|
||||
normalized, err := normalizeExportFormat(format)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
headers := []string{"用户名", "密码", "邮箱", "手机号", "昵称", "性别", "地区", "个人简介"}
|
||||
rows := [][]string{{
|
||||
"john_doe", "Password123!", "john@example.com", "13800138000",
|
||||
"约翰", "男", "北京", "这是个人简介",
|
||||
}}
|
||||
|
||||
switch normalized {
|
||||
case ExportFormatCSV:
|
||||
data, err := buildCSVRecords(headers, rows)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, "user_import_template.csv", "text/csv; charset=utf-8", nil
|
||||
case ExportFormatXLSX:
|
||||
data, err := buildXLSXRecords(headers, rows)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, "user_import_template.xlsx", "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
default:
|
||||
return nil, "", "", fmt.Errorf("不支持的模板格式: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeExportFormat(format string) (string, error) {
|
||||
normalized := strings.ToLower(strings.TrimSpace(format))
|
||||
if normalized == "" {
|
||||
normalized = ExportFormatCSV
|
||||
}
|
||||
switch normalized {
|
||||
case ExportFormatCSV, ExportFormatXLSX:
|
||||
return normalized, nil
|
||||
default:
|
||||
return "", fmt.Errorf("不支持的格式: %s", format)
|
||||
}
|
||||
}
|
||||
|
||||
func resolveExportColumns(fields []string) ([]exportColumn, error) {
|
||||
if len(fields) == 0 {
|
||||
return defaultExportColumns, nil
|
||||
}
|
||||
|
||||
columnMap := make(map[string]exportColumn, len(defaultExportColumns))
|
||||
for _, col := range defaultExportColumns {
|
||||
columnMap[col.Key] = col
|
||||
}
|
||||
|
||||
selected := make([]exportColumn, 0, len(fields))
|
||||
seen := make(map[string]struct{}, len(fields))
|
||||
for _, field := range fields {
|
||||
key := strings.ToLower(strings.TrimSpace(field))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[key]; ok {
|
||||
continue
|
||||
}
|
||||
col, ok := columnMap[key]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("不支持的导出字段: %s", field)
|
||||
}
|
||||
selected = append(selected, col)
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
|
||||
if len(selected) == 0 {
|
||||
return defaultExportColumns, nil
|
||||
}
|
||||
|
||||
return selected, nil
|
||||
}
|
||||
|
||||
func buildCSVExport(columns []exportColumn, users []*domain.User) ([]byte, error) {
|
||||
headers := make([]string, 0, len(columns))
|
||||
rows := make([][]string, 0, len(users))
|
||||
for _, col := range columns {
|
||||
headers = append(headers, col.Header)
|
||||
}
|
||||
for _, u := range users {
|
||||
row := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
row = append(row, col.Value(u))
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
return buildCSVRecords(headers, rows)
|
||||
}
|
||||
|
||||
func buildCSVRecords(headers []string, rows [][]string) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
writer := csv.NewWriter(&buf)
|
||||
|
||||
if err := writer.Write(headers); err != nil {
|
||||
return nil, fmt.Errorf("写CSV表头失败: %w", err)
|
||||
}
|
||||
for _, row := range rows {
|
||||
if err := writer.Write(row); err != nil {
|
||||
return nil, fmt.Errorf("写CSV行失败: %w", err)
|
||||
}
|
||||
}
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
return nil, fmt.Errorf("CSV Flush 失败: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func buildXLSXExport(columns []exportColumn, users []*domain.User) ([]byte, error) {
|
||||
headers := make([]string, 0, len(columns))
|
||||
rows := make([][]string, 0, len(users))
|
||||
for _, col := range columns {
|
||||
headers = append(headers, col.Header)
|
||||
}
|
||||
for _, u := range users {
|
||||
row := make([]string, 0, len(columns))
|
||||
for _, col := range columns {
|
||||
row = append(row, col.Value(u))
|
||||
}
|
||||
rows = append(rows, row)
|
||||
}
|
||||
return buildXLSXRecords(headers, rows)
|
||||
}
|
||||
|
||||
func buildXLSXRecords(headers []string, rows [][]string) ([]byte, error) {
|
||||
file := excelize.NewFile()
|
||||
defer file.Close()
|
||||
|
||||
sheet := file.GetSheetName(file.GetActiveSheetIndex())
|
||||
if sheet == "" {
|
||||
sheet = "Sheet1"
|
||||
}
|
||||
|
||||
for idx, header := range headers {
|
||||
cell, err := excelize.CoordinatesToCellName(idx+1, 1)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成表头单元格失败: %w", err)
|
||||
}
|
||||
if err := file.SetCellValue(sheet, cell, header); err != nil {
|
||||
return nil, fmt.Errorf("写入表头失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
for rowIdx, row := range rows {
|
||||
for colIdx, value := range row {
|
||||
cell, err := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成数据单元格失败: %w", err)
|
||||
}
|
||||
if err := file.SetCellValue(sheet, cell, value); err != nil {
|
||||
return nil, fmt.Errorf("写入单元格失败: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := file.WriteTo(&buf); err != nil {
|
||||
return nil, fmt.Errorf("生成Excel失败: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func parseCSVRecords(data []byte) ([][]string, error) {
|
||||
if len(data) >= 3 && data[0] == 0xEF && data[1] == 0xBB && data[2] == 0xBF {
|
||||
data = data[3:]
|
||||
}
|
||||
|
||||
reader := csv.NewReader(bytes.NewReader(data))
|
||||
records, err := reader.ReadAll()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CSV 解析失败: %w", err)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func parseXLSXRecords(data []byte) ([][]string, error) {
|
||||
file, err := excelize.OpenReader(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Excel 解析失败: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
sheets := file.GetSheetList()
|
||||
if len(sheets) == 0 {
|
||||
return nil, fmt.Errorf("Excel 文件没有可用工作表")
|
||||
}
|
||||
|
||||
rows, err := file.GetRows(sheets[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取Excel行失败: %w", err)
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
// ---- 辅助函数 ----
|
||||
|
||||
func genderLabel(g domain.Gender) string {
|
||||
switch g {
|
||||
case domain.GenderMale:
|
||||
return "男"
|
||||
case domain.GenderFemale:
|
||||
return "女"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
func userStatusLabel(s domain.UserStatus) string {
|
||||
switch s {
|
||||
case domain.UserStatusActive:
|
||||
return "已激活"
|
||||
case domain.UserStatusInactive:
|
||||
return "未激活"
|
||||
case domain.UserStatusLocked:
|
||||
return "已锁定"
|
||||
case domain.UserStatusDisabled:
|
||||
return "已禁用"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
func boolLabel(b bool) string {
|
||||
if b {
|
||||
return "是"
|
||||
}
|
||||
return "否"
|
||||
}
|
||||
|
||||
func timeLabel(t *time.Time) string {
|
||||
if t == nil {
|
||||
return ""
|
||||
}
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// buildColIndex 将表头列名映射到列索引
|
||||
func buildColIndex(headers []string) map[string]int {
|
||||
idx := make(map[string]int, len(headers))
|
||||
for i, h := range headers {
|
||||
idx[h] = i
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
// hashPassword hashes imported passwords with the primary runtime algorithm.
|
||||
func hashPassword(password string) (string, error) {
|
||||
return auth.HashPassword(password)
|
||||
}
|
||||
157
internal/service/header_util.go
Normal file
157
internal/service/header_util.go
Normal file
@@ -0,0 +1,157 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// headerWireCasing 定义每个白名单 header 在真实 Claude CLI 抓包中的准确大小写。
|
||||
// Go 的 HTTP server 解析请求时会将所有 header key 转为 Canonical 形式(如 x-app → X-App),
|
||||
// 此 map 用于在转发时恢复到真实的 wire format。
|
||||
//
|
||||
// 来源:对真实 Claude CLI (claude-cli/2.1.81) 到 api.anthropic.com 的 HTTPS 流量抓包。
|
||||
var headerWireCasing = map[string]string{
|
||||
// Title case
|
||||
"accept": "Accept",
|
||||
"user-agent": "User-Agent",
|
||||
|
||||
// X-Stainless-* 保持 SDK 原始大小写
|
||||
"x-stainless-retry-count": "X-Stainless-Retry-Count",
|
||||
"x-stainless-timeout": "X-Stainless-Timeout",
|
||||
"x-stainless-lang": "X-Stainless-Lang",
|
||||
"x-stainless-package-version": "X-Stainless-Package-Version",
|
||||
"x-stainless-os": "X-Stainless-OS",
|
||||
"x-stainless-arch": "X-Stainless-Arch",
|
||||
"x-stainless-runtime": "X-Stainless-Runtime",
|
||||
"x-stainless-runtime-version": "X-Stainless-Runtime-Version",
|
||||
"x-stainless-helper-method": "x-stainless-helper-method",
|
||||
|
||||
// Anthropic SDK 自身设置的 header,全小写
|
||||
"anthropic-dangerous-direct-browser-access": "anthropic-dangerous-direct-browser-access",
|
||||
"anthropic-version": "anthropic-version",
|
||||
"anthropic-beta": "anthropic-beta",
|
||||
"x-app": "x-app",
|
||||
"content-type": "content-type",
|
||||
"accept-language": "accept-language",
|
||||
"sec-fetch-mode": "sec-fetch-mode",
|
||||
"accept-encoding": "accept-encoding",
|
||||
"authorization": "authorization",
|
||||
}
|
||||
|
||||
// headerWireOrder 定义真实 Claude CLI 发送 header 的顺序(基于抓包)。
|
||||
// 用于 debug log 按此顺序输出,便于与抓包结果直接对比。
|
||||
var headerWireOrder = []string{
|
||||
"Accept",
|
||||
"X-Stainless-Retry-Count",
|
||||
"X-Stainless-Timeout",
|
||||
"X-Stainless-Lang",
|
||||
"X-Stainless-Package-Version",
|
||||
"X-Stainless-OS",
|
||||
"X-Stainless-Arch",
|
||||
"X-Stainless-Runtime",
|
||||
"X-Stainless-Runtime-Version",
|
||||
"anthropic-dangerous-direct-browser-access",
|
||||
"anthropic-version",
|
||||
"authorization",
|
||||
"x-app",
|
||||
"User-Agent",
|
||||
"content-type",
|
||||
"anthropic-beta",
|
||||
"accept-language",
|
||||
"sec-fetch-mode",
|
||||
"accept-encoding",
|
||||
"x-stainless-helper-method",
|
||||
}
|
||||
|
||||
// headerWireOrderSet 用于快速判断某个 key 是否在 headerWireOrder 中(按 lowercase 匹配)。
|
||||
var headerWireOrderSet map[string]struct{}
|
||||
|
||||
func init() {
|
||||
headerWireOrderSet = make(map[string]struct{}, len(headerWireOrder))
|
||||
for _, k := range headerWireOrder {
|
||||
headerWireOrderSet[strings.ToLower(k)] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// resolveWireCasing 将 Go canonical key(如 X-Stainless-Os)映射为真实 wire casing(如 X-Stainless-OS)。
|
||||
// 如果 map 中没有对应条目,返回原始 key 不变。
|
||||
func resolveWireCasing(key string) string {
|
||||
if wk, ok := headerWireCasing[strings.ToLower(key)]; ok {
|
||||
return wk
|
||||
}
|
||||
return key
|
||||
}
|
||||
|
||||
// setHeaderRaw sets a header bypassing Go's canonical-case normalization.
|
||||
// The key is stored exactly as provided, preserving original casing.
|
||||
//
|
||||
// It first removes any existing value under the canonical key, the wire casing key,
|
||||
// and the exact raw key, preventing duplicates from any source.
|
||||
func setHeaderRaw(h http.Header, key, value string) {
|
||||
h.Del(key) // remove canonical form (e.g. "Anthropic-Beta")
|
||||
if wk := resolveWireCasing(key); wk != key {
|
||||
delete(h, wk) // remove wire casing form if different
|
||||
}
|
||||
delete(h, key) // remove exact raw key if it differs from canonical
|
||||
h[key] = []string{value}
|
||||
}
|
||||
|
||||
// addHeaderRaw appends a header value bypassing Go's canonical-case normalization.
|
||||
func addHeaderRaw(h http.Header, key, value string) {
|
||||
h[key] = append(h[key], value)
|
||||
}
|
||||
|
||||
// getHeaderRaw reads a header value, trying multiple key forms to handle the mismatch
|
||||
// between Go canonical keys, wire casing keys, and raw keys:
|
||||
// 1. exact key as provided
|
||||
// 2. wire casing form (from headerWireCasing)
|
||||
// 3. Go canonical form (via http.Header.Get)
|
||||
func getHeaderRaw(h http.Header, key string) string {
|
||||
// 1. exact key
|
||||
if vals := h[key]; len(vals) > 0 {
|
||||
return vals[0]
|
||||
}
|
||||
// 2. wire casing (e.g. looking up "Anthropic-Dangerous-Direct-Browser-Access" finds "anthropic-dangerous-direct-browser-access")
|
||||
if wk := resolveWireCasing(key); wk != key {
|
||||
if vals := h[wk]; len(vals) > 0 {
|
||||
return vals[0]
|
||||
}
|
||||
}
|
||||
// 3. canonical fallback
|
||||
return h.Get(key)
|
||||
}
|
||||
|
||||
// sortHeadersByWireOrder 按照真实 Claude CLI 的 header 顺序返回排序后的 key 列表。
|
||||
// 在 headerWireOrder 中定义的 key 按其顺序排列,未定义的 key 追加到末尾。
|
||||
func sortHeadersByWireOrder(h http.Header) []string {
|
||||
// 构建 lowercase -> actual map key 的映射
|
||||
present := make(map[string]string, len(h))
|
||||
for k := range h {
|
||||
present[strings.ToLower(k)] = k
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(h))
|
||||
seen := make(map[string]struct{}, len(h))
|
||||
|
||||
// 先按 wire order 输出
|
||||
for _, wk := range headerWireOrder {
|
||||
lk := strings.ToLower(wk)
|
||||
if actual, ok := present[lk]; ok {
|
||||
if _, dup := seen[lk]; !dup {
|
||||
result = append(result, actual)
|
||||
seen[lk] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 再追加不在 wire order 中的 header
|
||||
for k := range h {
|
||||
lk := strings.ToLower(k)
|
||||
if _, ok := seen[lk]; !ok {
|
||||
result = append(result, k)
|
||||
seen[lk] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
257
internal/service/login_log.go
Normal file
257
internal/service/login_log.go
Normal file
@@ -0,0 +1,257 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/xuri/excelize/v2"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// LoginLogService 登录日志服务
|
||||
type LoginLogService struct {
|
||||
loginLogRepo *repository.LoginLogRepository
|
||||
}
|
||||
|
||||
// NewLoginLogService 创建登录日志服务
|
||||
func NewLoginLogService(loginLogRepo *repository.LoginLogRepository) *LoginLogService {
|
||||
return &LoginLogService{loginLogRepo: loginLogRepo}
|
||||
}
|
||||
|
||||
// RecordLogin 记录登录日志
|
||||
func (s *LoginLogService) RecordLogin(ctx context.Context, req *RecordLoginRequest) error {
|
||||
log := &domain.LoginLog{
|
||||
LoginType: req.LoginType,
|
||||
DeviceID: req.DeviceID,
|
||||
IP: req.IP,
|
||||
Location: req.Location,
|
||||
Status: req.Status,
|
||||
FailReason: req.FailReason,
|
||||
}
|
||||
if req.UserID != 0 {
|
||||
log.UserID = &req.UserID
|
||||
}
|
||||
return s.loginLogRepo.Create(ctx, log)
|
||||
}
|
||||
|
||||
// RecordLoginRequest 记录登录请求
|
||||
type RecordLoginRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
LoginType int `json:"login_type"` // 1-用户名, 2-邮箱, 3-手机
|
||||
DeviceID string `json:"device_id"`
|
||||
IP string `json:"ip"`
|
||||
Location string `json:"location"`
|
||||
Status int `json:"status"` // 0-失败, 1-成功
|
||||
FailReason string `json:"fail_reason"`
|
||||
}
|
||||
|
||||
// ListLoginLogRequest 登录日志列表请求
|
||||
type ListLoginLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Status int `json:"status"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
}
|
||||
|
||||
// GetLoginLogs 获取登录日志列表
|
||||
func (s *LoginLogService) GetLoginLogs(ctx context.Context, req *ListLoginLogRequest) ([]*domain.LoginLog, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
// 按用户 ID 查询
|
||||
if req.UserID > 0 {
|
||||
return s.loginLogRepo.ListByUserID(ctx, req.UserID, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// 按时间范围查询
|
||||
if req.StartAt != "" && req.EndAt != "" {
|
||||
start, err1 := time.Parse(time.RFC3339, req.StartAt)
|
||||
end, err2 := time.Parse(time.RFC3339, req.EndAt)
|
||||
if err1 == nil && err2 == nil {
|
||||
return s.loginLogRepo.ListByTimeRange(ctx, start, end, offset, req.PageSize)
|
||||
}
|
||||
}
|
||||
|
||||
// 按状态查询
|
||||
if req.Status == 0 || req.Status == 1 {
|
||||
return s.loginLogRepo.ListByStatus(ctx, req.Status, offset, req.PageSize)
|
||||
}
|
||||
|
||||
return s.loginLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// GetMyLoginLogs 获取当前用户的登录日志
|
||||
func (s *LoginLogService) GetMyLoginLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.LoginLog, int64, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
return s.loginLogRepo.ListByUserID(ctx, userID, offset, pageSize)
|
||||
}
|
||||
|
||||
// CleanupOldLogs 清理旧日志(保留最近 N 天)
|
||||
func (s *LoginLogService) CleanupOldLogs(ctx context.Context, retentionDays int) error {
|
||||
return s.loginLogRepo.DeleteOlderThan(ctx, retentionDays)
|
||||
}
|
||||
|
||||
// ExportLoginLogRequest 导出登录日志请求
|
||||
type ExportLoginLogRequest struct {
|
||||
UserID int64 `form:"user_id"`
|
||||
Status int `form:"status"`
|
||||
Format string `form:"format"`
|
||||
StartAt string `form:"start_at"`
|
||||
EndAt string `form:"end_at"`
|
||||
}
|
||||
|
||||
// ExportLoginLogs 导出登录日志
|
||||
func (s *LoginLogService) ExportLoginLogs(ctx context.Context, req *ExportLoginLogRequest) ([]byte, string, string, error) {
|
||||
format := "csv"
|
||||
if req.Format == "xlsx" {
|
||||
format = "xlsx"
|
||||
}
|
||||
|
||||
var startAt, endAt *time.Time
|
||||
if req.StartAt != "" {
|
||||
if t, err := time.Parse(time.RFC3339, req.StartAt); err == nil {
|
||||
startAt = &t
|
||||
}
|
||||
}
|
||||
if req.EndAt != "" {
|
||||
if t, err := time.Parse(time.RFC3339, req.EndAt); err == nil {
|
||||
endAt = &t
|
||||
}
|
||||
}
|
||||
|
||||
logs, err := s.loginLogRepo.ListAllForExport(ctx, req.UserID, req.Status, startAt, endAt)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("查询登录日志失败: %w", err)
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("login_logs_%s.%s", time.Now().Format("20060102_150405"), format)
|
||||
|
||||
if format == "xlsx" {
|
||||
data, err := buildLoginLogXLSXExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", nil
|
||||
}
|
||||
|
||||
data, err := buildLoginLogCSVExport(logs)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
return data, filename, "text/csv; charset=utf-8", nil
|
||||
}
|
||||
|
||||
func buildLoginLogCSVExport(logs []*domain.LoginLog) ([]byte, error) {
|
||||
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
|
||||
rows := make([][]string, 0, len(logs)+1)
|
||||
rows = append(rows, headers)
|
||||
|
||||
for _, log := range logs {
|
||||
rows = append(rows, []string{
|
||||
fmt.Sprintf("%d", log.ID),
|
||||
fmt.Sprintf("%d", derefInt64(log.UserID)),
|
||||
loginTypeLabel(log.LoginType),
|
||||
log.DeviceID,
|
||||
log.IP,
|
||||
log.Location,
|
||||
loginStatusLabel(log.Status),
|
||||
log.FailReason,
|
||||
log.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
})
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write([]byte{0xEF, 0xBB, 0xBF})
|
||||
writer := csv.NewWriter(&buf)
|
||||
if err := writer.WriteAll(rows); err != nil {
|
||||
return nil, fmt.Errorf("写CSV失败: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func buildLoginLogXLSXExport(logs []*domain.LoginLog) ([]byte, error) {
|
||||
file := excelize.NewFile()
|
||||
defer file.Close()
|
||||
|
||||
sheet := file.GetSheetName(file.GetActiveSheetIndex())
|
||||
if sheet == "" {
|
||||
sheet = "Sheet1"
|
||||
}
|
||||
|
||||
headers := []string{"ID", "用户ID", "登录方式", "设备ID", "IP地址", "位置", "状态", "失败原因", "时间"}
|
||||
for idx, header := range headers {
|
||||
cell, _ := excelize.CoordinatesToCellName(idx+1, 1)
|
||||
_ = file.SetCellValue(sheet, cell, header)
|
||||
}
|
||||
|
||||
for rowIdx, log := range logs {
|
||||
row := []string{
|
||||
fmt.Sprintf("%d", log.ID),
|
||||
fmt.Sprintf("%d", derefInt64(log.UserID)),
|
||||
loginTypeLabel(log.LoginType),
|
||||
log.DeviceID,
|
||||
log.IP,
|
||||
log.Location,
|
||||
loginStatusLabel(log.Status),
|
||||
log.FailReason,
|
||||
log.CreatedAt.Format("2006-01-02 15:04:05"),
|
||||
}
|
||||
for colIdx, value := range row {
|
||||
cell, _ := excelize.CoordinatesToCellName(colIdx+1, rowIdx+2)
|
||||
_ = file.SetCellValue(sheet, cell, value)
|
||||
}
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if _, err := file.WriteTo(&buf); err != nil {
|
||||
return nil, fmt.Errorf("生成Excel失败: %w", err)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func loginTypeLabel(t int) string {
|
||||
switch t {
|
||||
case 1:
|
||||
return "密码登录"
|
||||
case 2:
|
||||
return "邮箱验证码"
|
||||
case 3:
|
||||
return "手机验证码"
|
||||
case 4:
|
||||
return "OAuth"
|
||||
default:
|
||||
return "未知"
|
||||
}
|
||||
}
|
||||
|
||||
func loginStatusLabel(s int) string {
|
||||
if s == 1 {
|
||||
return "成功"
|
||||
}
|
||||
return "失败"
|
||||
}
|
||||
|
||||
func derefInt64(v *int64) int64 {
|
||||
if v == nil {
|
||||
return 0
|
||||
}
|
||||
return *v
|
||||
}
|
||||
115
internal/service/operation_log.go
Normal file
115
internal/service/operation_log.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// OperationLogService 操作日志服务
|
||||
type OperationLogService struct {
|
||||
operationLogRepo *repository.OperationLogRepository
|
||||
}
|
||||
|
||||
// NewOperationLogService 创建操作日志服务
|
||||
func NewOperationLogService(operationLogRepo *repository.OperationLogRepository) *OperationLogService {
|
||||
return &OperationLogService{operationLogRepo: operationLogRepo}
|
||||
}
|
||||
|
||||
// RecordOperation 记录操作日志
|
||||
func (s *OperationLogService) RecordOperation(ctx context.Context, req *RecordOperationRequest) error {
|
||||
log := &domain.OperationLog{
|
||||
OperationType: req.OperationType,
|
||||
OperationName: req.OperationName,
|
||||
RequestMethod: req.RequestMethod,
|
||||
RequestPath: req.RequestPath,
|
||||
RequestParams: req.RequestParams,
|
||||
ResponseStatus: req.ResponseStatus,
|
||||
IP: req.IP,
|
||||
UserAgent: req.UserAgent,
|
||||
}
|
||||
if req.UserID != 0 {
|
||||
log.UserID = &req.UserID
|
||||
}
|
||||
return s.operationLogRepo.Create(ctx, log)
|
||||
}
|
||||
|
||||
// RecordOperationRequest 记录操作请求
|
||||
type RecordOperationRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
OperationType string `json:"operation_type"`
|
||||
OperationName string `json:"operation_name"`
|
||||
RequestMethod string `json:"request_method"`
|
||||
RequestPath string `json:"request_path"`
|
||||
RequestParams string `json:"request_params"`
|
||||
ResponseStatus int `json:"response_status"`
|
||||
IP string `json:"ip"`
|
||||
UserAgent string `json:"user_agent"`
|
||||
}
|
||||
|
||||
// ListOperationLogRequest 操作日志列表请求
|
||||
type ListOperationLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Method string `json:"method"`
|
||||
Keyword string `json:"keyword"`
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
StartAt string `json:"start_at"`
|
||||
EndAt string `json:"end_at"`
|
||||
}
|
||||
|
||||
// GetOperationLogs 获取操作日志列表
|
||||
func (s *OperationLogService) GetOperationLogs(ctx context.Context, req *ListOperationLogRequest) ([]*domain.OperationLog, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
// 按关键词搜索
|
||||
if req.Keyword != "" {
|
||||
return s.operationLogRepo.Search(ctx, req.Keyword, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// 按用户 ID 查询
|
||||
if req.UserID > 0 {
|
||||
return s.operationLogRepo.ListByUserID(ctx, req.UserID, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// 按 HTTP 方法查询
|
||||
if req.Method != "" {
|
||||
return s.operationLogRepo.ListByMethod(ctx, req.Method, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// 按时间范围查询
|
||||
if req.StartAt != "" && req.EndAt != "" {
|
||||
start, err1 := time.Parse(time.RFC3339, req.StartAt)
|
||||
end, err2 := time.Parse(time.RFC3339, req.EndAt)
|
||||
if err1 == nil && err2 == nil {
|
||||
return s.operationLogRepo.ListByTimeRange(ctx, start, end, offset, req.PageSize)
|
||||
}
|
||||
}
|
||||
|
||||
return s.operationLogRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// GetMyOperationLogs 获取当前用户的操作日志
|
||||
func (s *OperationLogService) GetMyOperationLogs(ctx context.Context, userID int64, page, pageSize int) ([]*domain.OperationLog, int64, error) {
|
||||
if page <= 0 {
|
||||
page = 1
|
||||
}
|
||||
if pageSize <= 0 {
|
||||
pageSize = 20
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
return s.operationLogRepo.ListByUserID(ctx, userID, offset, pageSize)
|
||||
}
|
||||
|
||||
// CleanupOldLogs 清理旧日志(保留最近 N 天)
|
||||
func (s *OperationLogService) CleanupOldLogs(ctx context.Context, retentionDays int) error {
|
||||
return s.operationLogRepo.DeleteOlderThan(ctx, retentionDays)
|
||||
}
|
||||
272
internal/service/password_reset.go
Normal file
272
internal/service/password_reset.go
Normal file
@@ -0,0 +1,272 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/smtp"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/security"
|
||||
)
|
||||
|
||||
// PasswordResetConfig controls reset-token issuance and SMTP delivery.
|
||||
type PasswordResetConfig struct {
|
||||
TokenTTL time.Duration
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUser string
|
||||
SMTPPass string
|
||||
FromEmail string
|
||||
SiteURL string
|
||||
PasswordMinLen int
|
||||
PasswordRequireSpecial bool
|
||||
PasswordRequireNumber bool
|
||||
}
|
||||
|
||||
func DefaultPasswordResetConfig() *PasswordResetConfig {
|
||||
return &PasswordResetConfig{
|
||||
TokenTTL: 15 * time.Minute,
|
||||
SMTPHost: "",
|
||||
SMTPPort: 587,
|
||||
SMTPUser: "",
|
||||
SMTPPass: "",
|
||||
FromEmail: "noreply@example.com",
|
||||
SiteURL: "http://localhost:8080",
|
||||
PasswordMinLen: 8,
|
||||
PasswordRequireSpecial: false,
|
||||
PasswordRequireNumber: false,
|
||||
}
|
||||
}
|
||||
|
||||
type PasswordResetService struct {
|
||||
userRepo userRepositoryInterface
|
||||
cache *cache.CacheManager
|
||||
config *PasswordResetConfig
|
||||
}
|
||||
|
||||
func NewPasswordResetService(
|
||||
userRepo userRepositoryInterface,
|
||||
cache *cache.CacheManager,
|
||||
config *PasswordResetConfig,
|
||||
) *PasswordResetService {
|
||||
if config == nil {
|
||||
config = DefaultPasswordResetConfig()
|
||||
}
|
||||
return &PasswordResetService{
|
||||
userRepo: userRepo,
|
||||
cache: cache,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ForgotPassword(ctx context.Context, email string) error {
|
||||
user, err := s.userRepo.GetByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tokenBytes := make([]byte, 32)
|
||||
if _, err := cryptorand.Read(tokenBytes); err != nil {
|
||||
return fmt.Errorf("生成重置Token失败: %w", err)
|
||||
}
|
||||
resetToken := hex.EncodeToString(tokenBytes)
|
||||
|
||||
cacheKey := "pwd_reset:" + resetToken
|
||||
ttl := s.config.TokenTTL
|
||||
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
|
||||
return fmt.Errorf("缓存重置Token失败: %w", err)
|
||||
}
|
||||
|
||||
go s.sendResetEmail(domain.DerefStr(user.Email), user.Username, resetToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ResetPassword(ctx context.Context, token, newPassword string) error {
|
||||
if token == "" || newPassword == "" {
|
||||
return errors.New("参数不完整")
|
||||
}
|
||||
|
||||
cacheKey := "pwd_reset:" + token
|
||||
val, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return errors.New("重置链接已失效或不存在,请重新申请")
|
||||
}
|
||||
|
||||
userID, ok := int64Value(val)
|
||||
if !ok {
|
||||
return errors.New("重置Token数据异常")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
if err := s.doResetPassword(ctx, user, newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("清理重置Token失败: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) ValidateResetToken(ctx context.Context, token string) (bool, error) {
|
||||
if token == "" {
|
||||
return false, errors.New("token不能为空")
|
||||
}
|
||||
_, ok := s.cache.Get(ctx, "pwd_reset:"+token)
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) sendResetEmail(email, username, token string) {
|
||||
if s.config.SMTPHost == "" {
|
||||
return
|
||||
}
|
||||
|
||||
resetURL := fmt.Sprintf("%s/reset-password?token=%s", s.config.SiteURL, token)
|
||||
subject := "密码重置请求"
|
||||
body := fmt.Sprintf(`您好 %s:
|
||||
|
||||
您收到此邮件,是因为有人请求重置账户密码。
|
||||
请点击以下链接重置密码(链接将在 %s 后失效):
|
||||
%s
|
||||
|
||||
如果不是您本人操作,请忽略此邮件,您的密码不会被修改。
|
||||
|
||||
用户管理系统团队`, username, s.config.TokenTTL.String(), resetURL)
|
||||
|
||||
var authInfo smtp.Auth
|
||||
if s.config.SMTPUser != "" || s.config.SMTPPass != "" {
|
||||
authInfo = smtp.PlainAuth("", s.config.SMTPUser, s.config.SMTPPass, s.config.SMTPHost)
|
||||
}
|
||||
|
||||
msg := fmt.Sprintf(
|
||||
"From: %s\r\nTo: %s\r\nSubject: %s\r\nContent-Type: text/plain; charset=UTF-8\r\n\r\n%s",
|
||||
s.config.FromEmail,
|
||||
email,
|
||||
subject,
|
||||
body,
|
||||
)
|
||||
addr := fmt.Sprintf("%s:%d", s.config.SMTPHost, s.config.SMTPPort)
|
||||
if err := smtp.SendMail(addr, authInfo, s.config.FromEmail, []string{email}, []byte(msg)); err != nil {
|
||||
log.Printf("password-reset-email: send failed to=%s err=%v", email, err)
|
||||
}
|
||||
}
|
||||
|
||||
// ForgotPasswordByPhoneRequest 短信密码重置请求
|
||||
type ForgotPasswordByPhoneRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
}
|
||||
|
||||
// ForgotPasswordByPhone 通过手机验证码重置密码 - 发送验证码
|
||||
func (s *PasswordResetService) ForgotPasswordByPhone(ctx context.Context, phone string) (string, error) {
|
||||
user, err := s.userRepo.GetByPhone(ctx, phone)
|
||||
if err != nil {
|
||||
return "", nil // 用户不存在不提示,防止用户枚举
|
||||
}
|
||||
|
||||
// 生成6位数字验证码
|
||||
code, err := generateSMSCode()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("生成验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储验证码,关联用户ID
|
||||
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", phone)
|
||||
ttl := s.config.TokenTTL
|
||||
if err := s.cache.Set(ctx, cacheKey, user.ID, ttl, ttl); err != nil {
|
||||
return "", fmt.Errorf("缓存验证码失败: %w", err)
|
||||
}
|
||||
|
||||
// 存储验证码到另一个key,用于后续校验
|
||||
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", phone)
|
||||
if err := s.cache.Set(ctx, codeKey, code, ttl, ttl); err != nil {
|
||||
return "", fmt.Errorf("缓存验证码失败: %w", err)
|
||||
}
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ResetPasswordByPhoneRequest 通过手机验证码重置密码请求
|
||||
type ResetPasswordByPhoneRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
NewPassword string `json:"new_password" binding:"required"`
|
||||
}
|
||||
|
||||
// ResetPasswordByPhone 通过手机验证码重置密码 - 验证并重置
|
||||
func (s *PasswordResetService) ResetPasswordByPhone(ctx context.Context, req *ResetPasswordByPhoneRequest) error {
|
||||
if req.Phone == "" || req.Code == "" || req.NewPassword == "" {
|
||||
return errors.New("参数不完整")
|
||||
}
|
||||
|
||||
codeKey := fmt.Sprintf("pwd_reset_sms_code:%s", req.Phone)
|
||||
storedCode, ok := s.cache.Get(ctx, codeKey)
|
||||
if !ok {
|
||||
return errors.New("验证码已失效,请重新获取")
|
||||
}
|
||||
|
||||
code, ok := storedCode.(string)
|
||||
if !ok || code != req.Code {
|
||||
return errors.New("验证码不正确")
|
||||
}
|
||||
|
||||
// 获取用户ID
|
||||
cacheKey := fmt.Sprintf("pwd_reset_sms:%s", req.Phone)
|
||||
val, ok := s.cache.Get(ctx, cacheKey)
|
||||
if !ok {
|
||||
return errors.New("验证码已失效,请重新获取")
|
||||
}
|
||||
|
||||
userID, ok := int64Value(val)
|
||||
if !ok {
|
||||
return errors.New("验证码数据异常")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
if err := s.doResetPassword(ctx, user, req.NewPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 清理验证码
|
||||
s.cache.Delete(ctx, codeKey)
|
||||
s.cache.Delete(ctx, cacheKey)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PasswordResetService) doResetPassword(ctx context.Context, user *domain.User, newPassword string) error {
|
||||
policy := security.PasswordPolicy{
|
||||
MinLength: s.config.PasswordMinLen,
|
||||
RequireSpecial: s.config.PasswordRequireSpecial,
|
||||
RequireNumber: s.config.PasswordRequireNumber,
|
||||
}.Normalize()
|
||||
if err := policy.Validate(newPassword); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return fmt.Errorf("密码加密失败: %w", err)
|
||||
}
|
||||
|
||||
user.Password = hashedPassword
|
||||
if err := s.userRepo.Update(ctx, user); err != nil {
|
||||
return fmt.Errorf("更新密码失败: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
223
internal/service/permission.go
Normal file
223
internal/service/permission.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// PermissionService 权限服务
|
||||
type PermissionService struct {
|
||||
permissionRepo *repository.PermissionRepository
|
||||
}
|
||||
|
||||
// NewPermissionService 创建权限服务
|
||||
func NewPermissionService(
|
||||
permissionRepo *repository.PermissionRepository,
|
||||
) *PermissionService {
|
||||
return &PermissionService{
|
||||
permissionRepo: permissionRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreatePermissionRequest 创建权限请求
|
||||
type CreatePermissionRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
Type int `json:"type" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ParentID *int64 `json:"parent_id"`
|
||||
Path string `json:"path"`
|
||||
Method string `json:"method"`
|
||||
Sort int `json:"sort"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// UpdatePermissionRequest 更新权限请求
|
||||
type UpdatePermissionRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ParentID *int64 `json:"parent_id"`
|
||||
Path string `json:"path"`
|
||||
Method string `json:"method"`
|
||||
Sort int `json:"sort"`
|
||||
Icon string `json:"icon"`
|
||||
}
|
||||
|
||||
// CreatePermission 创建权限
|
||||
func (s *PermissionService) CreatePermission(ctx context.Context, req *CreatePermissionRequest) (*domain.Permission, error) {
|
||||
// 检查权限代码是否已存在
|
||||
exists, err := s.permissionRepo.ExistsByCode(ctx, req.Code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("权限代码已存在")
|
||||
}
|
||||
|
||||
// 检查父权限是否存在
|
||||
if req.ParentID != nil {
|
||||
_, err := s.permissionRepo.GetByID(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return nil, errors.New("父权限不存在")
|
||||
}
|
||||
}
|
||||
|
||||
// 创建权限
|
||||
permission := &domain.Permission{
|
||||
Name: req.Name,
|
||||
Code: req.Code,
|
||||
Type: domain.PermissionType(req.Type),
|
||||
Description: req.Description,
|
||||
ParentID: req.ParentID,
|
||||
Level: 1,
|
||||
Path: req.Path,
|
||||
Method: req.Method,
|
||||
Sort: req.Sort,
|
||||
Icon: req.Icon,
|
||||
Status: domain.PermissionStatusEnabled,
|
||||
}
|
||||
|
||||
if req.ParentID != nil {
|
||||
permission.Level = 2
|
||||
}
|
||||
|
||||
if err := s.permissionRepo.Create(ctx, permission); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// UpdatePermission 更新权限
|
||||
func (s *PermissionService) UpdatePermission(ctx context.Context, permissionID int64, req *UpdatePermissionRequest) (*domain.Permission, error) {
|
||||
permission, err := s.permissionRepo.GetByID(ctx, permissionID)
|
||||
if err != nil {
|
||||
return nil, errors.New("权限不存在")
|
||||
}
|
||||
|
||||
// 检查父权限是否存在
|
||||
if req.ParentID != nil {
|
||||
if *req.ParentID == permissionID {
|
||||
return nil, errors.New("不能将权限设置为自己的父权限")
|
||||
}
|
||||
_, err := s.permissionRepo.GetByID(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return nil, errors.New("父权限不存在")
|
||||
}
|
||||
permission.ParentID = req.ParentID
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != "" {
|
||||
permission.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
permission.Description = req.Description
|
||||
}
|
||||
if req.Path != "" {
|
||||
permission.Path = req.Path
|
||||
}
|
||||
if req.Method != "" {
|
||||
permission.Method = req.Method
|
||||
}
|
||||
if req.Sort > 0 {
|
||||
permission.Sort = req.Sort
|
||||
}
|
||||
if req.Icon != "" {
|
||||
permission.Icon = req.Icon
|
||||
}
|
||||
|
||||
if err := s.permissionRepo.Update(ctx, permission); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
}
|
||||
|
||||
// DeletePermission 删除权限
|
||||
func (s *PermissionService) DeletePermission(ctx context.Context, permissionID int64) error {
|
||||
_, err := s.permissionRepo.GetByID(ctx, permissionID)
|
||||
if err != nil {
|
||||
return errors.New("权限不存在")
|
||||
}
|
||||
|
||||
// 检查是否有子权限
|
||||
children, err := s.permissionRepo.ListByParentID(ctx, permissionID)
|
||||
if err == nil && len(children) > 0 {
|
||||
return errors.New("存在子权限,无法删除")
|
||||
}
|
||||
|
||||
return s.permissionRepo.Delete(ctx, permissionID)
|
||||
}
|
||||
|
||||
// GetPermission 获取权限信息
|
||||
func (s *PermissionService) GetPermission(ctx context.Context, permissionID int64) (*domain.Permission, error) {
|
||||
return s.permissionRepo.GetByID(ctx, permissionID)
|
||||
}
|
||||
|
||||
// ListPermissions 获取权限列表
|
||||
type ListPermissionRequest struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Type int `json:"type"`
|
||||
Status int `json:"status"`
|
||||
Keyword string `json:"keyword"`
|
||||
}
|
||||
|
||||
func (s *PermissionService) ListPermissions(ctx context.Context, req *ListPermissionRequest) ([]*domain.Permission, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
if req.Keyword != "" {
|
||||
return s.permissionRepo.Search(ctx, req.Keyword, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// Type > 0 表示按类型过滤;0 表示不过滤(查全部)
|
||||
if req.Type > 0 {
|
||||
return s.permissionRepo.ListByType(ctx, domain.PermissionType(req.Type), offset, req.PageSize)
|
||||
}
|
||||
|
||||
// Status > 0 表示按状态过滤;0 表示不过滤(查全部)
|
||||
if req.Status > 0 {
|
||||
return s.permissionRepo.ListByStatus(ctx, domain.PermissionStatus(req.Status), offset, req.PageSize)
|
||||
}
|
||||
|
||||
return s.permissionRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// UpdatePermissionStatus 更新权限状态
|
||||
func (s *PermissionService) UpdatePermissionStatus(ctx context.Context, permissionID int64, status domain.PermissionStatus) error {
|
||||
return s.permissionRepo.UpdateStatus(ctx, permissionID, status)
|
||||
}
|
||||
|
||||
// GetPermissionTree 获取权限树
|
||||
func (s *PermissionService) GetPermissionTree(ctx context.Context) ([]*domain.Permission, error) {
|
||||
// 获取所有权限
|
||||
permissions, _, err := s.permissionRepo.List(ctx, 0, 1000)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 构建树形结构
|
||||
return s.buildPermissionTree(permissions, 0), nil
|
||||
}
|
||||
|
||||
// buildPermissionTree 构建权限树
|
||||
func (s *PermissionService) buildPermissionTree(permissions []*domain.Permission, parentID int64) []*domain.Permission {
|
||||
var tree []*domain.Permission
|
||||
for _, perm := range permissions {
|
||||
if (parentID == 0 && perm.ParentID == nil) || (perm.ParentID != nil && *perm.ParentID == parentID) {
|
||||
perm.Children = s.buildPermissionTree(permissions, perm.ID)
|
||||
tree = append(tree, perm)
|
||||
}
|
||||
}
|
||||
return tree
|
||||
}
|
||||
122
internal/service/prompts/codex_opencode_bridge.txt
Normal file
122
internal/service/prompts/codex_opencode_bridge.txt
Normal file
@@ -0,0 +1,122 @@
|
||||
# Codex Running in OpenCode
|
||||
|
||||
You are running Codex through OpenCode, an open-source terminal coding assistant. OpenCode provides different tools but follows Codex operating principles.
|
||||
|
||||
## CRITICAL: Tool Replacements
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan, read_plan, readPlan
|
||||
- ALWAYS use: todowrite for task/plan updates, todoread to read plans
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
|
||||
## Available OpenCode Tools
|
||||
|
||||
**File Operations:**
|
||||
- `write` - Create new files
|
||||
- Overwriting existing files requires a prior Read in this session; default to ASCII unless the file already uses Unicode.
|
||||
- `edit` - Modify existing files (REPLACES apply_patch)
|
||||
- Requires a prior Read in this session; preserve exact indentation; ensure `oldString` uniquely matches or use `replaceAll`; edit fails if ambiguous or missing.
|
||||
- `read` - Read file contents
|
||||
|
||||
**Search/Discovery:**
|
||||
- `grep` - Search file contents (tool, not bash grep); use `include` to filter patterns; set `path` only when not searching workspace root; for cross-file match counts use bash with `rg`.
|
||||
- `glob` - Find files by pattern; defaults to workspace cwd unless `path` is set.
|
||||
- `list` - List directories (requires absolute paths)
|
||||
|
||||
**Execution:**
|
||||
- `bash` - Run shell commands
|
||||
- No workdir parameter; do not include it in tool calls.
|
||||
- Always include a short description for the command.
|
||||
- Do not use cd; use absolute paths in commands.
|
||||
- Quote paths containing spaces with double quotes.
|
||||
- Chain multiple commands with ';' or '&&'; avoid newlines.
|
||||
- Use Grep/Glob tools for searches; only use bash with `rg` when you need counts or advanced features.
|
||||
- Do not use `ls`/`cat` in bash; use `list`/`read` tools instead.
|
||||
- For deletions (rm), verify by listing parent dir with `list`.
|
||||
|
||||
**Network:**
|
||||
- `webfetch` - Fetch web content
|
||||
- Use fully-formed URLs (http/https; http auto-upgrades to https).
|
||||
- Always set `format` to one of: text | markdown | html; prefer markdown unless otherwise required.
|
||||
- Read-only; short cache window.
|
||||
|
||||
**Task Management:**
|
||||
- `todowrite` - Manage tasks/plans (REPLACES update_plan)
|
||||
- `todoread` - Read current plan
|
||||
|
||||
## Substitution Rules
|
||||
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
|
||||
**Path Usage:** Use per-tool conventions to avoid conflicts:
|
||||
- Tool calls: `read`, `edit`, `write`, `list` require absolute paths.
|
||||
- Searches: `grep`/`glob` default to the workspace cwd; prefer relative include patterns; set `path` only when a different root is needed.
|
||||
- Presentation: In assistant messages, show workspace-relative paths; use absolute paths only inside tool calls.
|
||||
- Tool schema overrides general path preferences—do not convert required absolute paths to relative.
|
||||
|
||||
## Verification Checklist
|
||||
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I following each tool's path requirements?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
|
||||
## OpenCode Working Style
|
||||
|
||||
**Communication:**
|
||||
- Send brief preambles (8-12 words) before tool calls, building on prior context
|
||||
- Provide progress updates during longer tasks
|
||||
|
||||
**Execution:**
|
||||
- Keep working autonomously until query is fully resolved before yielding
|
||||
- Don't return to user with partial solutions
|
||||
|
||||
**Code Approach:**
|
||||
- New projects: Be ambitious and creative
|
||||
- Existing codebases: Surgical precision - modify only what's requested unless explicitly instructed to do otherwise
|
||||
|
||||
**Testing:**
|
||||
- If tests exist: Start specific to your changes, then broader validation
|
||||
|
||||
## Advanced Tools
|
||||
|
||||
**Task Tool (Sub-Agents):**
|
||||
- Use the Task tool (functions.task) to launch sub-agents
|
||||
- Check the Task tool description for current agent types and their capabilities
|
||||
- Useful for complex analysis, specialized workflows, or tasks requiring isolated context
|
||||
- The agent list is dynamically generated - refer to tool schema for available agents
|
||||
|
||||
**Parallelization:**
|
||||
- When multiple independent tool calls are needed, use multi_tool_use.parallel to run them concurrently.
|
||||
- Reserve sequential calls for ordered or data-dependent steps.
|
||||
|
||||
**MCP Tools:**
|
||||
- Model Context Protocol servers provide additional capabilities
|
||||
- MCP tools are prefixed: `mcp__<server-name>__<tool-name>`
|
||||
- Check your available tools for MCP integrations
|
||||
- Use when the tool's functionality matches your task needs
|
||||
|
||||
## What Remains from Codex
|
||||
|
||||
Sandbox policies, approval mechanisms, final answer formatting, git commit protocols, and file reference formats all follow Codex instructions. In approval policy "never", never request escalations.
|
||||
|
||||
## Approvals & Safety
|
||||
- Assume workspace-write filesystem, network enabled, approval on-failure unless explicitly stated otherwise.
|
||||
- When a command fails due to sandboxing or permissions, retry with escalated permissions if allowed by policy, including a one-line justification.
|
||||
- Treat destructive commands (e.g., `rm`, `git reset --hard`) as requiring explicit user request or approval.
|
||||
- When uncertain, prefer non-destructive verification first (e.g., confirm file existence with `list`, then delete with `bash`).
|
||||
63
internal/service/prompts/tool_remap_message.txt
Normal file
63
internal/service/prompts/tool_remap_message.txt
Normal file
@@ -0,0 +1,63 @@
|
||||
<user_instructions priority="0">
|
||||
<environment_override priority="0">
|
||||
YOU ARE IN A DIFFERENT ENVIRONMENT. These instructions override ALL previous tool references.
|
||||
</environment_override>
|
||||
|
||||
<tool_replacements priority="0">
|
||||
<critical_rule priority="0">
|
||||
❌ APPLY_PATCH DOES NOT EXIST → ✅ USE "edit" INSTEAD
|
||||
- NEVER use: apply_patch, applyPatch
|
||||
- ALWAYS use: edit tool for ALL file modifications
|
||||
- Before modifying files: Verify you're using "edit", NOT "apply_patch"
|
||||
</critical_rule>
|
||||
|
||||
<critical_rule priority="0">
|
||||
❌ UPDATE_PLAN DOES NOT EXIST → ✅ USE "todowrite" INSTEAD
|
||||
- NEVER use: update_plan, updatePlan
|
||||
- ALWAYS use: todowrite for ALL task/plan operations
|
||||
- Use todoread to read current plan
|
||||
- Before plan operations: Verify you're using "todowrite", NOT "update_plan"
|
||||
</critical_rule>
|
||||
</tool_replacements>
|
||||
|
||||
<available_tools priority="0">
|
||||
File Operations:
|
||||
• write - Create new files
|
||||
• edit - Modify existing files (REPLACES apply_patch)
|
||||
• patch - Apply diff patches
|
||||
• read - Read file contents
|
||||
|
||||
Search/Discovery:
|
||||
• grep - Search file contents
|
||||
• glob - Find files by pattern
|
||||
• list - List directories (use relative paths)
|
||||
|
||||
Execution:
|
||||
• bash - Run shell commands
|
||||
|
||||
Network:
|
||||
• webfetch - Fetch web content
|
||||
|
||||
Task Management:
|
||||
• todowrite - Manage tasks/plans (REPLACES update_plan)
|
||||
• todoread - Read current plan
|
||||
</available_tools>
|
||||
|
||||
<substitution_rules priority="0">
|
||||
Base instruction says: You MUST use instead:
|
||||
apply_patch → edit
|
||||
update_plan → todowrite
|
||||
read_plan → todoread
|
||||
absolute paths → relative paths
|
||||
</substitution_rules>
|
||||
|
||||
<verification_checklist priority="0">
|
||||
Before file/plan modifications:
|
||||
1. Am I using "edit" NOT "apply_patch"?
|
||||
2. Am I using "todowrite" NOT "update_plan"?
|
||||
3. Is this tool in the approved list above?
|
||||
4. Am I using relative paths?
|
||||
|
||||
If ANY answer is NO → STOP and correct before proceeding.
|
||||
</verification_checklist>
|
||||
</user_instructions>
|
||||
216
internal/service/request_metadata.go
Normal file
216
internal/service/request_metadata.go
Normal file
@@ -0,0 +1,216 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/user-management-system/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
type requestMetadataContextKey struct{}
|
||||
|
||||
var requestMetadataKey = requestMetadataContextKey{}
|
||||
|
||||
type RequestMetadata struct {
|
||||
IsMaxTokensOneHaikuRequest *bool
|
||||
ThinkingEnabled *bool
|
||||
PrefetchedStickyAccountID *int64
|
||||
PrefetchedStickyGroupID *int64
|
||||
SingleAccountRetry *bool
|
||||
AccountSwitchCount *int
|
||||
}
|
||||
|
||||
var (
|
||||
requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64
|
||||
requestMetadataFallbackThinkingEnabledTotal atomic.Int64
|
||||
requestMetadataFallbackPrefetchedStickyAccount atomic.Int64
|
||||
requestMetadataFallbackPrefetchedStickyGroup atomic.Int64
|
||||
requestMetadataFallbackSingleAccountRetryTotal atomic.Int64
|
||||
requestMetadataFallbackAccountSwitchCountTotal atomic.Int64
|
||||
)
|
||||
|
||||
func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) {
|
||||
return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(),
|
||||
requestMetadataFallbackThinkingEnabledTotal.Load(),
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Load(),
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Load(),
|
||||
requestMetadataFallbackSingleAccountRetryTotal.Load(),
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Load()
|
||||
}
|
||||
|
||||
func metadataFromContext(ctx context.Context) *RequestMetadata {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata)
|
||||
return md
|
||||
}
|
||||
|
||||
func updateRequestMetadata(
|
||||
ctx context.Context,
|
||||
bridgeOldKeys bool,
|
||||
update func(md *RequestMetadata),
|
||||
legacyBridge func(ctx context.Context) context.Context,
|
||||
) context.Context {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
current := metadataFromContext(ctx)
|
||||
next := &RequestMetadata{}
|
||||
if current != nil {
|
||||
*next = *current
|
||||
}
|
||||
update(next)
|
||||
ctx = context.WithValue(ctx, requestMetadataKey, next)
|
||||
if bridgeOldKeys && legacyBridge != nil {
|
||||
ctx = legacyBridge(ctx)
|
||||
}
|
||||
return ctx
|
||||
}
|
||||
|
||||
func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.IsMaxTokensOneHaikuRequest = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.ThinkingEnabled = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.ThinkingEnabled, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
account := accountID
|
||||
group := groupID
|
||||
md.PrefetchedStickyAccountID = &account
|
||||
md.PrefetchedStickyGroupID = &group
|
||||
}, func(base context.Context) context.Context {
|
||||
bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID)
|
||||
return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID)
|
||||
})
|
||||
}
|
||||
|
||||
func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.SingleAccountRetry = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.SingleAccountRetry, value)
|
||||
})
|
||||
}
|
||||
|
||||
func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context {
|
||||
return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) {
|
||||
v := value
|
||||
md.AccountSwitchCount = &v
|
||||
}, func(base context.Context) context.Context {
|
||||
return context.WithValue(base, ctxkey.AccountSwitchCount, value)
|
||||
})
|
||||
}
|
||||
|
||||
func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil {
|
||||
return *md.IsMaxTokensOneHaikuRequest, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok {
|
||||
requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil {
|
||||
return *md.ThinkingEnabled, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok {
|
||||
requestMetadataFallbackThinkingEnabledTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil {
|
||||
return *md.PrefetchedStickyGroupID, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyGroupID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
|
||||
return t, true
|
||||
case int:
|
||||
requestMetadataFallbackPrefetchedStickyGroup.Add(1)
|
||||
return int64(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil {
|
||||
return *md.PrefetchedStickyAccountID, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.PrefetchedStickyAccountID)
|
||||
switch t := v.(type) {
|
||||
case int64:
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
|
||||
return t, true
|
||||
case int:
|
||||
requestMetadataFallbackPrefetchedStickyAccount.Add(1)
|
||||
return int64(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil {
|
||||
return *md.SingleAccountRetry, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return false, false
|
||||
}
|
||||
if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok {
|
||||
requestMetadataFallbackSingleAccountRetryTotal.Add(1)
|
||||
return value, true
|
||||
}
|
||||
return false, false
|
||||
}
|
||||
|
||||
func AccountSwitchCountFromContext(ctx context.Context) (int, bool) {
|
||||
if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil {
|
||||
return *md.AccountSwitchCount, true
|
||||
}
|
||||
if ctx == nil {
|
||||
return 0, false
|
||||
}
|
||||
v := ctx.Value(ctxkey.AccountSwitchCount)
|
||||
switch t := v.(type) {
|
||||
case int:
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
|
||||
return t, true
|
||||
case int64:
|
||||
requestMetadataFallbackAccountSwitchCountTotal.Add(1)
|
||||
return int(t), true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
284
internal/service/role.go
Normal file
284
internal/service/role.go
Normal file
@@ -0,0 +1,284 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// RoleService 角色服务
|
||||
type RoleService struct {
|
||||
roleRepo *repository.RoleRepository
|
||||
rolePermissionRepo *repository.RolePermissionRepository
|
||||
}
|
||||
|
||||
// NewRoleService 创建角色服务
|
||||
func NewRoleService(
|
||||
roleRepo *repository.RoleRepository,
|
||||
rolePermissionRepo *repository.RolePermissionRepository,
|
||||
) *RoleService {
|
||||
return &RoleService{
|
||||
roleRepo: roleRepo,
|
||||
rolePermissionRepo: rolePermissionRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateRoleRequest 创建角色请求
|
||||
type CreateRoleRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
Description string `json:"description"`
|
||||
ParentID *int64 `json:"parent_id"`
|
||||
}
|
||||
|
||||
// UpdateRoleRequest 更新角色请求
|
||||
type UpdateRoleRequest struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
ParentID *int64 `json:"parent_id"`
|
||||
}
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (s *RoleService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*domain.Role, error) {
|
||||
// 检查角色代码是否已存在
|
||||
exists, err := s.roleRepo.ExistsByCode(ctx, req.Code)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if exists {
|
||||
return nil, errors.New("角色代码已存在")
|
||||
}
|
||||
|
||||
// 设置角色层级
|
||||
level := 1
|
||||
if req.ParentID != nil {
|
||||
parentRole, err := s.roleRepo.GetByID(ctx, *req.ParentID)
|
||||
if err != nil {
|
||||
return nil, errors.New("父角色不存在")
|
||||
}
|
||||
level = parentRole.Level + 1
|
||||
}
|
||||
|
||||
// 创建角色
|
||||
role := &domain.Role{
|
||||
Name: req.Name,
|
||||
Code: req.Code,
|
||||
Description: req.Description,
|
||||
ParentID: req.ParentID,
|
||||
Level: level,
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}
|
||||
|
||||
if err := s.roleRepo.Create(ctx, role); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
const maxRoleDepth = 5 // 角色继承深度上限,可配置
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (s *RoleService) UpdateRole(ctx context.Context, roleID int64, req *UpdateRoleRequest) (*domain.Role, error) {
|
||||
role, err := s.roleRepo.GetByID(ctx, roleID)
|
||||
if err != nil {
|
||||
return nil, errors.New("角色不存在")
|
||||
}
|
||||
|
||||
// 检查父角色是否存在
|
||||
if req.ParentID != nil {
|
||||
if *req.ParentID == roleID {
|
||||
return nil, errors.New("不能将角色设置为自己的父角色")
|
||||
}
|
||||
// 检测循环继承:检查新父角色的祖先链是否包含当前角色
|
||||
if err := s.checkCircularInheritance(ctx, roleID, *req.ParentID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 检测继承深度:计算新父角色的深度 + 1
|
||||
if err := s.checkInheritanceDepth(ctx, *req.ParentID, maxRoleDepth-1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
role.ParentID = req.ParentID
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != "" {
|
||||
role.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
role.Description = req.Description
|
||||
}
|
||||
|
||||
if err := s.roleRepo.Update(ctx, role); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return role, nil
|
||||
}
|
||||
|
||||
// checkCircularInheritance 检测循环继承
|
||||
// 如果将 childID 的父角色设为 parentID,检查 parentID 的祖先链是否包含 childID
|
||||
func (s *RoleService) checkCircularInheritance(ctx context.Context, childID, parentID int64) error {
|
||||
ancestorIDs, err := s.roleRepo.GetAncestorIDs(ctx, parentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, ancestorID := range ancestorIDs {
|
||||
if ancestorID == childID {
|
||||
return errors.New("检测到循环继承,操作被拒绝")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkInheritanceDepth 检测继承深度是否超限
|
||||
func (s *RoleService) checkInheritanceDepth(ctx context.Context, roleID int64, maxDepth int) error {
|
||||
if maxDepth <= 0 {
|
||||
return errors.New("继承深度超限,最大支持5层")
|
||||
}
|
||||
|
||||
depth := 0
|
||||
currentID := roleID
|
||||
for {
|
||||
role, err := s.roleRepo.GetByID(ctx, currentID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
if role.ParentID == nil {
|
||||
break
|
||||
}
|
||||
depth++
|
||||
if depth > maxDepth {
|
||||
return errors.New("继承深度超限,最大支持5层")
|
||||
}
|
||||
currentID = *role.ParentID
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色
|
||||
func (s *RoleService) DeleteRole(ctx context.Context, roleID int64) error {
|
||||
role, err := s.roleRepo.GetByID(ctx, roleID)
|
||||
if err != nil {
|
||||
return errors.New("角色不存在")
|
||||
}
|
||||
|
||||
// 系统角色不能删除
|
||||
if role.IsSystem {
|
||||
return errors.New("系统角色不能删除")
|
||||
}
|
||||
|
||||
// 检查是否有子角色
|
||||
children, err := s.roleRepo.ListByParentID(ctx, roleID)
|
||||
if err == nil && len(children) > 0 {
|
||||
return errors.New("存在子角色,无法删除")
|
||||
}
|
||||
|
||||
// 删除角色权限关联
|
||||
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 删除角色
|
||||
return s.roleRepo.Delete(ctx, roleID)
|
||||
}
|
||||
|
||||
// GetRole 获取角色信息
|
||||
func (s *RoleService) GetRole(ctx context.Context, roleID int64) (*domain.Role, error) {
|
||||
return s.roleRepo.GetByID(ctx, roleID)
|
||||
}
|
||||
|
||||
// ListRoles 获取角色列表
|
||||
type ListRoleRequest struct {
|
||||
Page int `json:"page"`
|
||||
PageSize int `json:"page_size"`
|
||||
Status int `json:"status"`
|
||||
Keyword string `json:"keyword"`
|
||||
}
|
||||
|
||||
func (s *RoleService) ListRoles(ctx context.Context, req *ListRoleRequest) ([]*domain.Role, int64, error) {
|
||||
if req.Page <= 0 {
|
||||
req.Page = 1
|
||||
}
|
||||
if req.PageSize <= 0 {
|
||||
req.PageSize = 20
|
||||
}
|
||||
offset := (req.Page - 1) * req.PageSize
|
||||
|
||||
if req.Keyword != "" {
|
||||
return s.roleRepo.Search(ctx, req.Keyword, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// Status > 0 表示按状态过滤;0 表示不过滤(查全部)
|
||||
if req.Status > 0 {
|
||||
return s.roleRepo.ListByStatus(ctx, domain.RoleStatus(req.Status), offset, req.PageSize)
|
||||
}
|
||||
|
||||
return s.roleRepo.List(ctx, offset, req.PageSize)
|
||||
}
|
||||
|
||||
// UpdateRoleStatus 更新角色状态
|
||||
func (s *RoleService) UpdateRoleStatus(ctx context.Context, roleID int64, status domain.RoleStatus) error {
|
||||
role, err := s.roleRepo.GetByID(ctx, roleID)
|
||||
if err != nil {
|
||||
return errors.New("角色不存在")
|
||||
}
|
||||
|
||||
// 系统角色不能禁用
|
||||
if role.IsSystem && status == domain.RoleStatusDisabled {
|
||||
return errors.New("系统角色不能禁用")
|
||||
}
|
||||
|
||||
return s.roleRepo.UpdateStatus(ctx, roleID, status)
|
||||
}
|
||||
|
||||
// GetRolePermissions 获取角色权限(包含继承的父角色权限)
|
||||
func (s *RoleService) GetRolePermissions(ctx context.Context, roleID int64) ([]*domain.Permission, error) {
|
||||
// 收集所有角色ID(包括当前角色和所有父角色)
|
||||
allRoleIDs := []int64{roleID}
|
||||
ancestorIDs, err := s.roleRepo.GetAncestorIDs(ctx, roleID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allRoleIDs = append(allRoleIDs, ancestorIDs...)
|
||||
|
||||
// 批量获取所有角色的权限ID
|
||||
permissionIDs, err := s.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, allRoleIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量获取权限详情
|
||||
permissions, err := s.rolePermissionRepo.GetPermissionsByIDs(ctx, permissionIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return permissions, nil
|
||||
}
|
||||
|
||||
// AssignPermissions 分配权限
|
||||
func (s *RoleService) AssignPermissions(ctx context.Context, roleID int64, permissionIDs []int64) error {
|
||||
// 删除原有权限
|
||||
if err := s.rolePermissionRepo.DeleteByRoleID(ctx, roleID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 创建新权限关联
|
||||
var rolePermissions []*domain.RolePermission
|
||||
for _, permissionID := range permissionIDs {
|
||||
rolePermissions = append(rolePermissions, &domain.RolePermission{
|
||||
RoleID: roleID,
|
||||
PermissionID: permissionID,
|
||||
})
|
||||
}
|
||||
|
||||
return s.rolePermissionRepo.BatchCreate(ctx, rolePermissions)
|
||||
}
|
||||
462
internal/service/sms.go
Normal file
462
internal/service/sms.go
Normal file
@@ -0,0 +1,462 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
aliyunopenapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils"
|
||||
aliyunsms "github.com/alibabacloud-go/dysmsapi-20170525/v5/client"
|
||||
"github.com/alibabacloud-go/tea/dara"
|
||||
tccommon "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||||
tcprofile "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||||
tcsms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
|
||||
)
|
||||
|
||||
var (
|
||||
validPhonePattern = regexp.MustCompile(`^((\+86|86)?1[3-9]\d{9}|\+[1-9]\d{6,14})$`)
|
||||
mainlandPhonePattern = regexp.MustCompile(`^1[3-9]\d{9}$`)
|
||||
mainlandPhone86Pattern = regexp.MustCompile(`^86(1[3-9]\d{9})$`)
|
||||
mainlandPhone0086Pattern = regexp.MustCompile(`^0086(1[3-9]\d{9})$`)
|
||||
verificationCodeCharset10 = 1000000
|
||||
)
|
||||
|
||||
// SMSProvider sends one verification code to one phone number.
|
||||
type SMSProvider interface {
|
||||
SendVerificationCode(ctx context.Context, phone, code string) error
|
||||
}
|
||||
|
||||
// MockSMSProvider is a test helper and is not wired into the server runtime.
|
||||
type MockSMSProvider struct{}
|
||||
|
||||
func (m *MockSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||||
_ = ctx
|
||||
// 安全:不在日志中记录完整验证码,仅显示部分信息用于调试
|
||||
maskedCode := "****"
|
||||
if len(code) >= 4 {
|
||||
maskedCode = strings.Repeat("*", len(code)-4) + code[len(code)-4:]
|
||||
}
|
||||
log.Printf("[sms-mock] phone=%s code=%s ttl=5m", phone, maskedCode)
|
||||
return nil
|
||||
}
|
||||
|
||||
type aliyunSMSClient interface {
|
||||
SendSms(request *aliyunsms.SendSmsRequest) (*aliyunsms.SendSmsResponse, error)
|
||||
}
|
||||
|
||||
type tencentSMSClient interface {
|
||||
SendSmsWithContext(ctx context.Context, request *tcsms.SendSmsRequest) (*tcsms.SendSmsResponse, error)
|
||||
}
|
||||
|
||||
type AliyunSMSConfig struct {
|
||||
AccessKeyID string
|
||||
AccessKeySecret string
|
||||
SignName string
|
||||
TemplateCode string
|
||||
Endpoint string
|
||||
RegionID string
|
||||
CodeParamName string
|
||||
}
|
||||
|
||||
type AliyunSMSProvider struct {
|
||||
cfg AliyunSMSConfig
|
||||
client aliyunSMSClient
|
||||
}
|
||||
|
||||
func NewAliyunSMSProvider(cfg AliyunSMSConfig) (SMSProvider, error) {
|
||||
cfg = normalizeAliyunSMSConfig(cfg)
|
||||
if cfg.AccessKeyID == "" || cfg.AccessKeySecret == "" || cfg.SignName == "" || cfg.TemplateCode == "" {
|
||||
return nil, fmt.Errorf("aliyun SMS config is incomplete")
|
||||
}
|
||||
|
||||
client, err := newAliyunSMSClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create aliyun SMS client failed: %w", err)
|
||||
}
|
||||
|
||||
return &AliyunSMSProvider{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newAliyunSMSClient(cfg AliyunSMSConfig) (aliyunSMSClient, error) {
|
||||
client, err := aliyunsms.NewClient(&aliyunopenapiutil.Config{
|
||||
AccessKeyId: dara.String(cfg.AccessKeyID),
|
||||
AccessKeySecret: dara.String(cfg.AccessKeySecret),
|
||||
Endpoint: stringPointerOrNil(cfg.Endpoint),
|
||||
RegionId: dara.String(cfg.RegionID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (a *AliyunSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||||
_ = ctx
|
||||
|
||||
templateParam, err := json.Marshal(map[string]string{
|
||||
a.cfg.CodeParamName: code,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal aliyun SMS template param failed: %w", err)
|
||||
}
|
||||
|
||||
resp, err := a.client.SendSms(
|
||||
new(aliyunsms.SendSmsRequest).
|
||||
SetPhoneNumbers(normalizePhoneForSMS(phone)).
|
||||
SetSignName(a.cfg.SignName).
|
||||
SetTemplateCode(a.cfg.TemplateCode).
|
||||
SetTemplateParam(string(templateParam)),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("aliyun SMS request failed: %w", err)
|
||||
}
|
||||
if resp == nil || resp.Body == nil {
|
||||
return fmt.Errorf("aliyun SMS returned empty response")
|
||||
}
|
||||
|
||||
body := resp.Body
|
||||
if !strings.EqualFold(dara.StringValue(body.Code), "OK") {
|
||||
return fmt.Errorf(
|
||||
"aliyun SMS rejected: code=%s message=%s request_id=%s",
|
||||
valueOrDefault(dara.StringValue(body.Code), "unknown"),
|
||||
valueOrDefault(dara.StringValue(body.Message), "unknown"),
|
||||
valueOrDefault(dara.StringValue(body.RequestId), "unknown"),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type TencentSMSConfig struct {
|
||||
SecretID string
|
||||
SecretKey string
|
||||
AppID string
|
||||
SignName string
|
||||
TemplateID string
|
||||
Region string
|
||||
Endpoint string
|
||||
}
|
||||
|
||||
type TencentSMSProvider struct {
|
||||
cfg TencentSMSConfig
|
||||
client tencentSMSClient
|
||||
}
|
||||
|
||||
func NewTencentSMSProvider(cfg TencentSMSConfig) (SMSProvider, error) {
|
||||
cfg = normalizeTencentSMSConfig(cfg)
|
||||
if cfg.SecretID == "" || cfg.SecretKey == "" || cfg.AppID == "" || cfg.SignName == "" || cfg.TemplateID == "" {
|
||||
return nil, fmt.Errorf("tencent SMS config is incomplete")
|
||||
}
|
||||
|
||||
client, err := newTencentSMSClient(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create tencent SMS client failed: %w", err)
|
||||
}
|
||||
|
||||
return &TencentSMSProvider{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newTencentSMSClient(cfg TencentSMSConfig) (tencentSMSClient, error) {
|
||||
clientProfile := tcprofile.NewClientProfile()
|
||||
clientProfile.HttpProfile.ReqTimeout = 30
|
||||
if cfg.Endpoint != "" {
|
||||
clientProfile.HttpProfile.Endpoint = cfg.Endpoint
|
||||
}
|
||||
|
||||
client, err := tcsms.NewClient(
|
||||
tccommon.NewCredential(cfg.SecretID, cfg.SecretKey),
|
||||
cfg.Region,
|
||||
clientProfile,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return client, nil
|
||||
}
|
||||
|
||||
func (t *TencentSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
||||
req := tcsms.NewSendSmsRequest()
|
||||
req.PhoneNumberSet = []*string{tccommon.StringPtr(normalizePhoneForSMS(phone))}
|
||||
req.SmsSdkAppId = tccommon.StringPtr(t.cfg.AppID)
|
||||
req.SignName = tccommon.StringPtr(t.cfg.SignName)
|
||||
req.TemplateId = tccommon.StringPtr(t.cfg.TemplateID)
|
||||
req.TemplateParamSet = []*string{tccommon.StringPtr(code)}
|
||||
|
||||
resp, err := t.client.SendSmsWithContext(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("tencent SMS request failed: %w", err)
|
||||
}
|
||||
if resp == nil || resp.Response == nil {
|
||||
return fmt.Errorf("tencent SMS returned empty response")
|
||||
}
|
||||
if len(resp.Response.SendStatusSet) == 0 {
|
||||
return fmt.Errorf(
|
||||
"tencent SMS returned empty status list: request_id=%s",
|
||||
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
||||
)
|
||||
}
|
||||
|
||||
status := resp.Response.SendStatusSet[0]
|
||||
if !strings.EqualFold(pointerString(status.Code), "Ok") {
|
||||
return fmt.Errorf(
|
||||
"tencent SMS rejected: code=%s message=%s request_id=%s",
|
||||
valueOrDefault(pointerString(status.Code), "unknown"),
|
||||
valueOrDefault(pointerString(status.Message), "unknown"),
|
||||
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type SMSCodeConfig struct {
|
||||
CodeTTL time.Duration
|
||||
ResendCooldown time.Duration
|
||||
MaxDailyLimit int
|
||||
}
|
||||
|
||||
func DefaultSMSCodeConfig() SMSCodeConfig {
|
||||
return SMSCodeConfig{
|
||||
CodeTTL: 5 * time.Minute,
|
||||
ResendCooldown: time.Minute,
|
||||
MaxDailyLimit: 10,
|
||||
}
|
||||
}
|
||||
|
||||
type SMSCodeService struct {
|
||||
provider SMSProvider
|
||||
cache cacheInterface
|
||||
cfg SMSCodeConfig
|
||||
}
|
||||
|
||||
type cacheInterface interface {
|
||||
Get(ctx context.Context, key string) (interface{}, bool)
|
||||
Set(ctx context.Context, key string, value interface{}, l1TTL, l2TTL time.Duration) error
|
||||
Delete(ctx context.Context, key string) error
|
||||
}
|
||||
|
||||
func NewSMSCodeService(provider SMSProvider, cacheManager cacheInterface, cfg SMSCodeConfig) *SMSCodeService {
|
||||
if cfg.CodeTTL <= 0 {
|
||||
cfg.CodeTTL = 5 * time.Minute
|
||||
}
|
||||
if cfg.ResendCooldown <= 0 {
|
||||
cfg.ResendCooldown = time.Minute
|
||||
}
|
||||
if cfg.MaxDailyLimit <= 0 {
|
||||
cfg.MaxDailyLimit = 10
|
||||
}
|
||||
|
||||
return &SMSCodeService{
|
||||
provider: provider,
|
||||
cache: cacheManager,
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
type SendCodeRequest struct {
|
||||
Phone string `json:"phone" binding:"required"`
|
||||
Purpose string `json:"purpose"`
|
||||
Scene string `json:"scene"`
|
||||
}
|
||||
|
||||
type SendCodeResponse struct {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Cooldown int `json:"cooldown"`
|
||||
}
|
||||
|
||||
func (s *SMSCodeService) SendCode(ctx context.Context, req *SendCodeRequest) (*SendCodeResponse, error) {
|
||||
if s == nil || s.provider == nil || s.cache == nil {
|
||||
return nil, fmt.Errorf("sms code service is not configured")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, newValidationError("\u8bf7\u6c42\u4e0d\u80fd\u4e3a\u7a7a")
|
||||
}
|
||||
|
||||
phone := strings.TrimSpace(req.Phone)
|
||||
if !isValidPhone(phone) {
|
||||
return nil, newValidationError("\u624b\u673a\u53f7\u7801\u683c\u5f0f\u4e0d\u6b63\u786e")
|
||||
}
|
||||
purpose := strings.TrimSpace(req.Purpose)
|
||||
if purpose == "" {
|
||||
purpose = strings.TrimSpace(req.Scene)
|
||||
}
|
||||
|
||||
cooldownKey := fmt.Sprintf("sms_cooldown:%s", phone)
|
||||
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
|
||||
return nil, newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
|
||||
}
|
||||
|
||||
dailyKey := fmt.Sprintf("sms_daily:%s:%s", phone, time.Now().Format("2006-01-02"))
|
||||
var dailyCount int
|
||||
if val, ok := s.cache.Get(ctx, dailyKey); ok {
|
||||
if n, ok := intValue(val); ok {
|
||||
dailyCount = n
|
||||
}
|
||||
}
|
||||
if dailyCount >= s.cfg.MaxDailyLimit {
|
||||
return nil, newRateLimitError(fmt.Sprintf("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff08%d\u6b21\uff09\uff0c\u8bf7\u660e\u65e5\u518d\u8bd5", s.cfg.MaxDailyLimit))
|
||||
}
|
||||
|
||||
code, err := generateSMSCode()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate sms code failed: %w", err)
|
||||
}
|
||||
|
||||
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
||||
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
|
||||
return nil, fmt.Errorf("store sms code failed: %w", err)
|
||||
}
|
||||
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
return nil, fmt.Errorf("store sms cooldown failed: %w", err)
|
||||
}
|
||||
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
_ = s.cache.Delete(ctx, cooldownKey)
|
||||
return nil, fmt.Errorf("store sms daily counter failed: %w", err)
|
||||
}
|
||||
|
||||
if err := s.provider.SendVerificationCode(ctx, phone, code); err != nil {
|
||||
_ = s.cache.Delete(ctx, codeKey)
|
||||
_ = s.cache.Delete(ctx, cooldownKey)
|
||||
return nil, fmt.Errorf("\u77ed\u4fe1\u53d1\u9001\u5931\u8d25: %w", err)
|
||||
}
|
||||
|
||||
return &SendCodeResponse{
|
||||
ExpiresIn: int(s.cfg.CodeTTL.Seconds()),
|
||||
Cooldown: int(s.cfg.ResendCooldown.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code string) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return fmt.Errorf("sms code service is not configured")
|
||||
}
|
||||
if strings.TrimSpace(code) == "" {
|
||||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u80fd\u4e3a\u7a7a")
|
||||
}
|
||||
|
||||
phone = strings.TrimSpace(phone)
|
||||
purpose = strings.TrimSpace(purpose)
|
||||
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
||||
val, ok := s.cache.Get(ctx, codeKey)
|
||||
if !ok {
|
||||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u5df2\u8fc7\u671f\u6216\u4e0d\u5b58\u5728")
|
||||
}
|
||||
|
||||
stored, ok := val.(string)
|
||||
if !ok || stored != code {
|
||||
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(ctx, codeKey); err != nil {
|
||||
return fmt.Errorf("consume sms code failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func isValidPhone(phone string) bool {
|
||||
return validPhonePattern.MatchString(strings.TrimSpace(phone))
|
||||
}
|
||||
|
||||
func generateSMSCode() (string, error) {
|
||||
b := make([]byte, 4)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
if n < 0 {
|
||||
n = -n
|
||||
}
|
||||
n = n % verificationCodeCharset10
|
||||
if n < 100000 {
|
||||
n += 100000
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%06d", n), nil
|
||||
}
|
||||
|
||||
func normalizeAliyunSMSConfig(cfg AliyunSMSConfig) AliyunSMSConfig {
|
||||
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
|
||||
cfg.AccessKeySecret = strings.TrimSpace(cfg.AccessKeySecret)
|
||||
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
||||
cfg.TemplateCode = strings.TrimSpace(cfg.TemplateCode)
|
||||
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
||||
cfg.RegionID = strings.TrimSpace(cfg.RegionID)
|
||||
cfg.CodeParamName = strings.TrimSpace(cfg.CodeParamName)
|
||||
|
||||
if cfg.RegionID == "" {
|
||||
cfg.RegionID = "cn-hangzhou"
|
||||
}
|
||||
if cfg.CodeParamName == "" {
|
||||
cfg.CodeParamName = "code"
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func normalizeTencentSMSConfig(cfg TencentSMSConfig) TencentSMSConfig {
|
||||
cfg.SecretID = strings.TrimSpace(cfg.SecretID)
|
||||
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
|
||||
cfg.AppID = strings.TrimSpace(cfg.AppID)
|
||||
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
||||
cfg.TemplateID = strings.TrimSpace(cfg.TemplateID)
|
||||
cfg.Region = strings.TrimSpace(cfg.Region)
|
||||
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
||||
|
||||
if cfg.Region == "" {
|
||||
cfg.Region = "ap-guangzhou"
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
func normalizePhoneForSMS(phone string) string {
|
||||
phone = strings.TrimSpace(phone)
|
||||
|
||||
switch {
|
||||
case mainlandPhonePattern.MatchString(phone):
|
||||
return "+86" + phone
|
||||
case mainlandPhone86Pattern.MatchString(phone):
|
||||
return "+" + phone
|
||||
case mainlandPhone0086Pattern.MatchString(phone):
|
||||
return "+86" + mainlandPhone0086Pattern.ReplaceAllString(phone, "$1")
|
||||
default:
|
||||
return phone
|
||||
}
|
||||
}
|
||||
|
||||
func stringPointerOrNil(value string) *string {
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return dara.String(value)
|
||||
}
|
||||
|
||||
func pointerString(value *string) string {
|
||||
if value == nil {
|
||||
return ""
|
||||
}
|
||||
return *value
|
||||
}
|
||||
|
||||
func valueOrDefault(value, fallback string) string {
|
||||
if strings.TrimSpace(value) == "" {
|
||||
return fallback
|
||||
}
|
||||
return value
|
||||
}
|
||||
124
internal/service/stats.go
Normal file
124
internal/service/stats.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// StatsService 统计服务
|
||||
type StatsService struct {
|
||||
userRepo *repository.UserRepository
|
||||
loginLogRepo *repository.LoginLogRepository
|
||||
}
|
||||
|
||||
// NewStatsService 创建统计服务
|
||||
func NewStatsService(
|
||||
userRepo *repository.UserRepository,
|
||||
loginLogRepo *repository.LoginLogRepository,
|
||||
) *StatsService {
|
||||
return &StatsService{
|
||||
userRepo: userRepo,
|
||||
loginLogRepo: loginLogRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// UserStats 用户统计数据
|
||||
type UserStats struct {
|
||||
TotalUsers int64 `json:"total_users"`
|
||||
ActiveUsers int64 `json:"active_users"`
|
||||
InactiveUsers int64 `json:"inactive_users"`
|
||||
LockedUsers int64 `json:"locked_users"`
|
||||
DisabledUsers int64 `json:"disabled_users"`
|
||||
NewUsersToday int64 `json:"new_users_today"`
|
||||
NewUsersWeek int64 `json:"new_users_week"`
|
||||
NewUsersMonth int64 `json:"new_users_month"`
|
||||
}
|
||||
|
||||
// LoginStats 登录统计数据
|
||||
type LoginStats struct {
|
||||
LoginsTodaySuccess int64 `json:"logins_today_success"`
|
||||
LoginsTodayFailed int64 `json:"logins_today_failed"`
|
||||
LoginsWeek int64 `json:"logins_week"`
|
||||
}
|
||||
|
||||
// DashboardStats 仪表盘综合统计
|
||||
type DashboardStats struct {
|
||||
Users UserStats `json:"users"`
|
||||
Logins LoginStats `json:"logins"`
|
||||
}
|
||||
|
||||
// GetUserStats 获取用户统计
|
||||
func (s *StatsService) GetUserStats(ctx context.Context) (*UserStats, error) {
|
||||
stats := &UserStats{}
|
||||
|
||||
// 统计总用户数
|
||||
_, total, err := s.userRepo.List(ctx, 0, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stats.TotalUsers = total
|
||||
|
||||
// 按状态统计
|
||||
statusCounts := map[domain.UserStatus]*int64{
|
||||
domain.UserStatusActive: &stats.ActiveUsers,
|
||||
domain.UserStatusInactive: &stats.InactiveUsers,
|
||||
domain.UserStatusLocked: &stats.LockedUsers,
|
||||
domain.UserStatusDisabled: &stats.DisabledUsers,
|
||||
}
|
||||
for status, countPtr := range statusCounts {
|
||||
_, cnt, err := s.userRepo.ListByStatus(ctx, status, 0, 1)
|
||||
if err == nil {
|
||||
*countPtr = cnt
|
||||
}
|
||||
}
|
||||
|
||||
// 今日新增
|
||||
stats.NewUsersToday = s.countNewUsers(ctx, daysAgo(0))
|
||||
// 本周新增
|
||||
stats.NewUsersWeek = s.countNewUsers(ctx, daysAgo(7))
|
||||
// 本月新增
|
||||
stats.NewUsersMonth = s.countNewUsers(ctx, daysAgo(30))
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// countNewUsers 统计指定时间之后的新增用户数
|
||||
func (s *StatsService) countNewUsers(ctx context.Context, since time.Time) int64 {
|
||||
_, count, err := s.userRepo.ListCreatedAfter(ctx, since, 0, 0)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// GetDashboardStats 获取仪表盘综合统计
|
||||
func (s *StatsService) GetDashboardStats(ctx context.Context) (*DashboardStats, error) {
|
||||
userStats, err := s.GetUserStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
loginStats := &LoginStats{}
|
||||
// 今日登录成功/失败
|
||||
today := daysAgo(0)
|
||||
if s.loginLogRepo != nil {
|
||||
loginStats.LoginsTodaySuccess = s.loginLogRepo.CountByResultSince(ctx, true, today)
|
||||
loginStats.LoginsTodayFailed = s.loginLogRepo.CountByResultSince(ctx, false, today)
|
||||
loginStats.LoginsWeek = s.loginLogRepo.CountByResultSince(ctx, true, daysAgo(7))
|
||||
}
|
||||
|
||||
return &DashboardStats{
|
||||
Users: *userStats,
|
||||
Logins: *loginStats,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// daysAgo 返回N天前的时间(当天0点)
|
||||
func daysAgo(n int) time.Time {
|
||||
now := time.Now()
|
||||
start := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
|
||||
return start.AddDate(0, 0, -n)
|
||||
}
|
||||
206
internal/service/theme.go
Normal file
206
internal/service/theme.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// ThemeService 主题服务
|
||||
type ThemeService struct {
|
||||
themeRepo *repository.ThemeConfigRepository
|
||||
}
|
||||
|
||||
// NewThemeService 创建主题服务
|
||||
func NewThemeService(themeRepo *repository.ThemeConfigRepository) *ThemeService {
|
||||
return &ThemeService{themeRepo: themeRepo}
|
||||
}
|
||||
|
||||
// CreateThemeRequest 创建主题请求
|
||||
type CreateThemeRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
LogoURL string `json:"logo_url"`
|
||||
FaviconURL string `json:"favicon_url"`
|
||||
PrimaryColor string `json:"primary_color"`
|
||||
SecondaryColor string `json:"secondary_color"`
|
||||
BackgroundColor string `json:"background_color"`
|
||||
TextColor string `json:"text_color"`
|
||||
CustomCSS string `json:"custom_css"`
|
||||
CustomJS string `json:"custom_js"`
|
||||
IsDefault bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// UpdateThemeRequest 更新主题请求
|
||||
type UpdateThemeRequest struct {
|
||||
LogoURL string `json:"logo_url"`
|
||||
FaviconURL string `json:"favicon_url"`
|
||||
PrimaryColor string `json:"primary_color"`
|
||||
SecondaryColor string `json:"secondary_color"`
|
||||
BackgroundColor string `json:"background_color"`
|
||||
TextColor string `json:"text_color"`
|
||||
CustomCSS string `json:"custom_css"`
|
||||
CustomJS string `json:"custom_js"`
|
||||
Enabled *bool `json:"enabled"`
|
||||
IsDefault *bool `json:"is_default"`
|
||||
}
|
||||
|
||||
// CreateTheme 创建主题
|
||||
func (s *ThemeService) CreateTheme(ctx context.Context, req *CreateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
// 检查主题名称是否已存在
|
||||
existing, err := s.themeRepo.GetByName(ctx, req.Name)
|
||||
if err == nil && existing != nil {
|
||||
return nil, errors.New("主题名称已存在")
|
||||
}
|
||||
|
||||
theme := &domain.ThemeConfig{
|
||||
Name: req.Name,
|
||||
LogoURL: req.LogoURL,
|
||||
FaviconURL: req.FaviconURL,
|
||||
PrimaryColor: req.PrimaryColor,
|
||||
SecondaryColor: req.SecondaryColor,
|
||||
BackgroundColor: req.BackgroundColor,
|
||||
TextColor: req.TextColor,
|
||||
CustomCSS: req.CustomCSS,
|
||||
CustomJS: req.CustomJS,
|
||||
IsDefault: req.IsDefault,
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
// 如果设置为默认,先清除其他默认
|
||||
if req.IsDefault {
|
||||
if err := s.clearDefaultThemes(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.themeRepo.Create(ctx, theme); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return theme, nil
|
||||
}
|
||||
|
||||
// UpdateTheme 更新主题
|
||||
func (s *ThemeService) UpdateTheme(ctx context.Context, id int64, req *UpdateThemeRequest) (*domain.ThemeConfig, error) {
|
||||
theme, err := s.themeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, errors.New("主题不存在")
|
||||
}
|
||||
|
||||
if req.LogoURL != "" {
|
||||
theme.LogoURL = req.LogoURL
|
||||
}
|
||||
if req.FaviconURL != "" {
|
||||
theme.FaviconURL = req.FaviconURL
|
||||
}
|
||||
if req.PrimaryColor != "" {
|
||||
theme.PrimaryColor = req.PrimaryColor
|
||||
}
|
||||
if req.SecondaryColor != "" {
|
||||
theme.SecondaryColor = req.SecondaryColor
|
||||
}
|
||||
if req.BackgroundColor != "" {
|
||||
theme.BackgroundColor = req.BackgroundColor
|
||||
}
|
||||
if req.TextColor != "" {
|
||||
theme.TextColor = req.TextColor
|
||||
}
|
||||
if req.CustomCSS != "" {
|
||||
theme.CustomCSS = req.CustomCSS
|
||||
}
|
||||
if req.CustomJS != "" {
|
||||
theme.CustomJS = req.CustomJS
|
||||
}
|
||||
if req.Enabled != nil {
|
||||
theme.Enabled = *req.Enabled
|
||||
}
|
||||
if req.IsDefault != nil && *req.IsDefault {
|
||||
if err := s.clearDefaultThemes(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theme.IsDefault = true
|
||||
}
|
||||
|
||||
if err := s.themeRepo.Update(ctx, theme); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return theme, nil
|
||||
}
|
||||
|
||||
// DeleteTheme 删除主题
|
||||
func (s *ThemeService) DeleteTheme(ctx context.Context, id int64) error {
|
||||
theme, err := s.themeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.New("主题不存在")
|
||||
}
|
||||
|
||||
if theme.IsDefault {
|
||||
return errors.New("不能删除默认主题")
|
||||
}
|
||||
|
||||
return s.themeRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// GetTheme 获取主题
|
||||
func (s *ThemeService) GetTheme(ctx context.Context, id int64) (*domain.ThemeConfig, error) {
|
||||
return s.themeRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListThemes 获取所有已启用主题
|
||||
func (s *ThemeService) ListThemes(ctx context.Context) ([]*domain.ThemeConfig, error) {
|
||||
return s.themeRepo.List(ctx)
|
||||
}
|
||||
|
||||
// ListAllThemes 获取所有主题
|
||||
func (s *ThemeService) ListAllThemes(ctx context.Context) ([]*domain.ThemeConfig, error) {
|
||||
return s.themeRepo.ListAll(ctx)
|
||||
}
|
||||
|
||||
// GetDefaultTheme 获取默认主题
|
||||
func (s *ThemeService) GetDefaultTheme(ctx context.Context) (*domain.ThemeConfig, error) {
|
||||
return s.themeRepo.GetDefault(ctx)
|
||||
}
|
||||
|
||||
// SetDefaultTheme 设置默认主题
|
||||
func (s *ThemeService) SetDefaultTheme(ctx context.Context, id int64) error {
|
||||
theme, err := s.themeRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return errors.New("主题不存在")
|
||||
}
|
||||
|
||||
if !theme.Enabled {
|
||||
return errors.New("不能将禁用的主题设为默认")
|
||||
}
|
||||
|
||||
return s.themeRepo.SetDefault(ctx, id)
|
||||
}
|
||||
|
||||
// GetActiveTheme 获取当前生效的主题
|
||||
func (s *ThemeService) GetActiveTheme(ctx context.Context) (*domain.ThemeConfig, error) {
|
||||
theme, err := s.themeRepo.GetDefault(ctx)
|
||||
if err != nil {
|
||||
// 返回默认配置
|
||||
return domain.DefaultThemeConfig(), nil
|
||||
}
|
||||
return theme, nil
|
||||
}
|
||||
|
||||
// clearDefaultThemes 清除所有默认主题标记
|
||||
func (s *ThemeService) clearDefaultThemes(ctx context.Context) error {
|
||||
themes, err := s.themeRepo.ListAll(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, t := range themes {
|
||||
if t.IsDefault {
|
||||
t.IsDefault = false
|
||||
if err := s.themeRepo.Update(ctx, t); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
148
internal/service/totp.go
Normal file
148
internal/service/totp.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
)
|
||||
|
||||
// TOTPService manages 2FA setup, enable/disable, and verification.
|
||||
type TOTPService struct {
|
||||
userRepo userRepositoryInterface
|
||||
totpManager *auth.TOTPManager
|
||||
}
|
||||
|
||||
func NewTOTPService(userRepo userRepositoryInterface) *TOTPService {
|
||||
return &TOTPService{
|
||||
userRepo: userRepo,
|
||||
totpManager: auth.NewTOTPManager(),
|
||||
}
|
||||
}
|
||||
|
||||
type SetupTOTPResponse struct {
|
||||
Secret string `json:"secret"`
|
||||
QRCodeBase64 string `json:"qr_code_base64"`
|
||||
RecoveryCodes []string `json:"recovery_codes"`
|
||||
}
|
||||
|
||||
func (s *TOTPService) SetupTOTP(ctx context.Context, userID int64) (*SetupTOTPResponse, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
}
|
||||
if user.TOTPEnabled {
|
||||
return nil, errors.New("2FA \u5df2\u7ecf\u542f\u7528\uff0c\u5982\u9700\u91cd\u7f6e\u8bf7\u5148\u7981\u7528")
|
||||
}
|
||||
|
||||
setup, err := s.totpManager.GenerateSecret(user.Username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("\u751f\u6210 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
||||
}
|
||||
|
||||
// Persist the generated secret and recovery codes before activation.
|
||||
user.TOTPSecret = setup.Secret
|
||||
// Hash recovery codes before storing (SEC-03 fix)
|
||||
hashedCodes := make([]string, len(setup.RecoveryCodes))
|
||||
for i, code := range setup.RecoveryCodes {
|
||||
hashedCodes[i], _ = auth.HashRecoveryCode(code)
|
||||
}
|
||||
codesJSON, _ := json.Marshal(hashedCodes)
|
||||
user.TOTPRecoveryCodes = string(codesJSON)
|
||||
|
||||
if err := s.userRepo.UpdateTOTP(ctx, user); err != nil {
|
||||
return nil, fmt.Errorf("\u4fdd\u5b58 TOTP \u5bc6\u94a5\u5931\u8d25: %w", err)
|
||||
}
|
||||
|
||||
return &SetupTOTPResponse{
|
||||
Secret: setup.Secret,
|
||||
QRCodeBase64: setup.QRCodeBase64,
|
||||
RecoveryCodes: setup.RecoveryCodes,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *TOTPService) EnableTOTP(ctx context.Context, userID int64, code string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
}
|
||||
if user.TOTPSecret == "" {
|
||||
return errors.New("\u8bf7\u5148\u521d\u59cb\u5316 2FA\uff0c\u83b7\u53d6\u4e8c\u7ef4\u7801\u540e\u518d\u6fc0\u6d3b")
|
||||
}
|
||||
if user.TOTPEnabled {
|
||||
return errors.New("2FA \u5df2\u542f\u7528")
|
||||
}
|
||||
|
||||
if !s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
||||
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
||||
}
|
||||
|
||||
user.TOTPEnabled = true
|
||||
return s.userRepo.UpdateTOTP(ctx, user)
|
||||
}
|
||||
|
||||
func (s *TOTPService) DisableTOTP(ctx context.Context, userID int64, code string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
}
|
||||
if !user.TOTPEnabled {
|
||||
return errors.New("2FA \u672a\u542f\u7528")
|
||||
}
|
||||
|
||||
valid := s.totpManager.ValidateCode(user.TOTPSecret, code)
|
||||
if !valid {
|
||||
var hashedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &hashedCodes)
|
||||
}
|
||||
_, matched := auth.VerifyRecoveryCode(code, hashedCodes)
|
||||
if !matched {
|
||||
return errors.New("\u9a8c\u8bc1\u7801\u6216\u6062\u590d\u7801\u9519\u8bef")
|
||||
}
|
||||
}
|
||||
|
||||
user.TOTPEnabled = false
|
||||
user.TOTPSecret = ""
|
||||
user.TOTPRecoveryCodes = ""
|
||||
return s.userRepo.UpdateTOTP(ctx, user)
|
||||
}
|
||||
|
||||
func (s *TOTPService) VerifyTOTP(ctx context.Context, userID int64, code string) error {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
}
|
||||
if !user.TOTPEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if s.totpManager.ValidateCode(user.TOTPSecret, code) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var storedCodes []string
|
||||
if user.TOTPRecoveryCodes != "" {
|
||||
_ = json.Unmarshal([]byte(user.TOTPRecoveryCodes), &storedCodes)
|
||||
}
|
||||
idx, matched := auth.ValidateRecoveryCode(code, storedCodes)
|
||||
if !matched {
|
||||
return errors.New("\u9a8c\u8bc1\u7801\u9519\u8bef\u6216\u5df2\u8fc7\u671f")
|
||||
}
|
||||
|
||||
storedCodes = append(storedCodes[:idx], storedCodes[idx+1:]...)
|
||||
codesJSON, _ := json.Marshal(storedCodes)
|
||||
user.TOTPRecoveryCodes = string(codesJSON)
|
||||
_ = s.userRepo.UpdateTOTP(ctx, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *TOTPService) GetTOTPStatus(ctx context.Context, userID int64) (bool, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("\u7528\u6237\u4e0d\u5b58\u5728")
|
||||
}
|
||||
return user.TOTPEnabled, nil
|
||||
}
|
||||
133
internal/service/user_service.go
Normal file
133
internal/service/user_service.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
// UserService 用户服务
|
||||
type UserService struct {
|
||||
userRepo *repository.UserRepository
|
||||
userRoleRepo *repository.UserRoleRepository
|
||||
roleRepo *repository.RoleRepository
|
||||
passwordHistoryRepo *repository.PasswordHistoryRepository
|
||||
}
|
||||
|
||||
const passwordHistoryLimit = 5 // 保留最近5条密码历史
|
||||
|
||||
// NewUserService 创建用户服务实例
|
||||
func NewUserService(
|
||||
userRepo *repository.UserRepository,
|
||||
userRoleRepo *repository.UserRoleRepository,
|
||||
roleRepo *repository.RoleRepository,
|
||||
passwordHistoryRepo *repository.PasswordHistoryRepository,
|
||||
) *UserService {
|
||||
return &UserService{
|
||||
userRepo: userRepo,
|
||||
userRoleRepo: userRoleRepo,
|
||||
roleRepo: roleRepo,
|
||||
passwordHistoryRepo: passwordHistoryRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// ChangePassword 修改用户密码(含历史记录检查)
|
||||
func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassword, newPassword string) error {
|
||||
if s.userRepo == nil {
|
||||
return errors.New("user repository is not configured")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return errors.New("用户不存在")
|
||||
}
|
||||
|
||||
// 验证旧密码
|
||||
if strings.TrimSpace(oldPassword) == "" {
|
||||
return errors.New("请输入当前密码")
|
||||
}
|
||||
if !auth.VerifyPassword(user.Password, oldPassword) {
|
||||
return errors.New("当前密码不正确")
|
||||
}
|
||||
|
||||
// 检查新密码强度
|
||||
if strings.TrimSpace(newPassword) == "" {
|
||||
return errors.New("新密码不能为空")
|
||||
}
|
||||
if err := validatePasswordStrength(newPassword, 8, false); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 检查密码历史
|
||||
if s.passwordHistoryRepo != nil {
|
||||
histories, err := s.passwordHistoryRepo.GetByUserID(ctx, userID, passwordHistoryLimit)
|
||||
if err == nil && len(histories) > 0 {
|
||||
for _, h := range histories {
|
||||
if auth.VerifyPassword(h.PasswordHash, newPassword) {
|
||||
return errors.New("新密码不能与最近5次密码相同")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 保存新密码到历史记录
|
||||
newHashedPassword, hashErr := auth.HashPassword(newPassword)
|
||||
if hashErr != nil {
|
||||
return errors.New("密码哈希失败")
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = s.passwordHistoryRepo.Create(context.Background(), &domain.PasswordHistory{
|
||||
UserID: userID,
|
||||
PasswordHash: newHashedPassword,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(context.Background(), userID, passwordHistoryLimit)
|
||||
}()
|
||||
}
|
||||
|
||||
// 更新密码
|
||||
newHashedPassword, err := auth.HashPassword(newPassword)
|
||||
if err != nil {
|
||||
return errors.New("密码哈希失败")
|
||||
}
|
||||
user.Password = newHashedPassword
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
func (s *UserService) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||||
return s.userRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// GetByEmail 根据邮箱获取用户
|
||||
func (s *UserService) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||||
return s.userRepo.GetByEmail(ctx, email)
|
||||
}
|
||||
|
||||
// Create 创建用户
|
||||
func (s *UserService) Create(ctx context.Context, user *domain.User) error {
|
||||
return s.userRepo.Create(ctx, user)
|
||||
}
|
||||
|
||||
// Update 更新用户
|
||||
func (s *UserService) Update(ctx context.Context, user *domain.User) error {
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
// Delete 删除用户
|
||||
func (s *UserService) Delete(ctx context.Context, id int64) error {
|
||||
return s.userRepo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
// List 获取用户列表
|
||||
func (s *UserService) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||||
return s.userRepo.List(ctx, offset, limit)
|
||||
}
|
||||
|
||||
// UpdateStatus 更新用户状态
|
||||
func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||||
return s.userRepo.UpdateStatus(ctx, id, status)
|
||||
}
|
||||
484
internal/service/webhook.go
Normal file
484
internal/service/webhook.go
Normal file
@@ -0,0 +1,484 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
cryptorand "crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// WebhookService Webhook 服务
|
||||
type WebhookService struct {
|
||||
db *gorm.DB
|
||||
repo *repository.WebhookRepository
|
||||
queue chan *deliveryTask
|
||||
workers int
|
||||
config WebhookServiceConfig
|
||||
wg sync.WaitGroup
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
type WebhookServiceConfig struct {
|
||||
Enabled bool
|
||||
SecretHeader string
|
||||
TimeoutSec int
|
||||
MaxRetries int
|
||||
RetryBackoff string
|
||||
WorkerCount int
|
||||
QueueSize int
|
||||
}
|
||||
|
||||
// deliveryTask 投递任务
|
||||
type deliveryTask struct {
|
||||
webhook *domain.Webhook
|
||||
eventType domain.WebhookEventType
|
||||
payload []byte
|
||||
attempt int
|
||||
}
|
||||
|
||||
// WebhookEvent 发布的事件结构
|
||||
type WebhookEvent struct {
|
||||
EventID string `json:"event_id"`
|
||||
EventType domain.WebhookEventType `json:"event_type"`
|
||||
Timestamp time.Time `json:"timestamp"`
|
||||
Data interface{} `json:"data"`
|
||||
}
|
||||
|
||||
// NewWebhookService 创建 Webhook 服务
|
||||
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
|
||||
cfg := defaultWebhookServiceConfig()
|
||||
if len(cfgs) > 0 {
|
||||
cfg = cfgs[0]
|
||||
}
|
||||
if cfg.WorkerCount <= 0 {
|
||||
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
|
||||
}
|
||||
if cfg.QueueSize <= 0 {
|
||||
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
|
||||
}
|
||||
if cfg.SecretHeader == "" {
|
||||
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
|
||||
}
|
||||
if cfg.TimeoutSec <= 0 {
|
||||
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
|
||||
}
|
||||
if cfg.MaxRetries <= 0 {
|
||||
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
|
||||
}
|
||||
if cfg.RetryBackoff == "" {
|
||||
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
|
||||
}
|
||||
|
||||
svc := &WebhookService{
|
||||
db: db,
|
||||
repo: repository.NewWebhookRepository(db),
|
||||
queue: make(chan *deliveryTask, cfg.QueueSize),
|
||||
workers: cfg.WorkerCount,
|
||||
config: cfg,
|
||||
}
|
||||
svc.startWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
func defaultWebhookServiceConfig() WebhookServiceConfig {
|
||||
return WebhookServiceConfig{
|
||||
Enabled: true,
|
||||
SecretHeader: "X-Webhook-Signature",
|
||||
TimeoutSec: 10,
|
||||
MaxRetries: 3,
|
||||
RetryBackoff: "exponential",
|
||||
WorkerCount: 4,
|
||||
QueueSize: 1000,
|
||||
}
|
||||
}
|
||||
|
||||
// startWorkers 启动后台投递 worker
|
||||
func (s *WebhookService) startWorkers() {
|
||||
s.once.Do(func() {
|
||||
for i := 0; i < s.workers; i++ {
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
for task := range s.queue {
|
||||
s.deliver(task)
|
||||
}
|
||||
}()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
||||
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
||||
if !s.config.Enabled {
|
||||
return
|
||||
}
|
||||
// 查询所有活跃 Webhook
|
||||
webhooks, err := s.repo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// 构建事件载荷
|
||||
eventID, err := generateEventID()
|
||||
if err != nil {
|
||||
slog.Error("generate event ID failed", "error", err)
|
||||
return
|
||||
}
|
||||
event := &WebhookEvent{
|
||||
EventID: eventID,
|
||||
EventType: eventType,
|
||||
Timestamp: time.Now().UTC(),
|
||||
Data: data,
|
||||
}
|
||||
payloadBytes, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for i := range webhooks {
|
||||
wh := webhooks[i]
|
||||
// 检查是否订阅了该事件类型
|
||||
if !webhookSubscribesTo(wh, eventType) {
|
||||
continue
|
||||
}
|
||||
|
||||
task := &deliveryTask{
|
||||
webhook: wh,
|
||||
eventType: eventType,
|
||||
payload: payloadBytes,
|
||||
attempt: 1,
|
||||
}
|
||||
|
||||
// 非阻塞投递到队列
|
||||
select {
|
||||
case s.queue <- task:
|
||||
default:
|
||||
// 队列满时记录但不阻塞
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// deliver 执行单次 HTTP 投递
|
||||
func (s *WebhookService) deliver(task *deliveryTask) {
|
||||
wh := task.webhook
|
||||
|
||||
// NEW-SEC-01 修复:检查 URL 安全性
|
||||
if !isSafeURL(wh.URL) {
|
||||
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
|
||||
return
|
||||
}
|
||||
|
||||
timeout := time.Duration(wh.TimeoutSec) * time.Second
|
||||
if timeout <= 0 {
|
||||
timeout = time.Duration(s.config.TimeoutSec) * time.Second
|
||||
}
|
||||
if timeout <= 0 {
|
||||
timeout = 10 * time.Second
|
||||
}
|
||||
|
||||
client := &http.Client{Timeout: timeout}
|
||||
|
||||
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
|
||||
if err != nil {
|
||||
s.recordDelivery(task, 0, "", err.Error(), false)
|
||||
return
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
|
||||
req.Header.Set("X-Webhook-Event", string(task.eventType))
|
||||
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
|
||||
|
||||
// HMAC 签名
|
||||
if wh.Secret != "" {
|
||||
sig := computeHMAC(task.payload, wh.Secret)
|
||||
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
|
||||
}
|
||||
|
||||
// 使用带超时的 context 避免请求无限等待
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
resp, err := client.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
s.handleFailure(task, 0, "", err.Error())
|
||||
return
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var respBuf bytes.Buffer
|
||||
respBuf.ReadFrom(resp.Body)
|
||||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||||
|
||||
if !success {
|
||||
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
|
||||
return
|
||||
}
|
||||
|
||||
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
|
||||
}
|
||||
|
||||
// handleFailure 处理投递失败(重试逻辑)
|
||||
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
|
||||
s.recordDelivery(task, statusCode, body, errMsg, false)
|
||||
|
||||
// 指数退避重试
|
||||
if task.attempt < task.webhook.MaxRetries {
|
||||
backoff := time.Second
|
||||
if s.config.RetryBackoff == "fixed" {
|
||||
backoff = 2 * time.Second
|
||||
} else {
|
||||
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
|
||||
}
|
||||
time.AfterFunc(backoff, func() {
|
||||
task.attempt++
|
||||
select {
|
||||
case s.queue <- task:
|
||||
default:
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// recordDelivery 记录投递日志
|
||||
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
|
||||
now := time.Now()
|
||||
delivery := &domain.WebhookDelivery{
|
||||
WebhookID: task.webhook.ID,
|
||||
EventType: task.eventType,
|
||||
Payload: string(task.payload),
|
||||
StatusCode: statusCode,
|
||||
ResponseBody: body,
|
||||
Attempt: task.attempt,
|
||||
Success: success,
|
||||
Error: errMsg,
|
||||
}
|
||||
if success {
|
||||
delivery.DeliveredAt = &now
|
||||
}
|
||||
_ = s.repo.CreateDelivery(context.Background(), delivery)
|
||||
}
|
||||
|
||||
// CreateWebhook 创建 Webhook
|
||||
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
|
||||
eventsJSON, err := json.Marshal(req.Events)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("序列化事件列表失败")
|
||||
}
|
||||
|
||||
secret := req.Secret
|
||||
if secret == "" {
|
||||
generatedSecret, err := generateWebhookSecret()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
|
||||
}
|
||||
secret = generatedSecret
|
||||
}
|
||||
|
||||
wh := &domain.Webhook{
|
||||
Name: req.Name,
|
||||
URL: req.URL,
|
||||
Secret: secret,
|
||||
Events: string(eventsJSON),
|
||||
Status: domain.WebhookStatusActive,
|
||||
MaxRetries: s.config.MaxRetries,
|
||||
TimeoutSec: s.config.TimeoutSec,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
if err := s.repo.Create(ctx, wh); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return wh, nil
|
||||
}
|
||||
|
||||
// UpdateWebhook 更新 Webhook
|
||||
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
|
||||
updates := map[string]interface{}{}
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.URL != "" {
|
||||
updates["url"] = req.URL
|
||||
}
|
||||
if len(req.Events) > 0 {
|
||||
b, _ := json.Marshal(req.Events)
|
||||
updates["events"] = string(b)
|
||||
}
|
||||
if req.Status != nil {
|
||||
updates["status"] = *req.Status
|
||||
}
|
||||
return s.repo.Update(ctx, id, updates)
|
||||
}
|
||||
|
||||
// DeleteWebhook 删除 Webhook
|
||||
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
|
||||
return s.repo.Delete(ctx, id)
|
||||
}
|
||||
|
||||
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
|
||||
return s.repo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
// ListWebhooks 获取 Webhook 列表(不分页)
|
||||
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
|
||||
return s.repo.ListByCreator(ctx, createdBy)
|
||||
}
|
||||
|
||||
// ListWebhooksPaginated 获取 Webhook 列表(分页)
|
||||
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
|
||||
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
|
||||
}
|
||||
|
||||
// GetWebhookDeliveries 获取投递记录
|
||||
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
|
||||
return s.repo.ListDeliveries(ctx, webhookID, limit)
|
||||
}
|
||||
|
||||
// ---- Request/Response 结构 ----
|
||||
|
||||
// CreateWebhookRequest 创建 Webhook 请求
|
||||
type CreateWebhookRequest struct {
|
||||
Name string `json:"name" binding:"required"`
|
||||
URL string `json:"url" binding:"required,url"`
|
||||
Secret string `json:"secret"`
|
||||
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
|
||||
}
|
||||
|
||||
// UpdateWebhookRequest 更新 Webhook 请求
|
||||
type UpdateWebhookRequest struct {
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
Events []domain.WebhookEventType `json:"events"`
|
||||
Status *domain.WebhookStatus `json:"status"`
|
||||
}
|
||||
|
||||
// ---- 辅助函数 ----
|
||||
|
||||
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
|
||||
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
|
||||
var events []domain.WebhookEventType
|
||||
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
|
||||
return false
|
||||
}
|
||||
for _, e := range events {
|
||||
if e == eventType || e == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
|
||||
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
|
||||
|
||||
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
|
||||
// NEW-SEC-01 修复:添加完整的 URL 安全检查
|
||||
func isSafeURL(rawURL string) bool {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil || u.Scheme == "" {
|
||||
return false
|
||||
}
|
||||
// 只允许 http/https
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return false
|
||||
}
|
||||
|
||||
host := u.Hostname()
|
||||
|
||||
// 禁止 localhost
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查内网 IP
|
||||
if ip := net.ParseIP(host); ip != nil {
|
||||
if isPrivateIP(ip) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 检查内网域名
|
||||
if strings.HasSuffix(host, ".internal") ||
|
||||
strings.HasSuffix(host, ".local") ||
|
||||
strings.HasSuffix(host, ".corp") ||
|
||||
strings.HasSuffix(host, ".lan") ||
|
||||
strings.HasSuffix(host, ".intranet") {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查知名内网服务地址
|
||||
blockedHosts := []string{
|
||||
"metadata.google.internal", // GCP 元数据服务
|
||||
"169.254.169.254", // AWS/Azure/GCP 元数据服务
|
||||
"metadata.azure.internal", // Azure 元数据服务
|
||||
"100.100.100.200", // 阿里云元数据服务
|
||||
}
|
||||
for _, blocked := range blockedHosts {
|
||||
if host == blocked {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isPrivateIP 检查是否为内网 IP
|
||||
func isPrivateIP(ip net.IP) bool {
|
||||
privateRanges := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
for _, cidr := range privateRanges {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// computeHMAC 计算 HMAC-SHA256 签名
|
||||
func computeHMAC(payload []byte, secret string) string {
|
||||
mac := hmac.New(sha256.New, []byte(secret))
|
||||
mac.Write(payload)
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
|
||||
// generateEventID 生成随机事件 ID
|
||||
func generateEventID() (string, error) {
|
||||
b := make([]byte, 8)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate event ID failed: %w", err)
|
||||
}
|
||||
return "evt_" + hex.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// generateWebhookSecret 生成随机 Webhook 签名密钥
|
||||
func generateWebhookSecret() (string, error) {
|
||||
b := make([]byte, 24)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", fmt.Errorf("generate webhook secret failed: %w", err)
|
||||
}
|
||||
return strings.ToLower(hex.EncodeToString(b)), nil
|
||||
}
|
||||
Reference in New Issue
Block a user