Files
user-system/internal/repository/role.go
long-agent 8095307d82 fix: P0/P1 security and quality fixes
P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go
P0-02: Add atomic Increment to L1Cache and L2Cache interfaces
P0-07: Add TOTP verification step after password login
P1-01: Sanitize error messages in error.go middleware
P1-03: Remove err.Error() from export error messages
P1-04: Add error return to CountByResultSince in login_log.go
P1-05: Add transactional DeleteCascade to RoleRepository
P1-06: Add PasswordChangedAt tracking for JWT token invalidation
P1-07: Wrap theme SetDefault in database transaction
P1-08: Use config values for database pool parameters
P1-09: Add rows.Err() checks in social_account_repo.go
P1-10: Validate sortOrder with map in user.go ORDER BY
P1-11: Add GORM tags to Announcement struct
P1-15: Add pageSize upper limit (100) to device and log handlers
2026-04-18 15:33:12 +08:00

230 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RoleRepository 角色数据访问层
type RoleRepository struct {
db *gorm.DB
}
// NewRoleRepository 创建角色数据访问层
func NewRoleRepository(db *gorm.DB) *RoleRepository {
return &RoleRepository{db: db}
}
// Create 创建角色
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 roles.
requestedStatus := role.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(role).Error; err != nil {
return err
}
if requestedStatus == domain.RoleStatusDisabled {
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
role.Status = requestedStatus
}
return nil
})
}
// Update 更新角色
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
return r.db.WithContext(ctx).Save(role).Error
}
// Delete 删除角色
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// DeleteCascade 级联删除角色(同时删除角色权限关联)
func (r *RoleRepository) DeleteCascade(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
// 先删除角色权限关联
if err := tx.Where("role_id = ?", id).Delete(&domain.RolePermission{}).Error; err != nil {
return err
}
// 再删除角色
return tx.Delete(&domain.Role{}, id).Error
})
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).First(&role, id).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetByCode 根据代码获取角色
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// List 获取角色列表
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByStatus 根据状态获取角色列表
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// GetDefaultRoles 获取默认角色
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// ExistsByCode 检查角色代码是否存在
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新角色状态
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索角色
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByParentID 根据父ID获取角色列表
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetByIDs 根据ID列表批量获取角色
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
if len(ids) == 0 {
return []*domain.Role{}, nil
}
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
// 循环向上查找父角色,直到没有父角色为止
for {
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return nil, err
}
if role.ParentID == nil {
break
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
}
return ancestorIDs, nil
}
// GetAncestors 获取角色的完整继承链(从父到子)
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
if len(ancestorIDs) == 0 {
return []*domain.Role{}, nil
}
return r.GetByIDs(ctx, ancestorIDs)
}