package repository import ( "context" "errors" "gorm.io/gorm" "github.com/user-management-system/internal/domain" ) // RoleRepository 角色数据访问层 type RoleRepository struct { db *gorm.DB } // NewRoleRepository 创建角色数据访问层 func NewRoleRepository(db *gorm.DB) *RoleRepository { return &RoleRepository{db: db} } // Create 创建角色 func (r *RoleRepository) Create(ctx context.Context, role *domain.Role) error { // GORM omits zero values on insert for fields with DB defaults. Explicitly // backfill disabled status so callers can persist status=0 roles. requestedStatus := role.Status return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Create(role).Error; err != nil { return err } if requestedStatus == domain.RoleStatusDisabled { if err := tx.Model(&domain.Role{}).Where("id = ?", role.ID).Update("status", requestedStatus).Error; err != nil { return err } role.Status = requestedStatus } return nil }) } // Update 更新角色 func (r *RoleRepository) Update(ctx context.Context, role *domain.Role) error { return r.db.WithContext(ctx).Save(role).Error } // Delete 删除角色 func (r *RoleRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&domain.Role{}, id).Error } // GetByID 根据ID获取角色 func (r *RoleRepository) GetByID(ctx context.Context, id int64) (*domain.Role, error) { var role domain.Role err := r.db.WithContext(ctx).First(&role, id).Error if err != nil { return nil, err } return &role, nil } // GetByCode 根据代码获取角色 func (r *RoleRepository) GetByCode(ctx context.Context, code string) (*domain.Role, error) { var role domain.Role err := r.db.WithContext(ctx).Where("code = ?", code).First(&role).Error if err != nil { return nil, err } return &role, nil } // List 获取角色列表 func (r *RoleRepository) List(ctx context.Context, offset, limit int) ([]*domain.Role, int64, error) { var roles []*domain.Role var total int64 query := r.db.WithContext(ctx).Model(&domain.Role{}) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil { return nil, 0, err } return roles, total, nil } // ListByStatus 根据状态获取角色列表 func (r *RoleRepository) ListByStatus(ctx context.Context, status domain.RoleStatus, offset, limit int) ([]*domain.Role, int64, error) { var roles []*domain.Role var total int64 query := r.db.WithContext(ctx).Model(&domain.Role{}).Where("status = ?", status) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil { return nil, 0, err } return roles, total, nil } // GetDefaultRoles 获取默认角色 func (r *RoleRepository) GetDefaultRoles(ctx context.Context) ([]*domain.Role, error) { var roles []*domain.Role err := r.db.WithContext(ctx).Where("is_default = ?", true).Find(&roles).Error if err != nil { return nil, err } return roles, nil } // ExistsByCode 检查角色代码是否存在 func (r *RoleRepository) ExistsByCode(ctx context.Context, code string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.Role{}).Where("code = ?", code).Count(&count).Error return count > 0, err } // UpdateStatus 更新角色状态 func (r *RoleRepository) UpdateStatus(ctx context.Context, id int64, status domain.RoleStatus) error { return r.db.WithContext(ctx).Model(&domain.Role{}).Where("id = ?", id).Update("status", status).Error } // Search 搜索角色 func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, limit int) ([]*domain.Role, int64, error) { var roles []*domain.Role var total int64 query := r.db.WithContext(ctx).Model(&domain.Role{}). Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%") // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&roles).Error; err != nil { return nil, 0, err } return roles, total, nil } // ListByParentID 根据父ID获取角色列表 func (r *RoleRepository) ListByParentID(ctx context.Context, parentID int64) ([]*domain.Role, error) { var roles []*domain.Role err := r.db.WithContext(ctx).Where("parent_id = ?", parentID).Find(&roles).Error if err != nil { return nil, err } return roles, nil } // GetByIDs 根据ID列表批量获取角色 func (r *RoleRepository) GetByIDs(ctx context.Context, ids []int64) ([]*domain.Role, error) { if len(ids) == 0 { return []*domain.Role{}, nil } var roles []*domain.Role err := r.db.WithContext(ctx).Where("id IN ?", ids).Find(&roles).Error if err != nil { return nil, err } return roles, nil } // GetAncestorIDs 获取角色的所有祖先角色ID(用于权限继承) func (r *RoleRepository) GetAncestorIDs(ctx context.Context, roleID int64) ([]int64, error) { var ancestorIDs []int64 currentID := roleID // 循环向上查找父角色,直到没有父角色为止 for { var role domain.Role err := r.db.WithContext(ctx).Select("id", "parent_id").First(&role, currentID).Error if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { break } return nil, err } if role.ParentID == nil { break } ancestorIDs = append(ancestorIDs, *role.ParentID) currentID = *role.ParentID } return ancestorIDs, nil } // GetAncestors 获取角色的完整继承链(从父到子) func (r *RoleRepository) GetAncestors(ctx context.Context, roleID int64) ([]*domain.Role, error) { ancestorIDs, err := r.GetAncestorIDs(ctx, roleID) if err != nil { return nil, err } if len(ancestorIDs) == 0 { return []*domain.Role{}, nil } return r.GetByIDs(ctx, ancestorIDs) }