Files
user-system/internal/repository/user_role.go
long-agent 2a18a6fb47 fix(n+1): 批量查询替代循环单查
- IsAdminBootstrapRequired: userRepo.GetByID 循环 → GetByIDs 批量
- AssignRoles: roleRepo.GetByID 循环 → GetByIDs 批量
- 在 userRepositoryInterface 补充 GetByIDs 方法签名
2026-05-08 08:05:26 +08:00

238 lines
7.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// UserRoleRepository 用户角色关联数据访问层
type UserRoleRepository struct {
db *gorm.DB
}
// NewUserRoleRepository 创建用户角色关联数据访问层
func NewUserRoleRepository(db *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: db}
}
// DB returns the underlying GORM DB for transaction support
func (r *UserRoleRepository) DB() *gorm.DB {
return r.db
}
// WithTx returns a new repository instance that uses the given transaction
func (r *UserRoleRepository) WithTx(tx *gorm.DB) *UserRoleRepository {
return &UserRoleRepository{db: tx}
}
// Create 创建用户角色关联
func (r *UserRoleRepository) Create(ctx context.Context, userRole *domain.UserRole) error {
return r.db.WithContext(ctx).Create(userRole).Error
}
// Delete 删除用户角色关联
func (r *UserRoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, id).Error
}
// DeleteByUserID 删除用户的所有角色
func (r *UserRoleRepository) DeleteByUserID(ctx context.Context, userID int64) error {
return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error
}
// DeleteByUserAndRole 删除指定用户和角色的关联
func (r *UserRoleRepository) DeleteByUserAndRole(ctx context.Context, userID, roleID int64) error {
return r.db.WithContext(ctx).Where("user_id = ? AND role_id = ?", userID, roleID).Delete(&domain.UserRole{}).Error
}
// DeleteByRoleID 删除角色的所有用户
func (r *UserRoleRepository) DeleteByRoleID(ctx context.Context, roleID int64) error {
return r.db.WithContext(ctx).Where("role_id = ?", roleID).Delete(&domain.UserRole{}).Error
}
// GetByUserID 根据用户ID获取角色列表
func (r *UserRoleRepository) GetByUserID(ctx context.Context, userID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetByRoleID 根据角色ID获取用户列表
func (r *UserRoleRepository) GetByRoleID(ctx context.Context, roleID int64) ([]*domain.UserRole, error) {
var userRoles []*domain.UserRole
err := r.db.WithContext(ctx).Where("role_id = ?", roleID).Find(&userRoles).Error
if err != nil {
return nil, err
}
return userRoles, nil
}
// GetRoleIDsByUserID 根据用户ID获取角色ID列表
func (r *UserRoleRepository) GetRoleIDsByUserID(ctx context.Context, userID int64) ([]int64, error) {
var roleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &roleIDs).Error
if err != nil {
return nil, err
}
return roleIDs, nil
}
// getRoleAncestorIDs 递归获取角色的所有祖先角色ID含自身
// 包含循环检测(最大深度 5 层)
func (r *UserRoleRepository) getRoleAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestors []int64
visited := make(map[int64]bool)
current := roleID
depth := 0
maxDepth := 5
for current > 0 && depth < maxDepth {
if visited[current] {
break // 循环检测
}
visited[current] = true
ancestors = append(ancestors, current)
var role domain.Role
err := r.db.WithContext(ctx).Select("parent_id").First(&role, current).Error
if err != nil || role.ParentID == nil {
break
}
current = *role.ParentID
depth++
}
return ancestors, nil
}
// GetUserRolesAndPermissions 获取用户角色和权限(包含继承的父角色和权限)
func (r *UserRoleRepository) GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) {
// 获取用户直接分配的角色ID
var directRoleIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("user_id = ?", userID).Pluck("role_id", &directRoleIDs).Error
if err != nil {
return nil, nil, err
}
// 递归获取所有祖先角色ID含自身包含循环检测
allRoleIDMap := make(map[int64]bool)
for _, roleID := range directRoleIDs {
ancestors, err := r.getRoleAncestorIDs(ctx, roleID)
if err != nil {
return nil, nil, err
}
for _, id := range ancestors {
allRoleIDMap[id] = true
}
}
// 转换为 slice
allRoleIDs := make([]int64, 0, len(allRoleIDMap))
for id := range allRoleIDMap {
allRoleIDs = append(allRoleIDs, id)
}
if len(allRoleIDs) == 0 {
return []*domain.Role{}, []*domain.Permission{}, nil
}
// 查询所有角色信息
var roles []*domain.Role
err = r.db.WithContext(ctx).Where("id IN ? AND status = ?", allRoleIDs, domain.RoleStatusEnabled).Find(&roles).Error
if err != nil {
return nil, nil, err
}
// 查询所有权限ID
var permissionIDs []int64
err = r.db.WithContext(ctx).Model(&domain.RolePermission{}).Where("role_id IN ?", allRoleIDs).Pluck("permission_id", &permissionIDs).Error
if err != nil {
return nil, nil, err
}
// 查询权限详情
var permissions []*domain.Permission
if len(permissionIDs) > 0 {
err = r.db.WithContext(ctx).Where("id IN ?", permissionIDs).Find(&permissions).Error
if err != nil {
return nil, nil, err
}
}
return roles, permissions, nil
}
// GetUserIDByRoleID 根据角色ID获取用户ID列表
func (r *UserRoleRepository) GetUserIDByRoleID(ctx context.Context, roleID int64) ([]int64, error) {
var userIDs []int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).Where("role_id = ?", roleID).Pluck("user_id", &userIDs).Error
if err != nil {
return nil, err
}
return userIDs, nil
}
// Exists 检查用户角色关联是否存在
func (r *UserRoleRepository) Exists(ctx context.Context, userID, roleID int64) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.UserRole{}).
Where("user_id = ? AND role_id = ?", userID, roleID).
Count(&count).Error
return count > 0, err
}
// BatchCreate 批量创建用户角色关联
func (r *UserRoleRepository) BatchCreate(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
return r.db.WithContext(ctx).Create(&userRoles).Error
}
// BatchDelete 批量删除用户角色关联
func (r *UserRoleRepository) BatchDelete(ctx context.Context, userRoles []*domain.UserRole) error {
if len(userRoles) == 0 {
return nil
}
var ids []int64
for _, ur := range userRoles {
ids = append(ids, ur.ID)
}
return r.db.WithContext(ctx).Delete(&domain.UserRole{}, ids).Error
}
// ReplaceUserRoles replaces all roles for a user in a single transaction
// This encapsulates the delete-then-create pattern to ensure atomicity
func (r *UserRoleRepository) ReplaceUserRoles(ctx context.Context, userID int64, roleIDs []int64) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// Delete all existing roles for the user
if err := tx.Where("user_id = ?", userID).Delete(&domain.UserRole{}).Error; err != nil {
return err
}
// Create new role associations if any
if len(roleIDs) > 0 {
userRoles := make([]*domain.UserRole, len(roleIDs))
for i, roleID := range roleIDs {
userRoles[i] = &domain.UserRole{
UserID: userID,
RoleID: roleID,
}
}
if err := tx.Create(&userRoles).Error; err != nil {
return err
}
}
return nil
})
}