Files
user-system/internal/repository/role.go
long-agent 765a50b7d4 fix: 生产安全修复 + Go SDK + CAS SSO框架
安全修复:
- CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证
- HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证
- HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵)
- HIGH: 短信验证码熵值过低(4字节) - 提升到6字节
- HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时
- HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern

新功能:
- Go SDK: sdk/go/user-management/ 完整SDK实现
- CAS SSO框架: internal/auth/cas.go CAS协议支持

其他:
- L1Cache实例问题修复 - AuthMiddleware共享l1Cache
- 设备指纹XSS防护 - 内存存储替代localStorage
- 响应格式协议中间件
- 导出无界查询修复
2026-04-03 17:38:31 +08:00

218 lines
5.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"errors"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// RoleRepository 角色数据访问层
type RoleRepository struct {
db *gorm.DB
}
// NewRoleRepository 创建角色数据访问层
func NewRoleRepository(db *gorm.DB) *RoleRepository {
return &RoleRepository{db: db}
}
// Create 创建角色
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
// GORM omits zero values on insert for fields with DB defaults. Explicitly
// backfill disabled status so callers can persist status=0 roles.
requestedStatus := role.Status
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if err := tx.Create(role).Error; err != nil {
return err
}
if requestedStatus == domain.RoleStatusDisabled {
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
return err
}
role.Status = requestedStatus
}
return nil
})
}
// Update 更新角色
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
return r.db.WithContext(ctx).Save(role).Error
}
// Delete 删除角色
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
}
// GetByID 根据ID获取角色
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).First(&role, id).Error
if err != nil {
return nil, err
}
return &role, nil
}
// GetByCode 根据代码获取角色
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
var role domain.Role
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
if err != nil {
return nil, err
}
return &role, nil
}
// List 获取角色列表
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{})
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByStatus 根据状态获取角色列表
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// GetDefaultRoles 获取默认角色
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// ExistsByCode 检查角色代码是否存在
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
var count int64
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
return count > 0, err
}
// UpdateStatus 更新角色状态
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
}
// Search 搜索角色
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
var roles []*domain.Role
var total int64
// 转义 LIKE 特殊字符,防止搜索被意外干扰
escapedKeyword := escapeLikePattern(keyword)
pattern := "%" + escapedKeyword + "%"
query := r.db.WithContext(ctx).Model(&domain.Role{}).
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
// 获取总数
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
// 获取列表
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
return nil, 0, err
}
return roles, total, nil
}
// ListByParentID 根据父ID获取角色列表
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetByIDs 根据ID列表批量获取角色
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
if len(ids) == 0 {
return []*domain.Role{}, nil
}
var roles []*domain.Role
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
if err != nil {
return nil, err
}
return roles, nil
}
// GetAncestorIDs 获取角色的所有祖先角色ID用于权限继承
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
var ancestorIDs []int64
currentID := roleID
// 循环向上查找父角色,直到没有父角色为止
for {
var role domain.Role
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
break
}
return nil, err
}
if role.ParentID == nil {
break
}
ancestorIDs = append(ancestorIDs, *role.ParentID)
currentID = *role.ParentID
}
return ancestorIDs, nil
}
// GetAncestors 获取角色的完整继承链(从父到子)
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
if err != nil {
return nil, err
}
if len(ancestorIDs) == 0 {
return []*domain.Role{}, nil
}
return r.GetByIDs(ctx, ancestorIDs)
}