测试增强: - handler_test.go: 大幅增强 handler 集成测试(+1284/-98 行) - theme_handler_test.go: 增强主题管理测试(+174/-22 行) - auth_bootstrap_test.go: 新增 bootstrap 认证测试(+329 行) - ratelimit_test.go: 新增限流中间件测试(+153 行) - runtime_test.go: 新增运行时中间件测试(+351 行) 错误处理: - auth_handler.go: classifyErrorMessage 增加 TOTP 错误码和 2FA 状态字分类 清理: - 删除覆盖率报告残留文件(coverage_issue, handler, middleware 等) - 归档 docs/superpowers/plans/2026-05-09-middleware-test-backfill-phase1.md
570 lines
17 KiB
Go
570 lines
17 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
|
"github.com/user-management-system/internal/config"
|
|
apierrors "github.com/user-management-system/internal/pkg/errors"
|
|
"github.com/user-management-system/internal/security"
|
|
)
|
|
|
|
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
SetCORSConfig(config.CORSConfig{
|
|
AllowedOrigins: []string{"https://app.example.com"},
|
|
AllowCredentials: true,
|
|
})
|
|
t.Cleanup(func() {
|
|
SetCORSConfig(config.CORSConfig{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowCredentials: true,
|
|
})
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
|
|
c.Request.Header.Set("Origin", "https://app.example.com")
|
|
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
|
|
|
|
CORS()(c)
|
|
|
|
if recorder.Code != http.StatusNoContent {
|
|
t.Fatalf("expected 204, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" {
|
|
t.Fatalf("unexpected allow origin: %s", got)
|
|
}
|
|
if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
|
t.Fatalf("expected credentials header to be 'true', got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestCORS_RejectsDisallowedOrigin(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
SetCORSConfig(config.CORSConfig{
|
|
AllowedOrigins: []string{"https://app.example.com"},
|
|
AllowCredentials: false,
|
|
})
|
|
t.Cleanup(func() {
|
|
SetCORSConfig(config.CORSConfig{
|
|
AllowedOrigins: []string{"*"},
|
|
AllowCredentials: true,
|
|
})
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
c.Request.Header.Set("Origin", "https://evil.example.com")
|
|
|
|
CORS()(c)
|
|
|
|
if recorder.Code != http.StatusForbidden {
|
|
t.Fatalf("expected 403, got %d", recorder.Code)
|
|
}
|
|
}
|
|
|
|
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
|
|
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
|
|
sanitized := sanitizeQuery(raw)
|
|
|
|
if sanitized == "" {
|
|
t.Fatal("expected sanitized query")
|
|
}
|
|
if sanitized == raw {
|
|
t.Fatal("expected query to be sanitized")
|
|
}
|
|
for _, value := range []string{"abc123", "xyz", "s1"} {
|
|
if strings.Contains(sanitized, value) {
|
|
t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized)
|
|
}
|
|
}
|
|
if sanitizeQuery("") != "" {
|
|
t.Fatal("expected empty query to stay empty")
|
|
}
|
|
}
|
|
|
|
func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
|
|
SecurityHeaders()(c)
|
|
|
|
if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" {
|
|
t.Fatalf("unexpected nosniff header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" {
|
|
t.Fatalf("unexpected frame options: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Content-Security-Policy"); got == "" {
|
|
t.Fatal("expected content security policy header")
|
|
}
|
|
if got := recorder.Header().Get("Strict-Transport-Security"); got != "" {
|
|
t.Fatalf("did not expect hsts header for http request, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
|
|
|
SecurityHeaders()(c)
|
|
|
|
if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") {
|
|
t.Fatalf("expected hsts header, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil)
|
|
|
|
NoStoreSensitiveResponses()(c)
|
|
|
|
if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl {
|
|
t.Fatalf("unexpected cache-control header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Pragma"); got != "no-cache" {
|
|
t.Fatalf("unexpected pragma header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Expires"); got != "0" {
|
|
t.Fatalf("unexpected expires header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" {
|
|
t.Fatalf("unexpected surrogate-control header: %q", got)
|
|
}
|
|
}
|
|
|
|
func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
|
|
NoStoreSensitiveResponses()(c)
|
|
|
|
if got := recorder.Header().Get("Cache-Control"); got != "" {
|
|
t.Fatalf("did not expect cache-control header, got %q", got)
|
|
}
|
|
}
|
|
|
|
// ---------- TraceID middleware ----------
|
|
|
|
func TestTraceID_GeneratesAndAttachesTraceID(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
|
|
TraceID()(c)
|
|
|
|
traceID := c.GetString("trace_id")
|
|
if traceID == "" {
|
|
t.Fatal("expected trace_id to be set")
|
|
}
|
|
if len(traceID) < 8 {
|
|
t.Fatalf("trace_id should be reasonably long, got %q", traceID)
|
|
}
|
|
|
|
if got := recorder.Header().Get("X-Trace-ID"); got != traceID {
|
|
t.Fatalf("expected X-Trace-ID header to match trace_id, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestTraceID_ExtractsExistingTraceID(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
existingTraceID := "existing-trace-id-12345"
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
c.Request.Header.Set("X-Trace-ID", existingTraceID)
|
|
|
|
TraceID()(c)
|
|
|
|
traceID := c.GetString("trace_id")
|
|
if traceID != existingTraceID {
|
|
t.Fatalf("expected trace_id to be extracted from header, got %q", traceID)
|
|
}
|
|
}
|
|
|
|
func TestTraceID_GetTraceIDHandlesMissingAndPresentValue(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
|
|
if got := GetTraceID(c); got != "" {
|
|
t.Fatalf("GetTraceID() = %q, want empty string", got)
|
|
}
|
|
|
|
c.Set(TraceIDKey, "trace-123")
|
|
if got := GetTraceID(c); got != "trace-123" {
|
|
t.Fatalf("GetTraceID() = %q, want trace-123", got)
|
|
}
|
|
}
|
|
|
|
// ---------- Error handling middleware ----------
|
|
|
|
func TestErrorHandler_HandlesErrors(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
|
|
|
c.Error(errors.New("test error"))
|
|
|
|
ErrorHandler()(c)
|
|
|
|
if recorder.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected status 500, got %d", recorder.Code)
|
|
}
|
|
}
|
|
|
|
func TestErrorHandler_ApplicationErrorPreservesStatusAndReason(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
router := gin.New()
|
|
router.Use(ErrorHandler())
|
|
router.GET("/users", func(c *gin.Context) {
|
|
_ = c.Error(apierrors.Forbidden("FORBIDDEN", "denied"))
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
if recorder.Code != http.StatusForbidden {
|
|
t.Fatalf("expected status 403, got %d", recorder.Code)
|
|
}
|
|
|
|
var body map[string]any
|
|
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("unmarshal body failed: %v", err)
|
|
}
|
|
if got := body["reason"]; got != "FORBIDDEN" {
|
|
t.Fatalf("reason = %#v, want FORBIDDEN", got)
|
|
}
|
|
if got := body["message"]; got != "denied" {
|
|
t.Fatalf("message = %#v, want denied", got)
|
|
}
|
|
}
|
|
|
|
func TestRecover_HandlesPanic(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, router := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/panic", nil)
|
|
|
|
router.Use(Recover())
|
|
router.GET("/panic", func(c *gin.Context) {
|
|
panic("test panic")
|
|
})
|
|
|
|
router.ServeHTTP(recorder, c.Request)
|
|
|
|
if recorder.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected status 500 after panic, got %d", recorder.Code)
|
|
}
|
|
}
|
|
|
|
func TestRecover_ReturnsInternalServerErrorPayload(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
recorder := httptest.NewRecorder()
|
|
router := gin.New()
|
|
router.Use(Recover())
|
|
router.GET("/panic", func(c *gin.Context) {
|
|
panic("boom")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
if recorder.Code != http.StatusInternalServerError {
|
|
t.Fatalf("expected status 500 after panic, got %d", recorder.Code)
|
|
}
|
|
|
|
var body map[string]any
|
|
if err := json.Unmarshal(recorder.Body.Bytes(), &body); err != nil {
|
|
t.Fatalf("unmarshal body failed: %v", err)
|
|
}
|
|
if got := body["code"]; got != float64(http.StatusInternalServerError) {
|
|
t.Fatalf("code = %#v, want %d", got, http.StatusInternalServerError)
|
|
}
|
|
}
|
|
|
|
func TestLogger_WritesSanitizedQueryAndErrorContext(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var buf bytes.Buffer
|
|
originalWriter := log.Writer()
|
|
log.SetOutput(&buf)
|
|
t.Cleanup(func() {
|
|
log.SetOutput(originalWriter)
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
router := gin.New()
|
|
router.Use(TraceID())
|
|
router.Use(Logger())
|
|
router.GET("/users", func(c *gin.Context) {
|
|
c.Set("user_id", int64(7))
|
|
_ = c.Error(errors.New("boom"))
|
|
c.Status(http.StatusAccepted)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/users?token=secret&name=alice", nil)
|
|
req.RemoteAddr = "203.0.113.5:1234"
|
|
req.Header.Set("User-Agent", "logger-test")
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
deadline := time.Now().Add(time.Second)
|
|
for time.Now().Before(deadline) && !strings.Contains(buf.String(), "[Query] /users?name=alice&token=%2A%2A%2A") {
|
|
time.Sleep(10 * time.Millisecond)
|
|
}
|
|
|
|
logOutput := buf.String()
|
|
if !strings.Contains(logOutput, "[API]") {
|
|
t.Fatalf("expected API log entry, got %q", logOutput)
|
|
}
|
|
if !strings.Contains(logOutput, "user_id: 7") {
|
|
t.Fatalf("expected user id in logs, got %q", logOutput)
|
|
}
|
|
if !strings.Contains(logOutput, "[Error]") || !strings.Contains(logOutput, "boom") {
|
|
t.Fatalf("expected error log entry, got %q", logOutput)
|
|
}
|
|
if strings.Contains(logOutput, "token=secret") {
|
|
t.Fatalf("expected sanitized query string, got %q", logOutput)
|
|
}
|
|
}
|
|
|
|
func TestLogger_DropsMalformedQueryString(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
var buf bytes.Buffer
|
|
originalWriter := log.Writer()
|
|
log.SetOutput(&buf)
|
|
t.Cleanup(func() {
|
|
log.SetOutput(originalWriter)
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
router := gin.New()
|
|
router.Use(Logger())
|
|
router.GET("/users", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/users?bad=%zz", nil)
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
time.Sleep(25 * time.Millisecond)
|
|
if strings.Contains(buf.String(), "[Query]") {
|
|
t.Fatalf("expected malformed query to be skipped, got %q", buf.String())
|
|
}
|
|
}
|
|
|
|
func TestResponseWrapper_SkipsSSEAndBinaryResponses(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
testCases := []struct {
|
|
name string
|
|
path string
|
|
contentType string
|
|
}{
|
|
{name: "sse", path: "/stream", contentType: "text/event-stream"},
|
|
{name: "binary", path: "/download", contentType: "application/octet-stream"},
|
|
{name: "swagger", path: "/swagger/index.html", contentType: ""},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
router := gin.New()
|
|
router.Use(ResponseWrapper())
|
|
router.GET(tc.path, func(c *gin.Context) {
|
|
c.Header("Content-Type", "application/json")
|
|
c.JSON(http.StatusOK, gin.H{"ok": true})
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
|
|
if tc.contentType != "" {
|
|
req.Header.Set("Content-Type", tc.contentType)
|
|
}
|
|
router.ServeHTTP(recorder, req)
|
|
|
|
if recorder.Code != http.StatusOK {
|
|
t.Fatalf("expected 200, got %d", recorder.Code)
|
|
}
|
|
if got := recorder.Body.String(); got != `{"ok":true}` {
|
|
t.Fatalf("body = %s, want raw payload", got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestResponseWrapper_BufferMethodsTrackStatusAndBody(t *testing.T) {
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
wrapper := &responseWrapper{
|
|
ResponseWriter: c.Writer,
|
|
body: bytes.NewBuffer(nil),
|
|
statusCode: http.StatusOK,
|
|
}
|
|
|
|
if _, err := wrapper.Write([]byte("abc")); err != nil {
|
|
t.Fatalf("Write() error = %v", err)
|
|
}
|
|
if _, err := wrapper.WriteString("def"); err != nil {
|
|
t.Fatalf("WriteString() error = %v", err)
|
|
}
|
|
wrapper.WriteHeader(http.StatusAccepted)
|
|
|
|
if got := wrapper.body.String(); got != "abcdef" {
|
|
t.Fatalf("buffered body = %q, want abcdef", got)
|
|
}
|
|
if wrapper.statusCode != http.StatusAccepted {
|
|
t.Fatalf("statusCode = %d, want %d", wrapper.statusCode, http.StatusAccepted)
|
|
}
|
|
}
|
|
|
|
func TestIPFilter_RealIPAndInternalOnly(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
filter := security.NewIPFilter()
|
|
middleware := NewIPFilterMiddleware(filter, IPFilterConfig{
|
|
TrustProxy: true,
|
|
TrustedProxies: []string{"10.0.0.2"},
|
|
})
|
|
|
|
recorder := httptest.NewRecorder()
|
|
c, _ := gin.CreateTestContext(recorder)
|
|
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
c.Request.RemoteAddr = "10.0.0.2:8080"
|
|
c.Request.Header.Set("X-Forwarded-For", "198.51.100.10, 10.0.0.2")
|
|
|
|
if got := middleware.realIP(c); got != "198.51.100.10" {
|
|
t.Fatalf("realIP() = %q, want 198.51.100.10", got)
|
|
}
|
|
if !middleware.isTrustedProxy("10.0.0.2") {
|
|
t.Fatal("expected trusted proxy match")
|
|
}
|
|
if middleware.isTrustedProxy("10.0.0.3") {
|
|
t.Fatal("unexpected trusted proxy match")
|
|
}
|
|
|
|
if !isPrivateIP("127.0.0.1") {
|
|
t.Fatal("expected loopback to be private")
|
|
}
|
|
if isPrivateIP("198.51.100.10") {
|
|
t.Fatal("expected public address to be non-private")
|
|
}
|
|
|
|
allowed := httptest.NewRecorder()
|
|
allowedRouter := gin.New()
|
|
allowedRouter.Use(InternalOnly())
|
|
allowedRouter.GET("/metrics", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
allowedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
|
allowedReq.RemoteAddr = "127.0.0.1:12345"
|
|
allowedRouter.ServeHTTP(allowed, allowedReq)
|
|
if allowed.Code != http.StatusOK {
|
|
t.Fatalf("expected private IP to pass, got %d", allowed.Code)
|
|
}
|
|
|
|
blocked := httptest.NewRecorder()
|
|
blockedRouter := gin.New()
|
|
blockedRouter.Use(InternalOnly())
|
|
blockedRouter.GET("/metrics", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
blockedReq := httptest.NewRequest(http.MethodGet, "/metrics", nil)
|
|
blockedReq.RemoteAddr = "198.51.100.10:12345"
|
|
blockedRouter.ServeHTTP(blocked, blockedReq)
|
|
if blocked.Code != http.StatusForbidden {
|
|
t.Fatalf("expected public IP to be rejected, got %d", blocked.Code)
|
|
}
|
|
}
|
|
|
|
func TestIPFilter_FilterAndFallbacks(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
|
|
filter := security.NewIPFilter()
|
|
if err := filter.AddToBlacklist("198.51.100.10", "manual", time.Minute); err != nil {
|
|
t.Fatalf("AddToBlacklist() error = %v", err)
|
|
}
|
|
middleware := NewIPFilterMiddleware(filter, IPFilterConfig{})
|
|
if middleware.GetFilter() != filter {
|
|
t.Fatal("expected GetFilter() to expose the original filter")
|
|
}
|
|
|
|
blockedRecorder := httptest.NewRecorder()
|
|
blockedRouter := gin.New()
|
|
blockedRouter.Use(middleware.Filter())
|
|
blockedRouter.GET("/protected", func(c *gin.Context) {
|
|
c.Status(http.StatusOK)
|
|
})
|
|
blockedReq := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
blockedReq.RemoteAddr = "198.51.100.10:12345"
|
|
blockedRouter.ServeHTTP(blockedRecorder, blockedReq)
|
|
if blockedRecorder.Code != http.StatusForbidden {
|
|
t.Fatalf("expected blocked IP to be rejected, got %d", blockedRecorder.Code)
|
|
}
|
|
|
|
allowedRecorder := httptest.NewRecorder()
|
|
allowedRouter := gin.New()
|
|
allowedRouter.Use(middleware.Filter())
|
|
allowedRouter.GET("/protected", func(c *gin.Context) {
|
|
if got := c.GetString("client_ip"); got != "127.0.0.1" {
|
|
t.Fatalf("client_ip = %q, want 127.0.0.1", got)
|
|
}
|
|
c.Status(http.StatusOK)
|
|
})
|
|
allowedReq := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
allowedReq.RemoteAddr = "127.0.0.1:54321"
|
|
allowedRouter.ServeHTTP(allowedRecorder, allowedReq)
|
|
if allowedRecorder.Code != http.StatusOK {
|
|
t.Fatalf("expected allowed IP to pass, got %d", allowedRecorder.Code)
|
|
}
|
|
|
|
trustedProxyMiddleware := NewIPFilterMiddleware(filter, IPFilterConfig{
|
|
TrustProxy: true,
|
|
})
|
|
proxyRecorder := httptest.NewRecorder()
|
|
proxyCtx, _ := gin.CreateTestContext(proxyRecorder)
|
|
proxyCtx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
|
proxyCtx.Request.RemoteAddr = "10.0.0.2:8080"
|
|
proxyCtx.Request.Header.Set("X-Real-IP", "203.0.113.9")
|
|
if got := trustedProxyMiddleware.realIP(proxyCtx); got != "203.0.113.9" {
|
|
t.Fatalf("realIP() X-Real-IP fallback = %q, want 203.0.113.9", got)
|
|
}
|
|
}
|