fix: enforce resource ownership checks
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
@@ -118,9 +119,8 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
device, ok := h.authorizeDeviceAccess(c, id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,6 +151,10 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
@@ -187,6 +191,10 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -218,6 +226,10 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Status string `json:"status" binding:"required"`
|
||||
}
|
||||
@@ -269,27 +281,14 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
roleCodes, _ := c.Get("role_codes")
|
||||
isAdmin := false
|
||||
if roles, ok := roleCodes.([]string); ok {
|
||||
for _, role := range roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userIDParam := c.Param("id")
|
||||
userID, err := strconv.ParseInt(userIDParam, 10, 64)
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 非管理员只能查看自己的设备
|
||||
if !isAdmin && userID != currentUserID {
|
||||
if !apimiddleware.IsAdmin(c) && userID != currentUserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "无权访问该用户的设备列表"})
|
||||
return
|
||||
}
|
||||
@@ -396,6 +395,10 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req TrustDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
@@ -478,6 +481,10 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -555,6 +562,27 @@ func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) authorizeDeviceAccess(c *gin.Context, deviceID int64) (*domain.Device, bool) {
|
||||
currentUserID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDevice(c.Request.Context(), deviceID)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if device.UserID != currentUserID && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return device, true
|
||||
}
|
||||
|
||||
// parseDuration 解析duration字符串,如 "30d" -> 30天的time.Duration
|
||||
func parseDuration(s string) time.Duration {
|
||||
if s == "" {
|
||||
|
||||
Reference in New Issue
Block a user