feat: complete production readiness improvements

- Fix DIP violations in service layer (device, stats, auth middleware)
- Add ReplaceUserRoles interface method for transaction safety
- Implement Magic Bytes validation for avatar uploads
- Standardize OAuth error handling with ErrOAuthProviderNotSupported
- Use crypto/rand for JWT secret generation instead of weak fixed key
- Apply code formatting with gofumpt and goimports
- Fix staticcheck issues (S1024, S1008, ST1005)
- Add comprehensive quality and functional test reports
- Achieve 36.3% test coverage (up from 16.3%)
- All E2E, integration, and business logic tests passing
This commit is contained in:
2026-04-12 16:15:32 +08:00
parent 861736cf4d
commit 09beb173cc
22 changed files with 3122 additions and 414 deletions

View File

@@ -1,9 +1,11 @@
package handler
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
@@ -12,16 +14,21 @@ import (
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// avatarUserRepository interface for dependency inversion (DIP)
type avatarUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
Update(ctx context.Context, user *domain.User) error
}
// AvatarHandler handles avatar upload requests
type AvatarHandler struct {
userRepo *repository.UserRepository
userRepo avatarUserRepository
}
// NewAvatarHandler creates a new AvatarHandler
func NewAvatarHandler(userRepo *repository.UserRepository) *AvatarHandler {
func NewAvatarHandler(userRepo avatarUserRepository) *AvatarHandler {
return &AvatarHandler{userRepo: userRepo}
}
@@ -107,12 +114,37 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
}
defer src.Close()
// Validate Magic Bytes to detect actual file type (prevents file extension spoofing)
buf := make([]byte, 512)
n, err := src.Read(buf)
if err != nil && err != io.EOF {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "failed to read file"})
return
}
contentType := http.DetectContentType(buf[:n])
allowedMIME := map[string]bool{
"image/jpeg": true,
"image/png": true,
"image/gif": true,
"image/webp": true,
}
if !allowedMIME[contentType] {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid file content, allowed: jpeg, png, gif, webp"})
return
}
// Seek back to beginning for full file read
if _, err := src.Seek(0, io.SeekStart); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to read file"})
return
}
// Generate unique filename
avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, generateSecureToken(8), ext)
uploadDir := "./uploads/avatars"
// Create upload directory if not exists
if err := os.MkdirAll(uploadDir, 0755); err != nil {
if err := os.MkdirAll(uploadDir, 0o755); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to create upload directory"})
return
}
@@ -124,7 +156,7 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to read uploaded file"})
return
}
if err := os.WriteFile(dstPath, data, 0644); err != nil {
if err := os.WriteFile(dstPath, data, 0o644); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to save avatar file"})
return
}

View File

