package repository import ( "context" "strings" "time" "gorm.io/gorm" "github.com/user-management-system/internal/domain" ) // escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _) // 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配 func escapeLikePattern(s string) string { // 先转义 \,再转义 % 和 _(顺序很重要) s = strings.ReplaceAll(s, `\`, `\\`) s = strings.ReplaceAll(s, `%`, `\%`) s = strings.ReplaceAll(s, `_`, `\_`) return s } // UserRepository 用户数据访问层 type UserRepository struct { db *gorm.DB } // NewUserRepository 创建用户数据访问层 func NewUserRepository(db *gorm.DB) *UserRepository { return &UserRepository{db: db} } // Create 创建用户 func (r *UserRepository) Create(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Create(user).Error } // Update 更新用户 func (r *UserRepository) Update(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Save(user).Error } // Delete 删除用户(软删除) func (r *UserRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error } // GetByID 根据ID获取用户 func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) { var user domain.User err := r.db.WithContext(ctx).First(&user, id).Error if err != nil { return nil, err } return &user, nil } // GetByUsername 根据用户名获取用户 func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) { var user domain.User err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error if err != nil { return nil, err } return &user, nil } // GetByEmail 根据邮箱获取用户 func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) { var user domain.User err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error if err != nil { return nil, err } return &user, nil } // GetByPhone 根据手机号获取用户 func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) { var user domain.User err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error if err != nil { return nil, err } return &user, nil } // List 获取用户列表 func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) { var users []*domain.User var total int64 query := r.db.WithContext(ctx).Model(&domain.User{}) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // ListByStatus 根据状态获取用户列表 func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) { var users []*domain.User var total int64 query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // UpdateStatus 更新用户状态 func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error { return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error } // UpdateLastLogin 更新最后登录信息 func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error { now := time.Now() return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{ "last_login_time": &now, "last_login_ip": ip, }).Error } // ExistsByUsername 检查用户名是否存在 func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error return count > 0, err } // ExistsByEmail 检查邮箱是否存在 func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error return count > 0, err } // ExistsByPhone 检查手机号是否存在 func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error return count > 0, err } // Search 搜索用户 func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) { var users []*domain.User var total int64 // 转义 LIKE 特殊字符,防止搜索被意外干扰 escapedKeyword := escapeLikePattern(keyword) pattern := "%" + escapedKeyword + "%" query := r.db.WithContext(ctx).Model(&domain.User{}).Where( "username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?", pattern, pattern, pattern, pattern, ) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // UpdateTOTP 更新用户的 TOTP 字段 func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error { return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{ "totp_enabled": user.TOTPEnabled, "totp_secret": user.TOTPSecret, "totp_recovery_codes": user.TOTPRecoveryCodes, }).Error } // UpdatePassword 更新用户密码 func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error { return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error } // ListCreatedAfter 查询指定时间之后创建的用户(limit=0表示不限制数量) func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) { var users []*domain.User var total int64 query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since) if err := query.Count(&total).Error; err != nil { return nil, 0, err } if limit > 0 { query = query.Offset(offset).Limit(limit) } if err := query.Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil } // AdvancedFilter 高级用户筛选请求 type AdvancedFilter struct { Keyword string // 关键字(用户名/邮箱/手机号/昵称) Status int // 状态:-1 全部,0/1/2/3 对应 UserStatus RoleIDs []int64 // 角色ID列表(按角色筛选) CreatedFrom *time.Time // 注册时间范围(起始) CreatedTo *time.Time // 注册时间范围(截止) LastLoginFrom *time.Time // 最后登录时间范围(起始) SortBy string // 排序字段:created_at, last_login_time, username SortOrder string // 排序方向:asc, desc Offset int Limit int } // AdvancedSearch 高级用户搜索(支持多维度组合筛选) func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) { var users []*domain.User var total int64 query := r.db.WithContext(ctx).Model(&domain.User{}) // 关键字搜索(转义 LIKE 特殊字符) if filter.Keyword != "" { like := "%" + escapeLikePattern(filter.Keyword) + "%" query = query.Where( "username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?", like, like, like, like, ) } // 状态筛选 if filter.Status >= 0 { query = query.Where("status = ?", filter.Status) } // 注册时间范围 if filter.CreatedFrom != nil { query = query.Where("created_at >= ?", filter.CreatedFrom) } if filter.CreatedTo != nil { query = query.Where("created_at <= ?", filter.CreatedTo) } // 最后登录时间范围 if filter.LastLoginFrom != nil { query = query.Where("last_login_time >= ?", filter.LastLoginFrom) } // 按角色筛选(子查询) if len(filter.RoleIDs) > 0 { query = query.Where( "id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)", filter.RoleIDs, ) } // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 排序 sortBy := "created_at" sortOrder := "DESC" if filter.SortBy != "" { allowedFields := map[string]bool{ "created_at": true, "last_login_time": true, "username": true, "updated_at": true, } if allowedFields[filter.SortBy] { sortBy = filter.SortBy } } if filter.SortOrder == "asc" { sortOrder = "ASC" } query = query.Order(sortBy + " " + sortOrder) // 分页 limit := filter.Limit if limit <= 0 { limit = 20 } if limit > 200 { limit = 200 } query = query.Offset(filter.Offset).Limit(limit) if err := query.Find(&users).Error; err != nil { return nil, 0, err } return users, total, nil }