fix: 系统性修复安全问题、性能问题和错误处理
安全问题修复: - X-Forwarded-For越界检查(auth.go) - checkTokenStatus Context参数传递(auth.go) - Type Assertion安全检查(auth.go) 性能问题修复: - TokenCache过期清理机制 - BruteForceProtection过期清理 - InMemoryIdempotencyStore过期清理 错误处理修复: - AuditStore.Emit返回error - domain层emitAudit辅助方法 - List方法返回空slice而非nil - 金额/价格负数验证 架构一致性: - 统一使用model.RoleHierarchyLevels 新增功能: - Alert API完整实现(CRUD+Resolve) - pkg/error错误码集中管理
This commit is contained in:
350
supply-api/internal/audit/handler/alert_handler.go
Normal file
350
supply-api/internal/audit/handler/alert_handler.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
"lijiaoqiao/supply-api/internal/audit/service"
|
||||
)
|
||||
|
||||
// AlertHandler 告警HTTP处理器
|
||||
type AlertHandler struct {
|
||||
svc *service.AlertService
|
||||
}
|
||||
|
||||
// NewAlertHandler 创建告警处理器
|
||||
func NewAlertHandler(svc *service.AlertService) *AlertHandler {
|
||||
return &AlertHandler{svc: svc}
|
||||
}
|
||||
|
||||
// CreateAlertRequest 创建告警请求
|
||||
type CreateAlertRequest struct {
|
||||
AlertName string `json:"alert_name"`
|
||||
AlertType string `json:"alert_type"`
|
||||
AlertLevel string `json:"alert_level"`
|
||||
TenantID int64 `json:"tenant_id"`
|
||||
SupplierID int64 `json:"supplier_id,omitempty"`
|
||||
Title string `json:"title"`
|
||||
Message string `json:"message"`
|
||||
Description string `json:"description,omitempty"`
|
||||
EventID string `json:"event_id,omitempty"`
|
||||
EventIDs []string `json:"event_ids,omitempty"`
|
||||
NotifyEnabled bool `json:"notify_enabled"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
}
|
||||
|
||||
// UpdateAlertRequest 更新告警请求
|
||||
type UpdateAlertRequest struct {
|
||||
Title string `json:"title,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
AlertLevel string `json:"alert_level,omitempty"`
|
||||
Status string `json:"status,omitempty"`
|
||||
NotifyEnabled *bool `json:"notify_enabled,omitempty"`
|
||||
NotifyChannels []string `json:"notify_channels,omitempty"`
|
||||
Tags []string `json:"tags,omitempty"`
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
// ResolveAlertRequest 解决告警请求
|
||||
type ResolveAlertRequest struct {
|
||||
ResolvedBy string `json:"resolved_by"`
|
||||
Note string `json:"note"`
|
||||
}
|
||||
|
||||
// AlertResponse 告警响应
|
||||
type AlertResponse struct {
|
||||
Alert *model.Alert `json:"alert"`
|
||||
}
|
||||
|
||||
// AlertListResponse 告警列表响应
|
||||
type AlertListResponse struct {
|
||||
Alerts []*model.Alert `json:"alerts"`
|
||||
Total int64 `json:"total"`
|
||||
Offset int `json:"offset"`
|
||||
Limit int `json:"limit"`
|
||||
}
|
||||
|
||||
// CreateAlert 处理 POST /api/v1/audit/alerts
|
||||
func (h *AlertHandler) CreateAlert(w http.ResponseWriter, r *http.Request) {
|
||||
var req CreateAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 验证必填字段
|
||||
if req.Title == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "title is required")
|
||||
return
|
||||
}
|
||||
if req.AlertType == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "alert_type is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 创建告警
|
||||
alert := &model.Alert{
|
||||
AlertName: req.AlertName,
|
||||
AlertType: req.AlertType,
|
||||
AlertLevel: req.AlertLevel,
|
||||
TenantID: req.TenantID,
|
||||
SupplierID: req.SupplierID,
|
||||
Title: req.Title,
|
||||
Message: req.Message,
|
||||
Description: req.Description,
|
||||
EventID: req.EventID,
|
||||
EventIDs: req.EventIDs,
|
||||
NotifyEnabled: req.NotifyEnabled,
|
||||
Tags: req.Tags,
|
||||
}
|
||||
|
||||
result, err := h.svc.CreateAlert(r.Context(), alert)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "CREATE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// GetAlert 处理 GET /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) GetAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
alert, err := h.svc.GetAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: alert})
|
||||
}
|
||||
|
||||
// ListAlerts 处理 GET /api/v1/audit/alerts
|
||||
func (h *AlertHandler) ListAlerts(w http.ResponseWriter, r *http.Request) {
|
||||
filter := &model.AlertFilter{}
|
||||
|
||||
// 解析查询参数
|
||||
if tenantIDStr := r.URL.Query().Get("tenant_id"); tenantIDStr != "" {
|
||||
tenantID, err := strconv.ParseInt(tenantIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filter.TenantID = tenantID
|
||||
}
|
||||
}
|
||||
|
||||
if supplierIDStr := r.URL.Query().Get("supplier_id"); supplierIDStr != "" {
|
||||
supplierID, err := strconv.ParseInt(supplierIDStr, 10, 64)
|
||||
if err == nil {
|
||||
filter.SupplierID = supplierID
|
||||
}
|
||||
}
|
||||
|
||||
if alertType := r.URL.Query().Get("alert_type"); alertType != "" {
|
||||
filter.AlertType = alertType
|
||||
}
|
||||
|
||||
if alertLevel := r.URL.Query().Get("alert_level"); alertLevel != "" {
|
||||
filter.AlertLevel = alertLevel
|
||||
}
|
||||
|
||||
if status := r.URL.Query().Get("status"); status != "" {
|
||||
filter.Status = status
|
||||
}
|
||||
|
||||
if keywords := r.URL.Query().Get("keywords"); keywords != "" {
|
||||
filter.Keywords = keywords
|
||||
}
|
||||
|
||||
if offsetStr := r.URL.Query().Get("offset"); offsetStr != "" {
|
||||
offset, err := strconv.Atoi(offsetStr)
|
||||
if err == nil && offset >= 0 {
|
||||
filter.Offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
if limitStr := r.URL.Query().Get("limit"); limitStr != "" {
|
||||
limit, err := strconv.Atoi(limitStr)
|
||||
if err == nil && limit > 0 && limit <= 1000 {
|
||||
filter.Limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
if filter.Limit == 0 {
|
||||
filter.Limit = 100
|
||||
}
|
||||
|
||||
alerts, total, err := h.svc.ListAlerts(r.Context(), filter)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "LIST_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertListResponse{
|
||||
Alerts: alerts,
|
||||
Total: total,
|
||||
Offset: filter.Offset,
|
||||
Limit: filter.Limit,
|
||||
})
|
||||
}
|
||||
|
||||
// UpdateAlert 处理 PUT /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) UpdateAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
// 获取现有告警
|
||||
alert, err := h.svc.GetAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "GET_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
var req UpdateAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Title != "" {
|
||||
alert.Title = req.Title
|
||||
}
|
||||
if req.Message != "" {
|
||||
alert.Message = req.Message
|
||||
}
|
||||
if req.Description != "" {
|
||||
alert.Description = req.Description
|
||||
}
|
||||
if req.AlertLevel != "" {
|
||||
alert.AlertLevel = req.AlertLevel
|
||||
}
|
||||
if req.Status != "" {
|
||||
alert.Status = req.Status
|
||||
}
|
||||
if req.NotifyEnabled != nil {
|
||||
alert.NotifyEnabled = *req.NotifyEnabled
|
||||
}
|
||||
if len(req.NotifyChannels) > 0 {
|
||||
alert.NotifyChannels = req.NotifyChannels
|
||||
}
|
||||
if len(req.Tags) > 0 {
|
||||
alert.Tags = req.Tags
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
alert.Metadata = req.Metadata
|
||||
}
|
||||
|
||||
result, err := h.svc.UpdateAlert(r.Context(), alert)
|
||||
if err != nil {
|
||||
writeAlertError(w, http.StatusInternalServerError, "UPDATE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// DeleteAlert 处理 DELETE /api/v1/audit/alerts/{alert_id}
|
||||
func (h *AlertHandler) DeleteAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
err := h.svc.DeleteAlert(r.Context(), alertID)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "DELETE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// ResolveAlert 处理 POST /api/v1/audit/alerts/{alert_id}/resolve
|
||||
func (h *AlertHandler) ResolveAlert(w http.ResponseWriter, r *http.Request) {
|
||||
alertID := extractAlertID(r)
|
||||
if alertID == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_PARAM", "alert_id is required")
|
||||
return
|
||||
}
|
||||
|
||||
var req ResolveAlertRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAlertError(w, http.StatusBadRequest, "INVALID_REQUEST", "invalid request body: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
if req.ResolvedBy == "" {
|
||||
writeAlertError(w, http.StatusBadRequest, "MISSING_FIELD", "resolved_by is required")
|
||||
return
|
||||
}
|
||||
|
||||
result, err := h.svc.ResolveAlert(r.Context(), alertID, req.ResolvedBy, req.Note)
|
||||
if err != nil {
|
||||
if err == service.ErrAlertNotFound {
|
||||
writeAlertError(w, http.StatusNotFound, "NOT_FOUND", "alert not found")
|
||||
return
|
||||
}
|
||||
writeAlertError(w, http.StatusInternalServerError, "RESOLVE_FAILED", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(AlertResponse{Alert: result})
|
||||
}
|
||||
|
||||
// extractAlertID 从请求中提取alert_id(优先从路径,其次从查询参数)
|
||||
func extractAlertID(r *http.Request) string {
|
||||
// 先尝试从路径提取
|
||||
path := r.URL.Path
|
||||
parts := strings.Split(strings.TrimPrefix(path, "/"), "/")
|
||||
if len(parts) >= 5 && parts[0] == "api" && parts[1] == "v1" && parts[2] == "audit" && parts[3] == "alerts" {
|
||||
if parts[4] != "" && parts[4] != "resolve" {
|
||||
return parts[4]
|
||||
}
|
||||
}
|
||||
// 再尝试从查询参数提取
|
||||
if alertID := r.URL.Query().Get("alert_id"); alertID != "" {
|
||||
return alertID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// writeAlertError 写入错误响应
|
||||
func writeAlertError(w http.ResponseWriter, status int, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(ErrorResponse{
|
||||
Error: message,
|
||||
Code: code,
|
||||
Details: "",
|
||||
})
|
||||
}
|
||||
315
supply-api/internal/audit/handler/alert_handler_test.go
Normal file
315
supply-api/internal/audit/handler/alert_handler_test.go
Normal file
@@ -0,0 +1,315 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
"lijiaoqiao/supply-api/internal/audit/service"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockAlertStore 模拟告警存储
|
||||
type mockAlertStore struct {
|
||||
alerts map[string]*model.Alert
|
||||
}
|
||||
|
||||
func newMockAlertStore() *mockAlertStore {
|
||||
return &mockAlertStore{
|
||||
alerts: make(map[string]*model.Alert),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Create(ctx context.Context, alert *model.Alert) error {
|
||||
if alert.AlertID == "" {
|
||||
alert.AlertID = "test-alert-id"
|
||||
}
|
||||
alert.CreatedAt = testTime
|
||||
alert.UpdatedAt = testTime
|
||||
m.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) GetByID(ctx context.Context, alertID string) (*model.Alert, error) {
|
||||
if alert, ok := m.alerts[alertID]; ok {
|
||||
return alert, nil
|
||||
}
|
||||
return nil, service.ErrAlertNotFound
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Update(ctx context.Context, alert *model.Alert) error {
|
||||
if _, ok := m.alerts[alert.AlertID]; !ok {
|
||||
return service.ErrAlertNotFound
|
||||
}
|
||||
alert.UpdatedAt = testTime
|
||||
m.alerts[alert.AlertID] = alert
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) Delete(ctx context.Context, alertID string) error {
|
||||
if _, ok := m.alerts[alertID]; !ok {
|
||||
return service.ErrAlertNotFound
|
||||
}
|
||||
delete(m.alerts, alertID)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockAlertStore) List(ctx context.Context, filter *model.AlertFilter) ([]*model.Alert, int64, error) {
|
||||
var result []*model.Alert
|
||||
for _, alert := range m.alerts {
|
||||
if filter.TenantID > 0 && alert.TenantID != filter.TenantID {
|
||||
continue
|
||||
}
|
||||
if filter.Status != "" && alert.Status != filter.Status {
|
||||
continue
|
||||
}
|
||||
result = append(result, alert)
|
||||
}
|
||||
return result, int64(len(result)), nil
|
||||
}
|
||||
|
||||
var testTime = time.Now()
|
||||
|
||||
// TestAlertHandler_CreateAlert_Success 测试创建告警成功
|
||||
func TestAlertHandler_CreateAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
reqBody := CreateAlertRequest{
|
||||
AlertName: "TEST_ALERT",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert Title",
|
||||
Message: "Test alert message",
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.CreateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusCreated, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Test Alert Title", result.Alert.Title)
|
||||
assert.Equal(t, "security", result.Alert.AlertType)
|
||||
}
|
||||
|
||||
// TestAlertHandler_CreateAlert_MissingTitle 测试缺少标题
|
||||
func TestAlertHandler_CreateAlert_MissingTitle(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
reqBody := CreateAlertRequest{
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
}
|
||||
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.CreateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusBadRequest, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_GetAlert_Success 测试获取告警成功
|
||||
func TestAlertHandler_GetAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertName: "TEST_ALERT",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert",
|
||||
Message: "Test message",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 获取告警
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/test-alert-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "test-alert-123", result.Alert.AlertID)
|
||||
}
|
||||
|
||||
// TestAlertHandler_GetAlert_NotFound 测试告警不存在
|
||||
func TestAlertHandler_GetAlert_NotFound(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.GetAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_ListAlerts_Success 测试列出告警成功
|
||||
func TestAlertHandler_ListAlerts_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 创建多个告警
|
||||
for i := 0; i < 3; i++ {
|
||||
alert := &model.Alert{
|
||||
AlertID: "alert-" + string(rune('a'+i)),
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Test Alert",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/audit/alerts?tenant_id=2001", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ListAlerts(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertListResponse
|
||||
err := json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, int64(3), result.Total)
|
||||
}
|
||||
|
||||
// TestAlertHandler_UpdateAlert_Success 测试更新告警成功
|
||||
func TestAlertHandler_UpdateAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Title: "Original Title",
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 更新告警
|
||||
reqBody := UpdateAlertRequest{
|
||||
Title: "Updated Title",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("PUT", "/api/v1/audit/alerts/test-alert-123", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.UpdateAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Equal(t, "Updated Title", result.Alert.Title)
|
||||
}
|
||||
|
||||
// TestAlertHandler_DeleteAlert_Success 测试删除告警成功
|
||||
func TestAlertHandler_DeleteAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 删除告警
|
||||
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/test-alert-123", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.DeleteAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNoContent, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_DeleteAlert_NotFound 测试删除不存在的告警
|
||||
func TestAlertHandler_DeleteAlert_NotFound(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
req := httptest.NewRequest("DELETE", "/api/v1/audit/alerts/nonexistent", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.DeleteAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusNotFound, w.Code)
|
||||
}
|
||||
|
||||
// TestAlertHandler_ResolveAlert_Success 测试解决告警成功
|
||||
func TestAlertHandler_ResolveAlert_Success(t *testing.T) {
|
||||
store := newMockAlertStore()
|
||||
svc := service.NewAlertService(store)
|
||||
h := NewAlertHandler(svc)
|
||||
|
||||
// 先创建一个告警
|
||||
alert := &model.Alert{
|
||||
AlertID: "test-alert-123",
|
||||
AlertType: "security",
|
||||
AlertLevel: "warning",
|
||||
TenantID: 2001,
|
||||
Status: model.AlertStatusActive,
|
||||
}
|
||||
store.Create(context.Background(), alert)
|
||||
|
||||
// 解决告警
|
||||
reqBody := ResolveAlertRequest{
|
||||
ResolvedBy: "admin",
|
||||
Note: "Fixed the issue",
|
||||
}
|
||||
body, _ := json.Marshal(reqBody)
|
||||
req := httptest.NewRequest("POST", "/api/v1/audit/alerts/test-alert-123/resolve", bytes.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
h.ResolveAlert(w, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
|
||||
var result AlertResponse
|
||||
json.Unmarshal(w.Body.Bytes(), &result)
|
||||
assert.Equal(t, model.AlertStatusResolved, result.Alert.Status)
|
||||
assert.Equal(t, "admin", result.Alert.ResolvedBy)
|
||||
}
|
||||
Reference in New Issue
Block a user