fix: enforce resource ownership checks
This commit is contained in:
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
@@ -118,9 +119,8 @@ func (h *DeviceHandler) GetDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
device, ok := h.authorizeDeviceAccess(c, id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -151,6 +151,10 @@ func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
@@ -187,6 +191,10 @@ func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -218,6 +226,10 @@ func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Status string `json:"status" binding:"required"`
|
||||
}
|
||||
@@ -269,27 +281,14 @@ func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 检查是否为管理员
|
||||
roleCodes, _ := c.Get("role_codes")
|
||||
isAdmin := false
|
||||
if roles, ok := roleCodes.([]string); ok {
|
||||
for _, role := range roles {
|
||||
if role == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
userIDParam := c.Param("id")
|
||||
userID, err := strconv.ParseInt(userIDParam, 10, 64)
|
||||
userID, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid user id"})
|
||||
return
|
||||
}
|
||||
|
||||
// 非管理员只能查看自己的设备
|
||||
if !isAdmin && userID != currentUserID {
|
||||
if !apimiddleware.IsAdmin(c) && userID != currentUserID {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "无权访问该用户的设备列表"})
|
||||
return
|
||||
}
|
||||
@@ -396,6 +395,10 @@ func (h *DeviceHandler) TrustDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req TrustDeviceRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
@@ -478,6 +481,10 @@ func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeDeviceAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -555,6 +562,27 @@ func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *DeviceHandler) authorizeDeviceAccess(c *gin.Context, deviceID int64) (*domain.Device, bool) {
|
||||
currentUserID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
device, err := h.deviceService.GetDevice(c.Request.Context(), deviceID)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
if device.UserID != currentUserID && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return device, true
|
||||
}
|
||||
|
||||
// parseDuration 解析duration字符串,如 "30d" -> 30天的time.Duration
|
||||
func parseDuration(s string) time.Duration {
|
||||
if s == "" {
|
||||
|
||||
@@ -118,6 +118,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
deviceSvc := service.NewDeviceService(deviceRepo, userRepo)
|
||||
loginLogSvc := service.NewLoginLogService(loginLogRepo)
|
||||
opLogSvc := service.NewOperationLogService(opLogRepo)
|
||||
webhookSvc := service.NewWebhookService(db)
|
||||
captchaSvc := service.NewCaptchaService(cacheManager)
|
||||
totpSvc := service.NewTOTPService(userRepo)
|
||||
pwdResetCfg := service.DefaultPasswordResetConfig()
|
||||
@@ -141,6 +142,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
permHandler := handler.NewPermissionHandler(permSvc)
|
||||
deviceHandler := handler.NewDeviceHandler(deviceSvc)
|
||||
logHandler := handler.NewLogHandler(loginLogSvc, opLogSvc)
|
||||
webhookHandler := handler.NewWebhookHandler(webhookSvc)
|
||||
captchaHandler := handler.NewCaptchaHandler(captchaSvc)
|
||||
totpHandler := handler.NewTOTPHandler(authSvc, totpSvc)
|
||||
pwdResetHandler := handler.NewPasswordResetHandler(pwdResetSvc)
|
||||
@@ -149,7 +151,7 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
r := router.NewRouter(
|
||||
authHandler, userHandler, roleHandler, permHandler, deviceHandler,
|
||||
logHandler, authMiddleware, rateLimitMiddleware, opLogMiddleware,
|
||||
pwdResetHandler, captchaHandler, totpHandler, nil,
|
||||
pwdResetHandler, captchaHandler, totpHandler, webhookHandler,
|
||||
nil, nil, nil, nil, nil, themeHandler, nil, nil, nil, avatarH,
|
||||
)
|
||||
engine := r.Setup()
|
||||
@@ -233,6 +235,62 @@ func registerUser(baseURL, username, email, password string) bool {
|
||||
return resp.StatusCode == http.StatusCreated
|
||||
}
|
||||
|
||||
func createDeviceAndGetID(t *testing.T, baseURL, token, deviceID string) int64 {
|
||||
t.Helper()
|
||||
|
||||
resp, body := doPost(baseURL+"/api/v1/devices", token, map[string]interface{}{
|
||||
"device_id": deviceID,
|
||||
"device_name": "Owned Device",
|
||||
"device_type": 3,
|
||||
"device_os": "Linux",
|
||||
"device_browser": "Chrome",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create device failed: status=%d body=%s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("decode create device response failed: %v body=%s", err, body)
|
||||
}
|
||||
if result.Data.ID == 0 {
|
||||
t.Fatalf("expected non-zero device id, body=%s", body)
|
||||
}
|
||||
return result.Data.ID
|
||||
}
|
||||
|
||||
func createWebhookAndGetID(t *testing.T, baseURL, token, name string) int64 {
|
||||
t.Helper()
|
||||
|
||||
resp, body := doPost(baseURL+"/api/v1/webhooks", token, map[string]interface{}{
|
||||
"name": name,
|
||||
"url": "https://example.com/webhook",
|
||||
"events": []string{"user.created"},
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create webhook failed: status=%d body=%s", resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Data struct {
|
||||
ID int64 `json:"id"`
|
||||
} `json:"data"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("decode create webhook response failed: %v body=%s", err, body)
|
||||
}
|
||||
if result.Data.ID == 0 {
|
||||
t.Fatalf("expected non-zero webhook id, body=%s", body)
|
||||
}
|
||||
return result.Data.ID
|
||||
}
|
||||
|
||||
func bootstrapAdminToken(baseURL, username, email, password string) string {
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"username": username,
|
||||
@@ -876,6 +934,73 @@ func TestDeviceHandler_CreateDevice_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_DeviceByIDRoutes_ForbiddenForOtherUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "device-owner", "device-owner@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "device-attacker", "device-attacker@test.com", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "device-owner", "UserPass123!")
|
||||
attackerToken := getToken(server.URL, "device-attacker", "UserPass123!")
|
||||
deviceID := createDeviceAndGetID(t, server.URL, ownerToken, "device-owner-001")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
url string
|
||||
body map[string]interface{}
|
||||
}{
|
||||
{name: "get", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)},
|
||||
{name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), body: map[string]interface{}{"device_name": "hijacked"}},
|
||||
{name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID)},
|
||||
{name: "status", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), body: map[string]interface{}{"status": "inactive"}},
|
||||
{name: "trust", method: http.MethodPost, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), body: map[string]interface{}{"trust_duration": "30d"}},
|
||||
{name: "untrust", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID)},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for %s, got %d body=%s", tc.name, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebhookHandler_OtherUserCannotManageForeignWebhook(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "webhook-owner", "webhook-owner@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "webhook-attacker", "webhook-attacker@test.com", "UserPass123!")
|
||||
ownerToken := getToken(server.URL, "webhook-owner", "UserPass123!")
|
||||
attackerToken := getToken(server.URL, "webhook-attacker", "UserPass123!")
|
||||
webhookID := createWebhookAndGetID(t, server.URL, ownerToken, "owner-webhook")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
url string
|
||||
body map[string]interface{}
|
||||
}{
|
||||
{name: "update", method: http.MethodPut, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID), body: map[string]interface{}{"name": "hijacked"}},
|
||||
{name: "delete", method: http.MethodDelete, url: fmt.Sprintf("%s/api/v1/webhooks/%d", server.URL, webhookID)},
|
||||
{name: "deliveries", method: http.MethodGet, url: fmt.Sprintf("%s/api/v1/webhooks/%d/deliveries", server.URL, webhookID)},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp, body := doRequest(tc.method, tc.url, attackerToken, tc.body)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("expected 403 for webhook %s, got %d body=%s", tc.name, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Role Handler Tests
|
||||
// =============================================================================
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
@@ -117,6 +118,10 @@ func (h *WebhookHandler) UpdateWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
var req service.UpdateWebhookRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
@@ -150,6 +155,10 @@ func (h *WebhookHandler) DeleteWebhook(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.webhookService.DeleteWebhook(c.Request.Context(), id); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "删除 Webhook 失败"})
|
||||
return
|
||||
@@ -178,6 +187,10 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if _, ok := h.authorizeWebhookAccess(c, id); !ok {
|
||||
return
|
||||
}
|
||||
|
||||
limit, _ := strconv.Atoi(c.DefaultQuery("limit", "20"))
|
||||
if limit < 1 || limit > 100 {
|
||||
limit = 20
|
||||
@@ -191,3 +204,24 @@ func (h *WebhookHandler) GetWebhookDeliveries(c *gin.Context) {
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "success", "data": gin.H{"deliveries": deliveries}})
|
||||
}
|
||||
|
||||
func (h *WebhookHandler) authorizeWebhookAccess(c *gin.Context, webhookID int64) (int64, bool) {
|
||||
userID, ok := getUserIDFromContext(c)
|
||||
if !ok {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return 0, false
|
||||
}
|
||||
|
||||
webhook, err := h.webhookService.GetWebhook(c.Request.Context(), webhookID)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return 0, false
|
||||
}
|
||||
|
||||
if webhook.CreatedBy != userID && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return userID, true
|
||||
}
|
||||
|
||||
@@ -359,9 +359,9 @@ func TestWebhookHandler_DeleteWebhook_NotFound(t *testing.T) {
|
||||
resp := doRequestWithCheck(t, "DELETE", server.URL+"/api/v1/webhooks/99999", token, nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Delete is idempotent - returns 200 even if not found
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode)
|
||||
// 先做归属/存在性校验,不存在的 webhook 返回 404
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Fatalf("expected status 404, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user