diff --git a/internal/api/middleware/cors_test.go b/internal/api/middleware/cors_test.go new file mode 100644 index 0000000..fae130d --- /dev/null +++ b/internal/api/middleware/cors_test.go @@ -0,0 +1,215 @@ +package middleware + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/user-management-system/internal/config" +) + +func TestValidateCORSConfig(t *testing.T) { + tests := []struct { + name string + cfg config.CORSConfig + wantErr bool + }{ + { + name: "valid config with specific origins", + cfg: config.CORSConfig{ + AllowedOrigins: []string{"https://example.com"}, + AllowCredentials: true, + }, + wantErr: false, + }, + { + name: "valid config with wildcard no credentials", + cfg: config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: false, + }, + wantErr: false, + }, + { + name: "invalid config with wildcard and credentials", + cfg: config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + }, + wantErr: true, + }, + { + name: "empty origins", + cfg: config.CORSConfig{ + AllowedOrigins: []string{}, + AllowCredentials: false, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateCORSConfig(tt.cfg) + if tt.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +func TestSetCORSConfig(t *testing.T) { + // Save original config + originalConfig := corsConfig + defer func() { corsConfig = originalConfig }() + + t.Run("valid config", func(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"https://example.com"}, + AllowCredentials: true, + } + err := SetCORSConfig(cfg) + assert.NoError(t, err) + assert.Equal(t, cfg, corsConfig) + }) + + t.Run("invalid config", func(t *testing.T) { + cfg := config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + } + err := SetCORSConfig(cfg) + assert.Error(t, err) + }) +} + +func TestResolveAllowedOrigin(t *testing.T) { + tests := []struct { + name string + origin string + allowedOrigins []string + allowCredentials bool + wantOrigin string + wantAllowed bool + }{ + { + name: "exact match", + origin: "https://example.com", + allowedOrigins: []string{"https://example.com"}, + allowCredentials: true, + wantOrigin: "https://example.com", + wantAllowed: true, + }, + { + name: "wildcard without credentials", + origin: "https://any.com", + allowedOrigins: []string{"*"}, + allowCredentials: false, + wantOrigin: "*", + wantAllowed: true, + }, + { + name: "wildcard with credentials returns origin", + origin: "https://any.com", + allowedOrigins: []string{"*"}, + allowCredentials: true, + wantOrigin: "https://any.com", + wantAllowed: true, + }, + { + name: "no match", + origin: "https://evil.com", + allowedOrigins: []string{"https://example.com"}, + allowCredentials: false, + wantOrigin: "", + wantAllowed: false, + }, + { + name: "case insensitive match", + origin: "HTTPS://EXAMPLE.COM", + allowedOrigins: []string{"https://example.com"}, + allowCredentials: false, + wantOrigin: "HTTPS://EXAMPLE.COM", + wantAllowed: true, + }, + { + name: "empty origins list", + origin: "https://example.com", + allowedOrigins: []string{}, + allowCredentials: false, + wantOrigin: "", + wantAllowed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotOrigin, gotAllowed := resolveAllowedOrigin(tt.origin, tt.allowedOrigins, tt.allowCredentials) + assert.Equal(t, tt.wantOrigin, gotOrigin) + assert.Equal(t, tt.wantAllowed, gotAllowed) + }) + } +} + +func TestCORS(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Save and restore original config + originalConfig := corsConfig + defer func() { corsConfig = originalConfig }() + + // Set test config + corsConfig = config.CORSConfig{ + AllowedOrigins: []string{"https://example.com"}, + AllowCredentials: true, + } + + router := gin.New() + router.Use(CORS()) + router.GET("/test", func(c *gin.Context) { + c.String(200, "OK") + }) + + t.Run("allow valid origin", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://example.com") + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) + }) + + t.Run("forbid invalid origin", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + req.Header.Set("Origin", "https://evil.com") + router.ServeHTTP(w, req) + + assert.Equal(t, 403, w.Code) + }) + + t.Run("handle OPTIONS request", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("OPTIONS", "/test", nil) + req.Header.Set("Origin", "https://example.com") + router.ServeHTTP(w, req) + + assert.Equal(t, 204, w.Code) + assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods")) + assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers")) + }) + + t.Run("no origin header", func(t *testing.T) { + w := httptest.NewRecorder() + req, _ := http.NewRequest("GET", "/test", nil) + router.ServeHTTP(w, req) + + assert.Equal(t, 200, w.Code) + }) +}