315 lines
9.1 KiB
Go
315 lines
9.1 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"strings"
|
||
"time"
|
||
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/user-management-system/internal/domain"
|
||
)
|
||
|
||
// escapeLikePattern 转义 LIKE 模式中的特殊字符(% 和 _)
|
||
// 这些字符在 LIKE 查询中有特殊含义,需要转义才能作为普通字符匹配
|
||
func escapeLikePattern(s string) string {
|
||
// 先转义 \,再转义 % 和 _(顺序很重要)
|
||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||
return s
|
||
}
|
||
|
||
// UserRepository 用户数据访问层
|
||
type UserRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewUserRepository 创建用户数据访问层
|
||
func NewUserRepository(db *gorm.DB) *UserRepository {
|
||
return &UserRepository{db: db}
|
||
}
|
||
|
||
// Create 创建用户
|
||
func (r *UserRepository) Create(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Create(user).Error
|
||
}
|
||
|
||
// Update 更新用户
|
||
func (r *UserRepository) Update(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Save(user).Error
|
||
}
|
||
|
||
// Delete 删除用户(软删除)
|
||
func (r *UserRepository) Delete(ctx context.Context, id int64) error {
|
||
return r.db.WithContext(ctx).Delete(&domain.User{}, id).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取用户
|
||
func (r *UserRepository) GetByID(ctx context.Context, id int64) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).First(&user, id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByUsername 根据用户名获取用户
|
||
func (r *UserRepository) GetByUsername(ctx context.Context, username string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByEmail 根据邮箱获取用户
|
||
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("email = ?", email).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// GetByPhone 根据手机号获取用户
|
||
func (r *UserRepository) GetByPhone(ctx context.Context, phone string) (*domain.User, error) {
|
||
var user domain.User
|
||
err := r.db.WithContext(ctx).Where("phone = ?", phone).First(&user).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &user, nil
|
||
}
|
||
|
||
// List 获取用户列表
|
||
func (r *UserRepository) List(ctx context.Context, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// ListByStatus 根据状态获取用户列表
|
||
func (r *UserRepository) ListByStatus(ctx context.Context, status domain.UserStatus, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("status = ?", status)
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// UpdateStatus 更新用户状态
|
||
func (r *UserRepository) UpdateStatus(ctx context.Context, id int64, status domain.UserStatus) error {
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
|
||
// UpdateLastLogin 更新最后登录信息
|
||
func (r *UserRepository) UpdateLastLogin(ctx context.Context, id int64, ip string) error {
|
||
now := time.Now()
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||
"last_login_time": &now,
|
||
"last_login_ip": ip,
|
||
}).Error
|
||
}
|
||
|
||
// ExistsByUsername 检查用户名是否存在
|
||
func (r *UserRepository) ExistsByUsername(ctx context.Context, username string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("username = ?", username).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// ExistsByEmail 检查邮箱是否存在
|
||
func (r *UserRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("email = ?", email).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// ExistsByPhone 检查手机号是否存在
|
||
func (r *UserRepository) ExistsByPhone(ctx context.Context, phone string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.User{}).Where("phone = ?", phone).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// Search 搜索用户
|
||
func (r *UserRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
// 转义 LIKE 特殊字符,防止搜索被意外干扰
|
||
escapedKeyword := escapeLikePattern(keyword)
|
||
pattern := "%" + escapedKeyword + "%"
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where(
|
||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||
pattern, pattern, pattern, pattern,
|
||
)
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|
||
|
||
// UpdateTOTP 更新用户的 TOTP 字段
|
||
func (r *UserRepository) UpdateTOTP(ctx context.Context, user *domain.User) error {
|
||
return r.db.WithContext(ctx).Model(user).Updates(map[string]interface{}{
|
||
"totp_enabled": user.TOTPEnabled,
|
||
"totp_secret": user.TOTPSecret,
|
||
"totp_recovery_codes": user.TOTPRecoveryCodes,
|
||
}).Error
|
||
}
|
||
|
||
// UpdatePassword 更新用户密码
|
||
func (r *UserRepository) UpdatePassword(ctx context.Context, id int64, hashedPassword string) error {
|
||
return r.db.WithContext(ctx).Model(&domain.User{}).Where("id = ?", id).Update("password", hashedPassword).Error
|
||
}
|
||
|
||
// ListCreatedAfter 查询指定时间之后创建的用户(limit=0表示不限制数量)
|
||
func (r *UserRepository) ListCreatedAfter(ctx context.Context, since time.Time, offset, limit int) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
query := r.db.WithContext(ctx).Model(&domain.User{}).Where("created_at >= ?", since)
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
if limit > 0 {
|
||
query = query.Offset(offset).Limit(limit)
|
||
}
|
||
if err := query.Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
return users, total, nil
|
||
}
|
||
|
||
// AdvancedFilter 高级用户筛选请求
|
||
type AdvancedFilter struct {
|
||
Keyword string // 关键字(用户名/邮箱/手机号/昵称)
|
||
Status int // 状态:-1 全部,0/1/2/3 对应 UserStatus
|
||
RoleIDs []int64 // 角色ID列表(按角色筛选)
|
||
CreatedFrom *time.Time // 注册时间范围(起始)
|
||
CreatedTo *time.Time // 注册时间范围(截止)
|
||
LastLoginFrom *time.Time // 最后登录时间范围(起始)
|
||
SortBy string // 排序字段:created_at, last_login_time, username
|
||
SortOrder string // 排序方向:asc, desc
|
||
Offset int
|
||
Limit int
|
||
}
|
||
|
||
// AdvancedSearch 高级用户搜索(支持多维度组合筛选)
|
||
func (r *UserRepository) AdvancedSearch(ctx context.Context, filter *AdvancedFilter) ([]*domain.User, int64, error) {
|
||
var users []*domain.User
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.User{})
|
||
|
||
// 关键字搜索(转义 LIKE 特殊字符)
|
||
if filter.Keyword != "" {
|
||
like := "%" + escapeLikePattern(filter.Keyword) + "%"
|
||
query = query.Where(
|
||
"username LIKE ? OR email LIKE ? OR phone LIKE ? OR nickname LIKE ?",
|
||
like, like, like, like,
|
||
)
|
||
}
|
||
|
||
// 状态筛选
|
||
if filter.Status >= 0 {
|
||
query = query.Where("status = ?", filter.Status)
|
||
}
|
||
|
||
// 注册时间范围
|
||
if filter.CreatedFrom != nil {
|
||
query = query.Where("created_at >= ?", filter.CreatedFrom)
|
||
}
|
||
if filter.CreatedTo != nil {
|
||
query = query.Where("created_at <= ?", filter.CreatedTo)
|
||
}
|
||
|
||
// 最后登录时间范围
|
||
if filter.LastLoginFrom != nil {
|
||
query = query.Where("last_login_time >= ?", filter.LastLoginFrom)
|
||
}
|
||
|
||
// 按角色筛选(子查询)
|
||
if len(filter.RoleIDs) > 0 {
|
||
query = query.Where(
|
||
"id IN (SELECT user_id FROM user_roles WHERE role_id IN ? AND deleted_at IS NULL)",
|
||
filter.RoleIDs,
|
||
)
|
||
}
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 排序
|
||
sortBy := "created_at"
|
||
sortOrder := "DESC"
|
||
if filter.SortBy != "" {
|
||
allowedFields := map[string]bool{
|
||
"created_at": true, "last_login_time": true,
|
||
"username": true, "updated_at": true,
|
||
}
|
||
if allowedFields[filter.SortBy] {
|
||
sortBy = filter.SortBy
|
||
}
|
||
}
|
||
if filter.SortOrder == "asc" {
|
||
sortOrder = "ASC"
|
||
}
|
||
query = query.Order(sortBy + " " + sortOrder)
|
||
|
||
// 分页
|
||
limit := filter.Limit
|
||
if limit <= 0 {
|
||
limit = 20
|
||
}
|
||
if limit > 200 {
|
||
limit = 200
|
||
}
|
||
query = query.Offset(filter.Offset).Limit(limit)
|
||
|
||
if err := query.Find(&users).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return users, total, nil
|
||
}
|