- Add new test files for auth, service, and handler modules - Improve test organization and coverage - Refactor code for better maintainability - Add captcha, settings, stats, and theme handler tests - Add auth module tests (CAS, OAuth, password, SSO, state) - Add service layer tests for auth, export, permissions, roles - All Go tests pass (exit code 0) - All frontend tests pass (325 tests in 59 files)
406 lines
11 KiB
Go
406 lines
11 KiB
Go
package auth
|
|
|
|
import (
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestGenerateState(t *testing.T) {
|
|
state, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState() error = %v", err)
|
|
}
|
|
if state == "" {
|
|
t.Error("GenerateState() returned empty state")
|
|
}
|
|
// State should be base64 encoded, so no special chars that would break URLs
|
|
if strings.ContainsAny(state, "+/") {
|
|
t.Error("GenerateState() should use URL-safe base64 encoding")
|
|
}
|
|
}
|
|
|
|
func TestValidateState(t *testing.T) {
|
|
// Test valid state
|
|
state, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState() error = %v", err)
|
|
}
|
|
|
|
if !ValidateState(state) {
|
|
t.Error("ValidateState() returned false for valid state")
|
|
}
|
|
|
|
// State should be consumed (one-time use)
|
|
if ValidateState(state) {
|
|
t.Error("ValidateState() should return false for consumed state")
|
|
}
|
|
|
|
// Test invalid state
|
|
if ValidateState("invalid-state") {
|
|
t.Error("ValidateState() returned true for invalid state")
|
|
}
|
|
}
|
|
|
|
func TestValidateState_Expired(t *testing.T) {
|
|
// Create a state and manually expire it
|
|
state, err := GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState() error = %v", err)
|
|
}
|
|
|
|
// Manually set expired time
|
|
stateStore.mu.Lock()
|
|
stateStore.states[state] = time.Now().Add(-1 * time.Hour)
|
|
stateStore.mu.Unlock()
|
|
|
|
if ValidateState(state) {
|
|
t.Error("ValidateState() should return false for expired state")
|
|
}
|
|
}
|
|
|
|
func TestCleanupStates(t *testing.T) {
|
|
// Clear existing states
|
|
stateStore.mu.Lock()
|
|
stateStore.states = make(map[string]time.Time)
|
|
stateStore.mu.Unlock()
|
|
|
|
// Add some states
|
|
state1, _ := GenerateState()
|
|
state2, _ := GenerateState()
|
|
|
|
// Manually expire one
|
|
stateStore.mu.Lock()
|
|
stateStore.states["expired-state"] = time.Now().Add(-1 * time.Hour)
|
|
stateStore.mu.Unlock()
|
|
|
|
// Cleanup
|
|
CleanupStates()
|
|
|
|
stateStore.mu.RLock()
|
|
defer stateStore.mu.RUnlock()
|
|
|
|
// Expired state should be removed
|
|
if _, ok := stateStore.states["expired-state"]; ok {
|
|
t.Error("CleanupStates() did not remove expired state")
|
|
}
|
|
|
|
// Valid states should remain
|
|
if _, ok := stateStore.states[state1]; !ok {
|
|
t.Error("CleanupStates() removed valid state1")
|
|
}
|
|
if _, ok := stateStore.states[state2]; !ok {
|
|
t.Error("CleanupStates() removed valid state2")
|
|
}
|
|
}
|
|
|
|
func TestGet(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "GET" {
|
|
t.Errorf("Expected GET request, got %s", r.Method)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
resp, err := Get(server.URL)
|
|
if err != nil {
|
|
t.Fatalf("Get() error = %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("Get() status = %d, want %d", resp.StatusCode, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestPostForm(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
w.Write([]byte("OK"))
|
|
}))
|
|
defer server.Close()
|
|
|
|
data := url.Values{}
|
|
data.Set("key", "value")
|
|
|
|
resp, err := PostForm(server.URL, data)
|
|
if err != nil {
|
|
t.Fatalf("PostForm() error = %v", err)
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
t.Errorf("PostForm() status = %d, want %d", resp.StatusCode, http.StatusOK)
|
|
}
|
|
}
|
|
|
|
func TestGetJSON(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"message": "hello"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
var result struct {
|
|
Message string `json:"message"`
|
|
}
|
|
err := GetJSON(server.URL, &result)
|
|
if err != nil {
|
|
t.Fatalf("GetJSON() error = %v", err)
|
|
}
|
|
if result.Message != "hello" {
|
|
t.Errorf("GetJSON() result.Message = %s, want hello", result.Message)
|
|
}
|
|
}
|
|
|
|
func TestGetJSON_NonOKStatus(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer server.Close()
|
|
|
|
var result struct{}
|
|
err := GetJSON(server.URL, &result)
|
|
if err == nil {
|
|
t.Error("GetJSON() should return error for non-OK status")
|
|
}
|
|
}
|
|
|
|
func TestPostFormJSON(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != "POST" {
|
|
t.Errorf("Expected POST request, got %s", r.Method)
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
json.NewEncoder(w).Encode(map[string]string{"token": "abc123"})
|
|
}))
|
|
defer server.Close()
|
|
|
|
data := url.Values{}
|
|
data.Set("grant_type", "authorization_code")
|
|
|
|
var result struct {
|
|
Token string `json:"token"`
|
|
}
|
|
err := PostFormJSON(server.URL, data, &result)
|
|
if err != nil {
|
|
t.Fatalf("PostFormJSON() error = %v", err)
|
|
}
|
|
if result.Token != "abc123" {
|
|
t.Errorf("PostFormJSON() result.Token = %s, want abc123", result.Token)
|
|
}
|
|
}
|
|
|
|
func TestPostFormJSON_NonOKStatus(t *testing.T) {
|
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
defer server.Close()
|
|
|
|
var result struct{}
|
|
err := PostFormJSON(server.URL, url.Values{}, &result)
|
|
if err == nil {
|
|
t.Error("PostFormJSON() should return error for non-OK status")
|
|
}
|
|
}
|
|
|
|
func TestBuildAuthURL(t *testing.T) {
|
|
baseURL := "https://example.com/oauth/authorize"
|
|
clientID := "test-client-id"
|
|
redirectURI := "https://myapp.com/callback"
|
|
scope := "openid email"
|
|
state := "random-state"
|
|
|
|
result := BuildAuthURL(baseURL, clientID, redirectURI, scope, state)
|
|
|
|
u, err := url.Parse(result)
|
|
if err != nil {
|
|
t.Fatalf("BuildAuthURL() produced invalid URL: %v", err)
|
|
}
|
|
|
|
if u.Scheme != "https" {
|
|
t.Errorf("BuildAuthURL() scheme = %s, want https", u.Scheme)
|
|
}
|
|
if u.Host != "example.com" {
|
|
t.Errorf("BuildAuthURL() host = %s, want example.com", u.Host)
|
|
}
|
|
|
|
q := u.Query()
|
|
if q.Get("client_id") != clientID {
|
|
t.Errorf("BuildAuthURL() client_id = %s, want %s", q.Get("client_id"), clientID)
|
|
}
|
|
if q.Get("redirect_uri") != redirectURI {
|
|
t.Errorf("BuildAuthURL() redirect_uri = %s, want %s", q.Get("redirect_uri"), redirectURI)
|
|
}
|
|
if q.Get("scope") != scope {
|
|
t.Errorf("BuildAuthURL() scope = %s, want %s", q.Get("scope"), scope)
|
|
}
|
|
if q.Get("state") != state {
|
|
t.Errorf("BuildAuthURL() state = %s, want %s", q.Get("state"), state)
|
|
}
|
|
if q.Get("response_type") != "code" {
|
|
t.Errorf("BuildAuthURL() response_type = %s, want code", q.Get("response_type"))
|
|
}
|
|
}
|
|
|
|
func TestParseAccessTokenResponse(t *testing.T) {
|
|
jsonData := `{
|
|
"access_token": "test-access-token",
|
|
"refresh_token": "test-refresh-token",
|
|
"expires_in": 3600,
|
|
"token_type": "Bearer"
|
|
}`
|
|
|
|
token, err := ParseAccessTokenResponse([]byte(jsonData))
|
|
if err != nil {
|
|
t.Fatalf("ParseAccessTokenResponse() error = %v", err)
|
|
}
|
|
|
|
if token.AccessToken != "test-access-token" {
|
|
t.Errorf("AccessToken = %s, want test-access-token", token.AccessToken)
|
|
}
|
|
if token.RefreshToken != "test-refresh-token" {
|
|
t.Errorf("RefreshToken = %s, want test-refresh-token", token.RefreshToken)
|
|
}
|
|
if token.ExpiresIn != 3600 {
|
|
t.Errorf("ExpiresIn = %d, want 3600", token.ExpiresIn)
|
|
}
|
|
if token.TokenType != "Bearer" {
|
|
t.Errorf("TokenType = %s, want Bearer", token.TokenType)
|
|
}
|
|
}
|
|
|
|
func TestParseAccessTokenResponse_InvalidJSON(t *testing.T) {
|
|
_, err := ParseAccessTokenResponse([]byte("invalid json"))
|
|
if err == nil {
|
|
t.Error("ParseAccessTokenResponse() should return error for invalid JSON")
|
|
}
|
|
}
|
|
|
|
func TestParseQueryAccessToken(t *testing.T) {
|
|
body := "access_token=abc123&token_type=Bearer&expires_in=3600"
|
|
|
|
token, err := ParseQueryAccessToken(body)
|
|
if err != nil {
|
|
t.Fatalf("ParseQueryAccessToken() error = %v", err)
|
|
}
|
|
|
|
if token != "abc123" {
|
|
t.Errorf("ParseQueryAccessToken() = %s, want abc123", token)
|
|
}
|
|
}
|
|
|
|
func TestParseQueryAccessToken_NoToken(t *testing.T) {
|
|
body := "token_type=Bearer&expires_in=3600"
|
|
|
|
token, err := ParseQueryAccessToken(body)
|
|
if err != nil {
|
|
t.Fatalf("ParseQueryAccessToken() error = %v", err)
|
|
}
|
|
|
|
if token != "" {
|
|
t.Errorf("ParseQueryAccessToken() = %s, want empty", token)
|
|
}
|
|
}
|
|
|
|
func TestParseQueryAccessToken_InvalidQuery(t *testing.T) {
|
|
_, err := ParseQueryAccessToken("invalid%zz")
|
|
if err == nil {
|
|
t.Error("ParseQueryAccessToken() should return error for invalid query string")
|
|
}
|
|
}
|
|
|
|
func TestParseJSONPResponse(t *testing.T) {
|
|
jsonp := `callback({"access_token":"abc123","expires_in":7200})`
|
|
|
|
result, err := ParseJSONPResponse(jsonp)
|
|
if err != nil {
|
|
t.Fatalf("ParseJSONPResponse() error = %v", err)
|
|
}
|
|
|
|
if result["access_token"] != "abc123" {
|
|
t.Errorf("ParseJSONPResponse() access_token = %v, want abc123", result["access_token"])
|
|
}
|
|
if result["expires_in"].(float64) != 7200 {
|
|
t.Errorf("ParseJSONPResponse() expires_in = %v, want 7200", result["expires_in"])
|
|
}
|
|
}
|
|
|
|
func TestParseJSONPResponse_InvalidFormat(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
jsonp string
|
|
}{
|
|
{"no parentheses", "invalid"},
|
|
{"no opening", "invalid)"},
|
|
{"no closing", "invalid("},
|
|
{"invalid JSON", "callback(invalid json)"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
_, err := ParseJSONPResponse(tt.jsonp)
|
|
if err == nil {
|
|
t.Errorf("ParseJSONPResponse() should return error for %s", tt.name)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestToOAuth2Config(t *testing.T) {
|
|
config := &OAuthConfig{
|
|
ClientID: "test-client-id",
|
|
ClientSecret: "test-client-secret",
|
|
RedirectURI: "https://myapp.com/callback",
|
|
Scope: "openid,email,profile",
|
|
AuthURL: "https://example.com/oauth/authorize",
|
|
TokenURL: "https://example.com/oauth/token",
|
|
}
|
|
|
|
oauth2Config := ToOAuth2Config(config)
|
|
|
|
if oauth2Config.ClientID != config.ClientID {
|
|
t.Errorf("ClientID = %s, want %s", oauth2Config.ClientID, config.ClientID)
|
|
}
|
|
if oauth2Config.ClientSecret != config.ClientSecret {
|
|
t.Errorf("ClientSecret = %s, want %s", oauth2Config.ClientSecret, config.ClientSecret)
|
|
}
|
|
if oauth2Config.RedirectURL != config.RedirectURI {
|
|
t.Errorf("RedirectURL = %s, want %s", oauth2Config.RedirectURL, config.RedirectURI)
|
|
}
|
|
if len(oauth2Config.Scopes) != 3 {
|
|
t.Errorf("Scopes length = %d, want 3", len(oauth2Config.Scopes))
|
|
}
|
|
if oauth2Config.Endpoint.AuthURL != config.AuthURL {
|
|
t.Errorf("AuthURL = %s, want %s", oauth2Config.Endpoint.AuthURL, config.AuthURL)
|
|
}
|
|
if oauth2Config.Endpoint.TokenURL != config.TokenURL {
|
|
t.Errorf("TokenURL = %s, want %s", oauth2Config.Endpoint.TokenURL, config.TokenURL)
|
|
}
|
|
}
|
|
|
|
func TestGetJSON_ConnectionError(t *testing.T) {
|
|
var result struct{}
|
|
err := GetJSON("http://127.0.0.1:1", &result) // Invalid port
|
|
if err == nil {
|
|
t.Error("GetJSON() should return error for connection failure")
|
|
}
|
|
}
|
|
|
|
func TestPostFormJSON_ConnectionError(t *testing.T) {
|
|
var result struct{}
|
|
err := PostFormJSON("http://127.0.0.1:1", url.Values{}, &result) // Invalid port
|
|
if err == nil {
|
|
t.Error("PostFormJSON() should return error for connection failure")
|
|
}
|
|
}
|