package auth import ( "encoding/json" "net/http" "net/http/httptest" "net/url" "strings" "testing" "time" ) func TestGenerateState(t *testing.T) { state, err := GenerateState() if err != nil { t.Fatalf("GenerateState() error = %v", err) } if state == "" { t.Error("GenerateState() returned empty state") } // State should be base64 encoded, so no special chars that would break URLs if strings.ContainsAny(state, "+/") { t.Error("GenerateState() should use URL-safe base64 encoding") } } func TestValidateState(t *testing.T) { // Test valid state state, err := GenerateState() if err != nil { t.Fatalf("GenerateState() error = %v", err) } if !ValidateState(state) { t.Error("ValidateState() returned false for valid state") } // State should be consumed (one-time use) if ValidateState(state) { t.Error("ValidateState() should return false for consumed state") } // Test invalid state if ValidateState("invalid-state") { t.Error("ValidateState() returned true for invalid state") } } func TestValidateState_Expired(t *testing.T) { // Create a state and manually expire it state, err := GenerateState() if err != nil { t.Fatalf("GenerateState() error = %v", err) } // Manually set expired time stateStore.mu.Lock() stateStore.states[state] = time.Now().Add(-1 * time.Hour) stateStore.mu.Unlock() if ValidateState(state) { t.Error("ValidateState() should return false for expired state") } } func TestCleanupStates(t *testing.T) { // Clear existing states stateStore.mu.Lock() stateStore.states = make(map[string]time.Time) stateStore.mu.Unlock() // Add some states state1, _ := GenerateState() state2, _ := GenerateState() // Manually expire one stateStore.mu.Lock() stateStore.states["expired-state"] = time.Now().Add(-1 * time.Hour) stateStore.mu.Unlock() // Cleanup CleanupStates() stateStore.mu.RLock() defer stateStore.mu.RUnlock() // Expired state should be removed if _, ok := stateStore.states["expired-state"]; ok { t.Error("CleanupStates() did not remove expired state") } // Valid states should remain if _, ok := stateStore.states[state1]; !ok { t.Error("CleanupStates() removed valid state1") } if _, ok := stateStore.states[state2]; !ok { t.Error("CleanupStates() removed valid state2") } } func TestGet(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "GET" { t.Errorf("Expected GET request, got %s", r.Method) } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() resp, err := Get(server.URL) if err != nil { t.Fatalf("Get() error = %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("Get() status = %d, want %d", resp.StatusCode, http.StatusOK) } } func TestPostForm(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } w.WriteHeader(http.StatusOK) w.Write([]byte("OK")) })) defer server.Close() data := url.Values{} data.Set("key", "value") resp, err := PostForm(server.URL, data) if err != nil { t.Fatalf("PostForm() error = %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { t.Errorf("PostForm() status = %d, want %d", resp.StatusCode, http.StatusOK) } } func TestGetJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"message": "hello"}) })) defer server.Close() var result struct { Message string `json:"message"` } err := GetJSON(server.URL, &result) if err != nil { t.Fatalf("GetJSON() error = %v", err) } if result.Message != "hello" { t.Errorf("GetJSON() result.Message = %s, want hello", result.Message) } } func TestGetJSON_NonOKStatus(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer server.Close() var result struct{} err := GetJSON(server.URL, &result) if err == nil { t.Error("GetJSON() should return error for non-OK status") } } func TestPostFormJSON(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { t.Errorf("Expected POST request, got %s", r.Method) } w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(map[string]string{"token": "abc123"}) })) defer server.Close() data := url.Values{} data.Set("grant_type", "authorization_code") var result struct { Token string `json:"token"` } err := PostFormJSON(server.URL, data, &result) if err != nil { t.Fatalf("PostFormJSON() error = %v", err) } if result.Token != "abc123" { t.Errorf("PostFormJSON() result.Token = %s, want abc123", result.Token) } } func TestPostFormJSON_NonOKStatus(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer server.Close() var result struct{} err := PostFormJSON(server.URL, url.Values{}, &result) if err == nil { t.Error("PostFormJSON() should return error for non-OK status") } } func TestBuildAuthURL(t *testing.T) { baseURL := "https://example.com/oauth/authorize" clientID := "test-client-id" redirectURI := "https://myapp.com/callback" scope := "openid email" state := "random-state" result := BuildAuthURL(baseURL, clientID, redirectURI, scope, state) u, err := url.Parse(result) if err != nil { t.Fatalf("BuildAuthURL() produced invalid URL: %v", err) } if u.Scheme != "https" { t.Errorf("BuildAuthURL() scheme = %s, want https", u.Scheme) } if u.Host != "example.com" { t.Errorf("BuildAuthURL() host = %s, want example.com", u.Host) } q := u.Query() if q.Get("client_id") != clientID { t.Errorf("BuildAuthURL() client_id = %s, want %s", q.Get("client_id"), clientID) } if q.Get("redirect_uri") != redirectURI { t.Errorf("BuildAuthURL() redirect_uri = %s, want %s", q.Get("redirect_uri"), redirectURI) } if q.Get("scope") != scope { t.Errorf("BuildAuthURL() scope = %s, want %s", q.Get("scope"), scope) } if q.Get("state") != state { t.Errorf("BuildAuthURL() state = %s, want %s", q.Get("state"), state) } if q.Get("response_type") != "code" { t.Errorf("BuildAuthURL() response_type = %s, want code", q.Get("response_type")) } } func TestParseAccessTokenResponse(t *testing.T) { jsonData := `{ "access_token": "test-access-token", "refresh_token": "test-refresh-token", "expires_in": 3600, "token_type": "Bearer" }` token, err := ParseAccessTokenResponse([]byte(jsonData)) if err != nil { t.Fatalf("ParseAccessTokenResponse() error = %v", err) } if token.AccessToken != "test-access-token" { t.Errorf("AccessToken = %s, want test-access-token", token.AccessToken) } if token.RefreshToken != "test-refresh-token" { t.Errorf("RefreshToken = %s, want test-refresh-token", token.RefreshToken) } if token.ExpiresIn != 3600 { t.Errorf("ExpiresIn = %d, want 3600", token.ExpiresIn) } if token.TokenType != "Bearer" { t.Errorf("TokenType = %s, want Bearer", token.TokenType) } } func TestParseAccessTokenResponse_InvalidJSON(t *testing.T) { _, err := ParseAccessTokenResponse([]byte("invalid json")) if err == nil { t.Error("ParseAccessTokenResponse() should return error for invalid JSON") } } func TestParseQueryAccessToken(t *testing.T) { body := "access_token=abc123&token_type=Bearer&expires_in=3600" token, err := ParseQueryAccessToken(body) if err != nil { t.Fatalf("ParseQueryAccessToken() error = %v", err) } if token != "abc123" { t.Errorf("ParseQueryAccessToken() = %s, want abc123", token) } } func TestParseQueryAccessToken_NoToken(t *testing.T) { body := "token_type=Bearer&expires_in=3600" token, err := ParseQueryAccessToken(body) if err != nil { t.Fatalf("ParseQueryAccessToken() error = %v", err) } if token != "" { t.Errorf("ParseQueryAccessToken() = %s, want empty", token) } } func TestParseQueryAccessToken_InvalidQuery(t *testing.T) { _, err := ParseQueryAccessToken("invalid%zz") if err == nil { t.Error("ParseQueryAccessToken() should return error for invalid query string") } } func TestParseJSONPResponse(t *testing.T) { jsonp := `callback({"access_token":"abc123","expires_in":7200})` result, err := ParseJSONPResponse(jsonp) if err != nil { t.Fatalf("ParseJSONPResponse() error = %v", err) } if result["access_token"] != "abc123" { t.Errorf("ParseJSONPResponse() access_token = %v, want abc123", result["access_token"]) } if result["expires_in"].(float64) != 7200 { t.Errorf("ParseJSONPResponse() expires_in = %v, want 7200", result["expires_in"]) } } func TestParseJSONPResponse_InvalidFormat(t *testing.T) { tests := []struct { name string jsonp string }{ {"no parentheses", "invalid"}, {"no opening", "invalid)"}, {"no closing", "invalid("}, {"invalid JSON", "callback(invalid json)"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { _, err := ParseJSONPResponse(tt.jsonp) if err == nil { t.Errorf("ParseJSONPResponse() should return error for %s", tt.name) } }) } } func TestToOAuth2Config(t *testing.T) { config := &OAuthConfig{ ClientID: "test-client-id", ClientSecret: "test-client-secret", RedirectURI: "https://myapp.com/callback", Scope: "openid,email,profile", AuthURL: "https://example.com/oauth/authorize", TokenURL: "https://example.com/oauth/token", } oauth2Config := ToOAuth2Config(config) if oauth2Config.ClientID != config.ClientID { t.Errorf("ClientID = %s, want %s", oauth2Config.ClientID, config.ClientID) } if oauth2Config.ClientSecret != config.ClientSecret { t.Errorf("ClientSecret = %s, want %s", oauth2Config.ClientSecret, config.ClientSecret) } if oauth2Config.RedirectURL != config.RedirectURI { t.Errorf("RedirectURL = %s, want %s", oauth2Config.RedirectURL, config.RedirectURI) } if len(oauth2Config.Scopes) != 3 { t.Errorf("Scopes length = %d, want 3", len(oauth2Config.Scopes)) } if oauth2Config.Endpoint.AuthURL != config.AuthURL { t.Errorf("AuthURL = %s, want %s", oauth2Config.Endpoint.AuthURL, config.AuthURL) } if oauth2Config.Endpoint.TokenURL != config.TokenURL { t.Errorf("TokenURL = %s, want %s", oauth2Config.Endpoint.TokenURL, config.TokenURL) } } func TestGetJSON_ConnectionError(t *testing.T) { var result struct{} err := GetJSON("http://127.0.0.1:1", &result) // Invalid port if err == nil { t.Error("GetJSON() should return error for connection failure") } } func TestPostFormJSON_ConnectionError(t *testing.T) { var result struct{} err := PostFormJSON("http://127.0.0.1:1", url.Values{}, &result) // Invalid port if err == nil { t.Error("PostFormJSON() should return error for connection failure") } }