157 lines
3.4 KiB
Go
157 lines
3.4 KiB
Go
package middleware
|
||
|
||
import (
|
||
"net/http"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
// contextKey 上下文键常量
|
||
const (
|
||
ContextKeyRoleCodes = "role_codes"
|
||
ContextKeyPermissionCodes = "permission_codes"
|
||
)
|
||
|
||
// RequirePermission 要求用户拥有指定权限之一(OR 逻辑)
|
||
// 适用于需要单个或多选权限校验的路由
|
||
func RequirePermission(codes ...string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if !hasAnyPermission(c, codes) {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"code": 403,
|
||
"message": "权限不足",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireAllPermissions 要求用户拥有所有指定权限(AND 逻辑)
|
||
func RequireAllPermissions(codes ...string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if !hasAllPermissions(c, codes) {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"code": 403,
|
||
"message": "权限不足,需要所有指定权限",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireRole 要求用户拥有指定角色之一(OR 逻辑)
|
||
func RequireRole(codes ...string) gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
if !hasAnyRole(c, codes) {
|
||
c.JSON(http.StatusForbidden, gin.H{
|
||
"code": 403,
|
||
"message": "权限不足,角色受限",
|
||
})
|
||
c.Abort()
|
||
return
|
||
}
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
// RequireAnyPermission RequirePermission 的别名,语义更清晰
|
||
func RequireAnyPermission(codes ...string) gin.HandlerFunc {
|
||
return RequirePermission(codes...)
|
||
}
|
||
|
||
// AdminOnly 仅限 admin 角色
|
||
func AdminOnly() gin.HandlerFunc {
|
||
return RequireRole("admin")
|
||
}
|
||
|
||
// GetRoleCodes 从 Context 获取当前用户角色代码列表
|
||
func GetRoleCodes(c *gin.Context) []string {
|
||
val, exists := c.Get(ContextKeyRoleCodes)
|
||
if !exists {
|
||
return nil
|
||
}
|
||
if codes, ok := val.([]string); ok {
|
||
return codes
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// GetPermissionCodes 从 Context 获取当前用户权限代码列表
|
||
func GetPermissionCodes(c *gin.Context) []string {
|
||
val, exists := c.Get(ContextKeyPermissionCodes)
|
||
if !exists {
|
||
return nil
|
||
}
|
||
if codes, ok := val.([]string); ok {
|
||
return codes
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// IsAdmin 判断当前用户是否为 admin
|
||
func IsAdmin(c *gin.Context) bool {
|
||
return hasAnyRole(c, []string{"admin"})
|
||
}
|
||
|
||
// hasAnyPermission 判断用户是否拥有任意一个权限
|
||
func hasAnyPermission(c *gin.Context, codes []string) bool {
|
||
// admin 角色拥有所有权限
|
||
if IsAdmin(c) {
|
||
return true
|
||
}
|
||
permCodes := GetPermissionCodes(c)
|
||
if len(permCodes) == 0 {
|
||
return false
|
||
}
|
||
permSet := toSet(permCodes)
|
||
for _, code := range codes {
|
||
if _, ok := permSet[code]; ok {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// hasAllPermissions 判断用户是否拥有所有权限
|
||
func hasAllPermissions(c *gin.Context, codes []string) bool {
|
||
if IsAdmin(c) {
|
||
return true
|
||
}
|
||
permCodes := GetPermissionCodes(c)
|
||
permSet := toSet(permCodes)
|
||
for _, code := range codes {
|
||
if _, ok := permSet[code]; !ok {
|
||
return false
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
// hasAnyRole 判断用户是否拥有任意一个角色
|
||
func hasAnyRole(c *gin.Context, codes []string) bool {
|
||
roleCodes := GetRoleCodes(c)
|
||
if len(roleCodes) == 0 {
|
||
return false
|
||
}
|
||
roleSet := toSet(roleCodes)
|
||
for _, code := range codes {
|
||
if _, ok := roleSet[code]; ok {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// toSet 将字符串切片转换为 map 集合
|
||
func toSet(items []string) map[string]struct{} {
|
||
s := make(map[string]struct{}, len(items))
|
||
for _, item := range items {
|
||
s[item] = struct{}{}
|
||
}
|
||
return s
|
||
}
|