214 lines
5.8 KiB
Go
214 lines
5.8 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/user-management-system/internal/domain"
|
||
)
|
||
|
||
// RoleRepository 角色数据访问层
|
||
type RoleRepository struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewRoleRepository 创建角色数据访问层
|
||
func NewRoleRepository(db *gorm.DB) *RoleRepository {
|
||
return &RoleRepository{db: db}
|
||
}
|
||
|
||
// Create 创建角色
|
||
func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error {
|
||
// GORM omits zero values on insert for fields with DB defaults. Explicitly
|
||
// backfill disabled status so callers can persist status=0 roles.
|
||
requestedStatus := role.Status
|
||
return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||
if err := tx.Create(role).Error; err != nil {
|
||
return err
|
||
}
|
||
if requestedStatus == domain.RoleStatusDisabled {
|
||
if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil {
|
||
return err
|
||
}
|
||
role.Status = requestedStatus
|
||
}
|
||
return nil
|
||
})
|
||
}
|
||
|
||
// Update 更新角色
|
||
func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error {
|
||
return r.db.WithContext(ctx).Save(role).Error
|
||
}
|
||
|
||
// Delete 删除角色
|
||
func (r *RoleRepository) Delete(ctx context.Context, id int64) error {
|
||
return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取角色
|
||
func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) {
|
||
var role domain.Role
|
||
err := r.db.WithContext(ctx).First(&role, id).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &role, nil
|
||
}
|
||
|
||
// GetByCode 根据代码获取角色
|
||
func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) {
|
||
var role domain.Role
|
||
err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &role, nil
|
||
}
|
||
|
||
// List 获取角色列表
|
||
func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) {
|
||
var roles []*domain.Role
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.Role{})
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return roles, total, nil
|
||
}
|
||
|
||
// ListByStatus 根据状态获取角色列表
|
||
func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) {
|
||
var roles []*domain.Role
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status)
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return roles, total, nil
|
||
}
|
||
|
||
// GetDefaultRoles 获取默认角色
|
||
func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) {
|
||
var roles []*domain.Role
|
||
err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return roles, nil
|
||
}
|
||
|
||
// ExistsByCode 检查角色代码是否存在
|
||
func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) {
|
||
var count int64
|
||
err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error
|
||
return count > 0, err
|
||
}
|
||
|
||
// UpdateStatus 更新角色状态
|
||
func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error {
|
||
return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error
|
||
}
|
||
|
||
// Search 搜索角色
|
||
func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) {
|
||
var roles []*domain.Role
|
||
var total int64
|
||
|
||
query := r.db.WithContext(ctx).Model(&domain.Role{}).
|
||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||
|
||
// 获取总数
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
// 获取列表
|
||
if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return roles, total, nil
|
||
}
|
||
|
||
// ListByParentID 根据父ID获取角色列表
|
||
func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) {
|
||
var roles []*domain.Role
|
||
err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return roles, nil
|
||
}
|
||
|
||
// GetByIDs 根据ID列表批量获取角色
|
||
func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) {
|
||
if len(ids) == 0 {
|
||
return []*domain.Role{}, nil
|
||
}
|
||
|
||
var roles []*domain.Role
|
||
err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return roles, nil
|
||
}
|
||
|
||
// GetAncestorIDs 获取角色的所有祖先角色ID(用于权限继承)
|
||
func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) {
|
||
var ancestorIDs []int64
|
||
currentID := roleID
|
||
|
||
// 循环向上查找父角色,直到没有父角色为止
|
||
for {
|
||
var role domain.Role
|
||
err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error
|
||
if err != nil {
|
||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||
break
|
||
}
|
||
return nil, err
|
||
}
|
||
if role.ParentID == nil {
|
||
break
|
||
}
|
||
ancestorIDs = append(ancestorIDs, *role.ParentID)
|
||
currentID = *role.ParentID
|
||
}
|
||
|
||
return ancestorIDs, nil
|
||
}
|
||
|
||
// GetAncestors 获取角色的完整继承链(从父到子)
|
||
func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) {
|
||
ancestorIDs, err := r.GetAncestorIDs(ctx, roleID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
if len(ancestorIDs) == 0 {
|
||
return []*domain.Role{}, nil
|
||
}
|
||
return r.GetByIDs(ctx, ancestorIDs)
|
||
}
|