package middleware import ( "net/http" "net/http/httptest" "strings" "testing" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) func TestCORS_UsesConfiguredOrigins(t *testing.T) { gin.SetMode(gin.TestMode) SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"https://app.example.com"}, AllowCredentials: true, }) t.Cleanup(func() { SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"*"}, AllowCredentials: true, }) }) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil) c.Request.Header.Set("Origin", "https://app.example.com") c.Request.Header.Set("Access-Control-Request-Headers", "Authorization") CORS()(c) if recorder.Code != http.StatusNoContent { t.Fatalf("expected 204, got %d", recorder.Code) } if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" { t.Fatalf("unexpected allow origin: %s", got) } if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" { t.Fatalf("expected credentials header to be 'true', got %q", got) } } func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) { raw := "token=abc123&foo=bar&access_token=xyz&secret=s1" sanitized := sanitizeQuery(raw) if sanitized == "" { t.Fatal("expected sanitized query") } if sanitized == raw { t.Fatal("expected query to be sanitized") } for _, value := range []string{"abc123", "xyz", "s1"} { if strings.Contains(sanitized, value) { t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized) } } if sanitizeQuery("") != "" { t.Fatal("expected empty query to stay empty") } } func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) SecurityHeaders()(c) if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" { t.Fatalf("unexpected nosniff header: %q", got) } if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" { t.Fatalf("unexpected frame options: %q", got) } if got := recorder.Header().Get("Content-Security-Policy"); got == "" { t.Fatal("expected content security policy header") } if got := recorder.Header().Get("Strict-Transport-Security"); got != "" { t.Fatalf("did not expect hsts header for http request, got %q", got) } } func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) c.Request.Header.Set("X-Forwarded-Proto", "https") SecurityHeaders()(c) if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") { t.Fatalf("expected hsts header, got %q", got) } } func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil) NoStoreSensitiveResponses()(c) if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl { t.Fatalf("unexpected cache-control header: %q", got) } if got := recorder.Header().Get("Pragma"); got != "no-cache" { t.Fatalf("unexpected pragma header: %q", got) } if got := recorder.Header().Get("Expires"); got != "0" { t.Fatalf("unexpected expires header: %q", got) } if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" { t.Fatalf("unexpected surrogate-control header: %q", got) } } func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) { gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil) NoStoreSensitiveResponses()(c) if got := recorder.Header().Get("Cache-Control"); got != "" { t.Fatalf("did not expect cache-control header, got %q", got) } }