@@ -14,38 +14,37 @@ import (
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
"github.com/user-management-system/internal/repository"
)
// Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types.
type authUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
type authUserRoleRepository interface {
GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error)
}
type AuthMiddleware struct {
jwt *auth.JWT
userRepo *repository.UserRepository
userRoleRepo *repository.UserRoleRepository
roleRepo *repository.RoleRepository
rolePermissionRepo *repository.RolePermissionRepository
permissionRepo *repository.PermissionRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
sfGroup singleflight.Group
jwt *auth.JWT
userRepo authUserRepository
userRoleRepo authUserRoleRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
sfGroup singleflight.Group
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo *repository.UserRepository,
userRoleRepo *repository.UserRoleRepository,
roleRepo *repository.RoleRepository,
rolePermissionRepo *repository.RolePermissionRepository,
permissionRepo *repository.PermissionRepository,
userRepo authUserRepository,
userRoleRepo authUserRoleRepository,
l1Cache *cache.L1Cache,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
roleRepo: roleRepo,
rolePermissionRepo: rolePermissionRepo,
permissionRepo: permissionRepo,
l1Cache: l1Cache,
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
l1Cache: l1Cache,
}
}
@@ -69,7 +68,7 @@ func (m *AuthMiddleware) Required() gin.HandlerFunc {
return
}
if m.isJTIBlacklisted(claims.JTI) {
if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
@@ -98,7 +97,7 @@ func (m *AuthMiddleware) Optional() gin.HandlerFunc {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
@@ -112,7 +111,7 @@ func (m *AuthMiddleware) Optional() gin.HandlerFunc {
}
}
func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool {
if jti == "" {
return false
}
@@ -128,7 +127,7 @@ func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
// 多个并发请求只会触发一次 L2 查询
if m.cacheManager != nil {
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
found, _ := m.cacheManager.Get(context.Background(), key)
found, _ := m.cacheManager.Get(ctx, key)
return found, nil
})
if err == nil && val != nil {

View File

@@ -16,7 +16,7 @@ const (
OAuthProviderWeChat OAuthProvider = "wechat"
OAuthProviderQQ OAuthProvider = "qq"
OAuthProviderWeibo OAuthProvider = "weibo"
OAuthProviderGoogle OAuthProvider = "google"
OAuthProviderGoogle OAuthProvider = "google"
OAuthProviderFacebook OAuthProvider = "facebook"
OAuthProviderTwitter OAuthProvider = "twitter"
OAuthProviderGitHub OAuthProvider = "github"
@@ -298,7 +298,7 @@ func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string)
}
}
return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider)
return nil, ErrOAuthProviderNotSupported
}
// GetUserInfo 获取用户信息(使用真实 provider 实现)
@@ -428,7 +428,7 @@ func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthTo
}
}
return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider)
return nil, ErrOAuthProviderNotSupported
}
// ValidateToken 验证令牌
@@ -479,7 +479,7 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider,
// GetEnabledProviders 获取已启用的OAuth提供商
func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo {
providerNames := map[OAuthProvider]string{
OAuthProviderGoogle: "Google",
OAuthProviderGoogle: "Google",
OAuthProviderWeChat: "微信",
OAuthProviderQQ: "QQ",
OAuthProviderWeibo: "微博",

View File

@@ -1091,7 +1091,13 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
originalJWTSecret := cfg.JWT.Secret
if allowMissingJWTSecret && originalJWTSecret == "" {
// 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。
cfg.JWT.Secret = strings.Repeat("0", 32)
// 使用临时随机密钥(而非固定弱密钥)进行验证,确保即使流程异常也不会使用弱密钥。
tmpSecret := make([]byte, 32)
if _, err := rand.Read(tmpSecret); err != nil {
return nil, fmt.Errorf("failed to generate temporary JWT secret: %w", err)
}
cfg.JWT.Secret = hex.EncodeToString(tmpSecret)
slog.Warn("JWT_SECRET not set. Bootstrap mode active - JWT operations will fail until secret is configured.")
}
if err := cfg.Validate(); err != nil {
@@ -1100,6 +1106,7 @@ func load(allowMissingJWTSecret bool) (*Config, error) {
if allowMissingJWTSecret && originalJWTSecret == "" {
cfg.JWT.Secret = ""
slog.Info("JWT_SECRET cleared for bootstrap mode. Ensure secret is set after database initialization.")
}
if !cfg.Security.URLAllowlist.Enabled {
@@ -1516,7 +1523,6 @@ func setDefaults() {
// Subscription Maintenance (bounded queue + worker pool)
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
}
func (c *Config) Validate() error {

View File

@@ -171,7 +171,7 @@ func (r *DeviceRepository) GetActiveDevices(ctx context.Context, userID int64) (
// TrustDevice 设置设备为信任状态
func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expiresAt *time.Time) error {
updates := map[string]interface{}{
"is_trusted": true,
"is_trusted": true,
"trust_expires_at": expiresAt,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error
@@ -180,7 +180,7 @@ func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expi
// UntrustDevice 取消设备信任状态
func (r *DeviceRepository) UntrustDevice(ctx context.Context, deviceID int64) error {
updates := map[string]interface{}{
"is_trusted": false,
"is_trusted": false,
"trust_expires_at": nil,
}
return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error

View File

@@ -136,7 +136,6 @@ func (r *PermissionRepository) GetByRoleIDs(ctx context.Context, roleIDs []int64
Where("role_permissions.role_id IN ?", roleIDs).
Where("permissions.status = ?", domain.PermissionStatusEnabled).
Find(&permissions).Error
if err != nil {
return nil, err
}

View File

@@ -86,11 +86,11 @@ func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int6
// GetUserRolesAndPermissions 获取用户角色和权限PERF-01 优化:合并为单次 JOIN 查询)
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
var results []struct {
RoleID int64
RoleName string
RoleCode string
RoleStatus int
PermissionID int64
RoleID int64
RoleName string
RoleCode string
RoleStatus int
PermissionID int64
PermissionCode string
PermissionName string
}
@@ -118,9 +118,9 @@ func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, use
for _, row := range results {
if _, ok := roleMap[row.RoleID]; !ok {
roleMap[row.RoleID] = &domain.Role{
ID: row.RoleID,
Name: row.RoleName,
Code: row.RoleCode,
ID: row.RoleID,
Name: row.RoleName,
Code: row.RoleCode,
Status: domain.RoleStatus(row.RoleStatus),
}
}
@@ -180,11 +180,38 @@ func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domai
if len(userRoles) == 0 {
return nil
}
var ids []int64
for _, ur := range userRoles {
ids = append(ids, ur.ID)
}
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error
}
// ReplaceUserRoles replaces all roles for a user in a single transaction
// This encapsulates the delete-then-create pattern to ensure atomicity
func (r *UserRoleRepository) ReplaceUserRoles(ctx context.Context, userID int64, roleIDs []int64) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Delete all existing roles for the user
if err := tx.Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error; err != nil {
return err
}
// Create new role associations if any
if len(roleIDs) > 0 {
userRoles := make([]*domain.UserRole, len(roleIDs))
for i, roleID := range roleIDs {
userRoles[i] = &domain.UserRole{
UserID: userID,
RoleID: roleID,
}
}
if err := tx.Create(&userRoles).Error; err != nil {
return err
}
}
return nil
})
}

View File

@@ -87,8 +87,8 @@ type LoginRequest struct {
Email string `json:"email"`
Phone string `json:"phone"`
Password string `json:"password"`
Remember bool `json:"remember"` // 记住登录
DeviceID string `json:"device_id,omitempty"` // 设备唯一标识
Remember bool `json:"remember"` // 记住登录
DeviceID string `json:"device_id,omitempty"` // 设备唯一标识
DeviceName string `json:"device_name,omitempty"` // 设备名称
DeviceBrowser string `json:"device_browser,omitempty"` // 浏览器
DeviceOS string `json:"device_os,omitempty"` // 操作系统
@@ -437,12 +437,12 @@ func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip,
}
s.publishEvent(ctx, domain.EventAnomalyDetected, map[string]interface{}{
"user_id": *userID,
"ip": ip,
"location": location,
"device": deviceFingerprint,
"events": events,
"success": success,
"user_id": *userID,
"ip": ip,
"location": location,
"device": deviceFingerprint,
"events": events,
"success": success,
})
}
@@ -787,7 +787,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
blacklistKey := tokenBlacklistPrefix + claims.JTI
// TTL 设置为 refresh token 的剩余有效期
if claims.ExpiresAt != nil {
remaining := claims.ExpiresAt.Time.Sub(time.Now())
remaining := time.Until(claims.ExpiresAt.Time)
if remaining > 0 {
_ = s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining)
}

View File

@@ -91,9 +91,5 @@ func (s *AuthService) IsAdminBootstrapRequired(ctx context.Context) bool {
}
}
if hadUnexpectedLookupError {
return false
}
return true
return !hadUnexpectedLookupError
}

View File

@@ -11,16 +11,40 @@ import (
"github.com/user-management-system/internal/repository"
)
// Interfaces for dependency inversion (DIP) — service layer depends on these abstractions, not concrete types.
type deviceRepository interface {
Create(ctx context.Context, device *domain.Device) error
Update(ctx context.Context, device *domain.Device) error
Delete(ctx context.Context, id int64) error
GetByID(ctx context.Context, id int64) (*domain.Device, error)
GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error)
Exists(ctx context.Context, userID int64, deviceID string) (bool, error)
ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error)
ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error)
UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error
UpdateLastActiveTime(ctx context.Context, id int64) error
TrustDevice(ctx context.Context, id int64, expiresAt *time.Time) error
UntrustDevice(ctx context.Context, id int64) error
DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error
GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error)
ListAll(ctx context.Context, params *repository.ListDevicesParams) ([]*domain.Device, int64, error)
ListAllCursor(ctx context.Context, params *repository.ListDevicesParams, limit int, cursor *pagination.Cursor) ([]*domain.Device, bool, error)
}
type deviceUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
// DeviceService 设备服务
type DeviceService struct {
deviceRepo *repository.DeviceRepository
userRepo *repository.UserRepository
deviceRepo deviceRepository
userRepo deviceUserRepository
}
// NewDeviceService 创建设备服务
func NewDeviceService(
deviceRepo *repository.DeviceRepository,
userRepo *repository.UserRepository,
deviceRepo deviceRepository,
userRepo deviceUserRepository,
) *DeviceService {
return &DeviceService{
deviceRepo: deviceRepo,
@@ -30,24 +54,24 @@ func NewDeviceService(
// CreateDeviceRequest 创建设备请求
type CreateDeviceRequest struct {
DeviceID string `json:"device_id" binding:"required"`
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceID string `json:"device_id" binding:"required"`
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
IP string `json:"ip"`
Location string `json:"location"`
}
// UpdateDeviceRequest 更新设备请求
type UpdateDeviceRequest struct {
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceName string `json:"device_name"`
DeviceType int `json:"device_type"`
DeviceOS string `json:"device_os"`
DeviceBrowser string `json:"device_browser"`
IP string `json:"ip"`
Location string `json:"location"`
Status int `json:"status"`
IP string `json:"ip"`
Location string `json:"location"`
Status int `json:"status"`
}
// CreateDevice 创建设备
@@ -75,15 +99,15 @@ func (s *DeviceService) CreateDevice(ctx context.Context, userID int64, req *Cre
// 创建设备
device := &domain.Device{
UserID: userID,
DeviceID: req.DeviceID,
DeviceName: req.DeviceName,
DeviceType: domain.DeviceType(req.DeviceType),
DeviceOS: req.DeviceOS,
DeviceBrowser: req.DeviceBrowser,
IP: req.IP,
Location: req.Location,
Status: domain.DeviceStatusActive,
UserID: userID,
DeviceID: req.DeviceID,
DeviceName: req.DeviceName,
DeviceType: domain.DeviceType(req.DeviceType),
DeviceOS: req.DeviceOS,
DeviceBrowser: req.DeviceBrowser,
IP: req.IP,
Location: req.Location,
Status: domain.DeviceStatusActive,
}
if err := s.deviceRepo.Create(ctx, device); err != nil {

View File

@@ -20,6 +20,18 @@ const (
ExportFormatXLSX = "xlsx"
)
// Interfaces for dependency inversion (DIP) — service layer depends on these abstractions, not concrete types.
type exportUserRepository interface {
List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error)
AdvancedSearch(ctx context.Context, filter *repository.AdvancedFilter) ([]*domain.User, int64, error)
ExistsByUsername(ctx context.Context, username string) (bool, error)
Create(ctx context.Context, user *domain.User) error
}
type exportRoleRepository interface {
// Reserved for future use (role assignment during import)
}
// ExportUsersRequest defines the supported export filters and output options.
type ExportUsersRequest struct {
Format string
@@ -53,14 +65,14 @@ var defaultExportColumns = []exportColumn{
// ExportService 用户数据导入导出服务
type ExportService struct {
userRepo *repository.UserRepository
roleRepo *repository.RoleRepository
userRepo exportUserRepository
roleRepo exportRoleRepository
}
// NewExportService 创建导入导出服务
func NewExportService(
userRepo *repository.UserRepository,
roleRepo *repository.RoleRepository,
userRepo exportUserRepository,
roleRepo exportRoleRepository,
) *ExportService {
return &ExportService{
userRepo: userRepo,
@@ -461,13 +473,13 @@ func parseCSVRecords(data []byte) ([][]string, error) {
func parseXLSXRecords(data []byte) ([][]string, error) {
file, err := excelize.OpenReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("Excel 解析失败: %w", err)
return nil, fmt.Errorf("excel parse failed: %w", err)
}
defer file.Close()
sheets := file.GetSheetList()
if len(sheets) == 0 {
return nil, fmt.Errorf("Excel 文件没有可用工作表")
return nil, fmt.Errorf("excel file has no available sheets")
}
rows, err := file.GetRows(sheets[0])

View File

@@ -5,19 +5,29 @@ import (
"time"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
)
// Interfaces for dependency inversion (DIP) — service layer depends on these abstractions, not concrete types.
type statsUserRepository interface {
List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error)
ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error)
ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error)
}
type statsLoginLogRepository interface {
CountByResultSince(ctx context.Context, success bool, since time.Time) int64
}
// StatsService 统计服务
type StatsService struct {
userRepo *repository.UserRepository
loginLogRepo *repository.LoginLogRepository
userRepo statsUserRepository
loginLogRepo statsLoginLogRepository
}
// NewStatsService 创建统计服务
func NewStatsService(
userRepo *repository.UserRepository,
loginLogRepo *repository.LoginLogRepository,
userRepo statsUserRepository,
loginLogRepo statsLoginLogRepository,
) *StatsService {
return &StatsService{
userRepo: userRepo,

View File

@@ -38,6 +38,7 @@ type userRoleRepository interface {
GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error)
GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error)
BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error
ReplaceUserRoles(ctx context.Context, userID int64, roleIDs []int64) error
DB() *gorm.DB
}
@@ -55,10 +56,10 @@ type passwordHistoryRepository interface {
// UserService 用户服务
type UserService struct {
userRepo userRepository
userRoleRepo userRoleRepository
roleRepo roleRepository
passwordHistoryRepo passwordHistoryRepository
userRepo userRepository
userRoleRepo userRoleRepository
roleRepo roleRepository
passwordHistoryRepo passwordHistoryRepository
}
const passwordHistoryLimit = 5 // 保留最近5条密码历史
@@ -73,7 +74,7 @@ func NewUserService(
return &UserService{
userRepo: userRepo,
userRoleRepo: userRoleRepo,
roleRepo: roleRepo,
roleRepo: roleRepo,
passwordHistoryRepo: passwordHistoryRepo,
}
}
@@ -203,13 +204,13 @@ func (s *UserService) ListCursor(ctx context.Context, req *ListCursorRequest) (*
}
filter := &repository.AdvancedFilter{
Keyword: req.Keyword,
Status: req.Status,
RoleIDs: req.RoleIDs,
CreatedFrom: req.CreatedFrom,
CreatedTo: req.CreatedTo,
SortBy: req.SortBy,
SortOrder: req.SortOrder,
Keyword: req.Keyword,
Status: req.Status,
RoleIDs: req.RoleIDs,
CreatedFrom: req.CreatedFrom,
CreatedTo: req.CreatedTo,
SortBy: req.SortBy,
SortOrder: req.SortOrder,
}
users, hasMore, err := s.userRepo.ListCursor(ctx, filter, size, cursor)
@@ -238,8 +239,8 @@ func (s *UserService) UpdateStatus(ctx context.Context, id int64, status domain.
// BatchUpdateStatusRequest 批量更新状态请求
type BatchUpdateStatusRequest struct {
IDs []int64 `json:"ids" binding:"required,min=1"`
Status domain.UserStatus `json:"status" binding:"required"`
IDs []int64 `json:"ids" binding:"required,min=1"`
Status domain.UserStatus `json:"status" binding:"required"`
}
// BatchDeleteRequest 批量删除请求
@@ -305,27 +306,8 @@ func (s *UserService) AssignRoles(ctx context.Context, userID int64, roleIDs []i
}
}
// 构建新的用户角色关联
var userRoles []*domain.UserRole
for _, roleID := range roleIDs {
userRoles = append(userRoles, &domain.UserRole{
UserID: userID,
RoleID: roleID,
})
}
// 使用事务包装删旧建新操作,确保原子性
// Note: WithTx is on concrete type, requires type assertion
txRepo, ok := s.userRoleRepo.(*repository.UserRoleRepository)
if !ok {
return errors.New("userRoleRepo does not support transactions")
}
return s.userRoleRepo.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := txRepo.WithTx(tx).DeleteByUserID(ctx, userID); err != nil {
return err
}
return txRepo.WithTx(tx).BatchCreate(ctx, userRoles)
})
// 使用 Repository 层的事务方法替换用户角色(原子操作)
return s.userRoleRepo.ReplaceUserRoles(ctx, userID, roleIDs)
}
// getAdminRoleID looks up the admin role ID by code to avoid hardcoded magic numbers.
@@ -451,6 +433,6 @@ func (s *UserService) DeleteAdmin(ctx context.Context, userID int64, currentUser
type CreateAdminRequest struct {
Username string `json:"username" binding:"required"`
Password string `json:"password" binding:"required"`
Email string `json:"email"`
Email string `json:"email"`
Nickname string `json:"nickname"`
}