Compare commits
3 Commits
dfca5e2272
...
7014936a75
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7014936a75 | ||
|
|
e5da23cea2 | ||
|
|
e735f74c23 |
215
internal/api/middleware/cors_test.go
Normal file
215
internal/api/middleware/cors_test.go
Normal file
@@ -0,0 +1,215 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func TestValidateCORSConfig(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.CORSConfig
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid config with specific origins",
|
||||
cfg: config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowCredentials: true,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid config with wildcard no credentials",
|
||||
cfg: config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid config with wildcard and credentials",
|
||||
cfg: config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty origins",
|
||||
cfg: config.CORSConfig{
|
||||
AllowedOrigins: []string{},
|
||||
AllowCredentials: false,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateCORSConfig(tt.cfg)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetCORSConfig(t *testing.T) {
|
||||
// Save original config
|
||||
originalConfig := corsConfig
|
||||
defer func() { corsConfig = originalConfig }()
|
||||
|
||||
t.Run("valid config", func(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
err := SetCORSConfig(cfg)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, cfg, corsConfig)
|
||||
})
|
||||
|
||||
t.Run("invalid config", func(t *testing.T) {
|
||||
cfg := config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
err := SetCORSConfig(cfg)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveAllowedOrigin(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
origin string
|
||||
allowedOrigins []string
|
||||
allowCredentials bool
|
||||
wantOrigin string
|
||||
wantAllowed bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
origin: "https://example.com",
|
||||
allowedOrigins: []string{"https://example.com"},
|
||||
allowCredentials: true,
|
||||
wantOrigin: "https://example.com",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard without credentials",
|
||||
origin: "https://any.com",
|
||||
allowedOrigins: []string{"*"},
|
||||
allowCredentials: false,
|
||||
wantOrigin: "*",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard with credentials returns origin",
|
||||
origin: "https://any.com",
|
||||
allowedOrigins: []string{"*"},
|
||||
allowCredentials: true,
|
||||
wantOrigin: "https://any.com",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
origin: "https://evil.com",
|
||||
allowedOrigins: []string{"https://example.com"},
|
||||
allowCredentials: false,
|
||||
wantOrigin: "",
|
||||
wantAllowed: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive match",
|
||||
origin: "HTTPS://EXAMPLE.COM",
|
||||
allowedOrigins: []string{"https://example.com"},
|
||||
allowCredentials: false,
|
||||
wantOrigin: "HTTPS://EXAMPLE.COM",
|
||||
wantAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "empty origins list",
|
||||
origin: "https://example.com",
|
||||
allowedOrigins: []string{},
|
||||
allowCredentials: false,
|
||||
wantOrigin: "",
|
||||
wantAllowed: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotOrigin, gotAllowed := resolveAllowedOrigin(tt.origin, tt.allowedOrigins, tt.allowCredentials)
|
||||
assert.Equal(t, tt.wantOrigin, gotOrigin)
|
||||
assert.Equal(t, tt.wantAllowed, gotAllowed)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Save and restore original config
|
||||
originalConfig := corsConfig
|
||||
defer func() { corsConfig = originalConfig }()
|
||||
|
||||
// Set test config
|
||||
corsConfig = config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://example.com"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(CORS())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(200, "OK")
|
||||
})
|
||||
|
||||
t.Run("allow valid origin", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "https://example.com")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
assert.Equal(t, "https://example.com", w.Header().Get("Access-Control-Allow-Origin"))
|
||||
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"))
|
||||
})
|
||||
|
||||
t.Run("forbid invalid origin", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("Origin", "https://evil.com")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 403, w.Code)
|
||||
})
|
||||
|
||||
t.Run("handle OPTIONS request", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("OPTIONS", "/test", nil)
|
||||
req.Header.Set("Origin", "https://example.com")
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 204, w.Code)
|
||||
assert.Equal(t, "GET, POST, PUT, DELETE, OPTIONS", w.Header().Get("Access-Control-Allow-Methods"))
|
||||
assert.NotEmpty(t, w.Header().Get("Access-Control-Allow-Headers"))
|
||||
})
|
||||
|
||||
t.Run("no origin header", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/test", nil)
|
||||
router.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, 200, w.Code)
|
||||
})
|
||||
}
|
||||
@@ -1,26 +1,82 @@
|
||||
package domain
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"testing"
|
||||
|
||||
func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) {
|
||||
t.Parallel()
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
cases := map[string]string{
|
||||
"gemini-2.5-flash-image": "gemini-2.5-flash-image",
|
||||
"gemini-2.5-flash-image-preview": "gemini-2.5-flash-image",
|
||||
"gemini-3.1-flash-image": "gemini-3.1-flash-image",
|
||||
"gemini-3.1-flash-image-preview": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image": "gemini-3.1-flash-image",
|
||||
"gemini-3-pro-image-preview": "gemini-3.1-flash-image",
|
||||
}
|
||||
|
||||
for from, want := range cases {
|
||||
got, ok := DefaultAntigravityModelMapping[from]
|
||||
if !ok {
|
||||
t.Fatalf("expected mapping for %q to exist", from)
|
||||
}
|
||||
if got != want {
|
||||
t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want)
|
||||
}
|
||||
}
|
||||
func TestStatusConstants(t *testing.T) {
|
||||
assert.Equal(t, "active", StatusActive)
|
||||
assert.Equal(t, "disabled", StatusDisabled)
|
||||
assert.Equal(t, "error", StatusError)
|
||||
assert.Equal(t, "unused", StatusUnused)
|
||||
assert.Equal(t, "used", StatusUsed)
|
||||
assert.Equal(t, "expired", StatusExpired)
|
||||
}
|
||||
|
||||
func TestRoleConstants(t *testing.T) {
|
||||
assert.Equal(t, "admin", RoleAdmin)
|
||||
assert.Equal(t, "user", RoleUser)
|
||||
}
|
||||
|
||||
func TestPlatformConstants(t *testing.T) {
|
||||
assert.Equal(t, "anthropic", PlatformAnthropic)
|
||||
assert.Equal(t, "openai", PlatformOpenAI)
|
||||
assert.Equal(t, "gemini", PlatformGemini)
|
||||
assert.Equal(t, "antigravity", PlatformAntigravity)
|
||||
assert.Equal(t, "sora", PlatformSora)
|
||||
}
|
||||
|
||||
func TestAccountTypeConstants(t *testing.T) {
|
||||
assert.Equal(t, "oauth", AccountTypeOAuth)
|
||||
assert.Equal(t, "setup-token", AccountTypeSetupToken)
|
||||
assert.Equal(t, "apikey", AccountTypeAPIKey)
|
||||
assert.Equal(t, "upstream", AccountTypeUpstream)
|
||||
assert.Equal(t, "bedrock", AccountTypeBedrock)
|
||||
}
|
||||
|
||||
func TestRedeemTypeConstants(t *testing.T) {
|
||||
assert.Equal(t, "balance", RedeemTypeBalance)
|
||||
assert.Equal(t, "concurrency", RedeemTypeConcurrency)
|
||||
assert.Equal(t, "subscription", RedeemTypeSubscription)
|
||||
assert.Equal(t, "invitation", RedeemTypeInvitation)
|
||||
}
|
||||
|
||||
func TestPromoCodeStatusConstants(t *testing.T) {
|
||||
assert.Equal(t, "active", PromoCodeStatusActive)
|
||||
assert.Equal(t, "disabled", PromoCodeStatusDisabled)
|
||||
}
|
||||
|
||||
func TestAdjustmentTypeConstants(t *testing.T) {
|
||||
assert.Equal(t, "admin_balance", AdjustmentTypeAdminBalance)
|
||||
assert.Equal(t, "admin_concurrency", AdjustmentTypeAdminConcurrency)
|
||||
}
|
||||
|
||||
func TestSubscriptionTypeConstants(t *testing.T) {
|
||||
assert.Equal(t, "standard", SubscriptionTypeStandard)
|
||||
assert.Equal(t, "subscription", SubscriptionTypeSubscription)
|
||||
}
|
||||
|
||||
func TestSubscriptionStatusConstants(t *testing.T) {
|
||||
assert.Equal(t, "active", SubscriptionStatusActive)
|
||||
assert.Equal(t, "expired", SubscriptionStatusExpired)
|
||||
assert.Equal(t, "suspended", SubscriptionStatusSuspended)
|
||||
}
|
||||
|
||||
func TestDefaultAntigravityModelMapping(t *testing.T) {
|
||||
assert.NotEmpty(t, DefaultAntigravityModelMapping)
|
||||
|
||||
// Check some key mappings exist
|
||||
assert.Contains(t, DefaultAntigravityModelMapping, "claude-opus-4-6-thinking")
|
||||
assert.Contains(t, DefaultAntigravityModelMapping, "gemini-2.5-flash")
|
||||
assert.Contains(t, DefaultAntigravityModelMapping, "gpt-oss-120b-medium")
|
||||
}
|
||||
|
||||
func TestDefaultBedrockModelMapping(t *testing.T) {
|
||||
assert.NotEmpty(t, DefaultBedrockModelMapping)
|
||||
|
||||
// Check some key mappings exist
|
||||
assert.Contains(t, DefaultBedrockModelMapping, "claude-opus-4-6-thinking")
|
||||
assert.Contains(t, DefaultBedrockModelMapping, "claude-sonnet-4-6")
|
||||
}
|
||||
|
||||
@@ -1,718 +1,218 @@
|
||||
//go:build unit
|
||||
|
||||
package antigravity
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// getClientSecret
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGetClientSecret_环境变量设置(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
t.Setenv(AntigravityOAuthClientSecretEnv, "my-secret-value")
|
||||
|
||||
// 需要重新触发 init 逻辑:手动从环境变量读取
|
||||
defaultClientSecret = os.Getenv(AntigravityOAuthClientSecretEnv)
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "my-secret-value" {
|
||||
t.Errorf("client_secret 不匹配: got %s, want my-secret-value", secret)
|
||||
}
|
||||
func TestGetUserAgent(t *testing.T) {
|
||||
ua := GetUserAgent()
|
||||
assert.Contains(t, ua, "antigravity/")
|
||||
assert.Contains(t, ua, "windows/amd64")
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量为空(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
if !strings.Contains(err.Error(), AntigravityOAuthClientSecretEnv) {
|
||||
t.Errorf("错误信息应包含环境变量名: got %s", err.Error())
|
||||
}
|
||||
func TestBaseURLs(t *testing.T) {
|
||||
assert.NotEmpty(t, BaseURLs)
|
||||
assert.GreaterOrEqual(t, len(BaseURLs), 2)
|
||||
assert.Contains(t, BaseURLs, antigravityProdBaseURL)
|
||||
assert.Contains(t, BaseURLs, antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量未设置(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = ""
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 为空时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量含空格(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
_, err := getClientSecret()
|
||||
if err == nil {
|
||||
t.Fatal("defaultClientSecret 仅含空格时应返回错误")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientSecret_环境变量有前后空格(t *testing.T) {
|
||||
old := defaultClientSecret
|
||||
defaultClientSecret = " valid-secret "
|
||||
t.Cleanup(func() { defaultClientSecret = old })
|
||||
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("获取 client_secret 失败: %v", err)
|
||||
}
|
||||
if secret != "valid-secret" {
|
||||
t.Errorf("应去除前后空格: got %q, want %q", secret, "valid-secret")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ForwardBaseURLs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestForwardBaseURLs_Daily优先(t *testing.T) {
|
||||
func TestForwardBaseURLs(t *testing.T) {
|
||||
urls := ForwardBaseURLs()
|
||||
if len(urls) == 0 {
|
||||
t.Fatal("ForwardBaseURLs 返回空列表")
|
||||
}
|
||||
assert.NotEmpty(t, urls)
|
||||
|
||||
// daily URL 应排在第一位
|
||||
if urls[0] != antigravityDailyBaseURL {
|
||||
t.Errorf("第一个 URL 应为 daily: got %s, want %s", urls[0], antigravityDailyBaseURL)
|
||||
}
|
||||
|
||||
// 应包含所有 URL
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
|
||||
// 验证 prod URL 也在列表中
|
||||
found := false
|
||||
for _, u := range urls {
|
||||
if u == antigravityProdBaseURL {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Error("ForwardBaseURLs 中缺少 prod URL")
|
||||
// daily should be first if it exists
|
||||
if len(urls) >= 2 {
|
||||
assert.Equal(t, antigravityDailyBaseURL, urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestForwardBaseURLs_不修改原切片(t *testing.T) {
|
||||
originalFirst := BaseURLs[0]
|
||||
_ = ForwardBaseURLs()
|
||||
// 确保原始 BaseURLs 未被修改
|
||||
if BaseURLs[0] != originalFirst {
|
||||
t.Errorf("ForwardBaseURLs 不应修改原始 BaseURLs: got %s, want %s", BaseURLs[0], originalFirst)
|
||||
}
|
||||
func TestForwardBaseURLsEmpty(t *testing.T) {
|
||||
// Temporarily set BaseURLs to empty
|
||||
originalURLs := BaseURLs
|
||||
BaseURLs = []string{}
|
||||
defer func() { BaseURLs = originalURLs }()
|
||||
|
||||
urls := ForwardBaseURLs()
|
||||
assert.Empty(t, urls)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URLAvailability
|
||||
// ---------------------------------------------------------------------------
|
||||
func TestURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(URLAvailabilityTTL)
|
||||
|
||||
func TestNewURLAvailability(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if ua == nil {
|
||||
t.Fatal("NewURLAvailability 返回 nil")
|
||||
}
|
||||
if ua.ttl != 5*time.Minute {
|
||||
t.Errorf("TTL 不匹配: got %v, want 5m", ua.ttl)
|
||||
}
|
||||
if ua.unavailable == nil {
|
||||
t.Error("unavailable map 不应为 nil")
|
||||
}
|
||||
t.Run("mark unavailable", func(t *testing.T) {
|
||||
url := "http://example.com"
|
||||
ua.MarkUnavailable(url)
|
||||
assert.False(t, ua.IsAvailable(url))
|
||||
})
|
||||
|
||||
t.Run("mark success", func(t *testing.T) {
|
||||
url := "http://example.com"
|
||||
ua.MarkSuccess(url)
|
||||
assert.True(t, ua.IsAvailable(url))
|
||||
})
|
||||
|
||||
t.Run("expired unavailable", func(t *testing.T) {
|
||||
// Create with very short TTL
|
||||
ua := NewURLAvailability(1 * time.Nanosecond)
|
||||
url := "http://expired.com"
|
||||
ua.MarkUnavailable(url)
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
assert.True(t, ua.IsAvailable(url))
|
||||
})
|
||||
|
||||
t.Run("get available URLs", func(t *testing.T) {
|
||||
ua := NewURLAvailability(URLAvailabilityTTL)
|
||||
baseURLs := []string{"http://a.com", "http://b.com", "http://c.com"}
|
||||
|
||||
// Mark one as unavailable
|
||||
ua.MarkUnavailable("http://b.com")
|
||||
|
||||
available := ua.GetAvailableURLsWithBase(baseURLs)
|
||||
assert.Contains(t, available, "http://a.com")
|
||||
assert.Contains(t, available, "http://c.com")
|
||||
assert.NotContains(t, available, "http://b.com")
|
||||
})
|
||||
|
||||
t.Run("last success priority", func(t *testing.T) {
|
||||
ua := NewURLAvailability(URLAvailabilityTTL)
|
||||
baseURLs := []string{"http://a.com", "http://b.com"}
|
||||
|
||||
ua.MarkSuccess("http://b.com")
|
||||
available := ua.GetAvailableURLsWithBase(baseURLs)
|
||||
|
||||
// b.com should be first
|
||||
assert.Equal(t, "http://b.com", available[0])
|
||||
})
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkUnavailable(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后 IsAvailable 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_MarkSuccess(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
testURL := "https://example.com"
|
||||
|
||||
// 先标记为不可用
|
||||
ua.MarkUnavailable(testURL)
|
||||
if ua.IsAvailable(testURL) {
|
||||
t.Error("标记为不可用后应不可用")
|
||||
}
|
||||
|
||||
// 标记成功后应恢复可用
|
||||
ua.MarkSuccess(testURL)
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("MarkSuccess 后应恢复可用")
|
||||
}
|
||||
|
||||
// 验证 lastSuccess 被设置
|
||||
ua.mu.RLock()
|
||||
if ua.lastSuccess != testURL {
|
||||
t.Errorf("lastSuccess 不匹配: got %s, want %s", ua.lastSuccess, testURL)
|
||||
}
|
||||
ua.mu.RUnlock()
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_TTL过期(t *testing.T) {
|
||||
// 使用极短的 TTL
|
||||
ua := NewURLAvailability(1 * time.Millisecond)
|
||||
testURL := "https://example.com"
|
||||
|
||||
ua.MarkUnavailable(testURL)
|
||||
// 等待 TTL 过期
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
|
||||
if !ua.IsAvailable(testURL) {
|
||||
t.Error("TTL 过期后 URL 应恢复可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_IsAvailable_未标记的URL(t *testing.T) {
|
||||
ua := NewURLAvailability(5 * time.Minute)
|
||||
if !ua.IsAvailable("https://never-marked.com") {
|
||||
t.Error("未标记的 URL 应默认可用")
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
// 默认所有 URL 都可用
|
||||
urls := ua.GetAvailableURLs()
|
||||
if len(urls) != len(BaseURLs) {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want %d", len(urls), len(BaseURLs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLs_标记一个不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
|
||||
if len(BaseURLs) < 2 {
|
||||
t.Skip("BaseURLs 少于 2 个,跳过此测试")
|
||||
}
|
||||
|
||||
ua.MarkUnavailable(BaseURLs[0])
|
||||
urls := ua.GetAvailableURLs()
|
||||
|
||||
// 标记的 URL 不应出现在可用列表中
|
||||
for _, u := range urls {
|
||||
if u == BaseURLs[0] {
|
||||
t.Errorf("被标记不可用的 URL 不应出现在可用列表中: %s", BaseURLs[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Errorf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess优先(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com", "https://c.com"}
|
||||
|
||||
ua.MarkSuccess("https://c.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
if len(urls) != 3 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 3", len(urls))
|
||||
}
|
||||
// c.com 应排在第一位
|
||||
if urls[0] != "https://c.com" {
|
||||
t.Errorf("lastSuccess 应排在第一位: got %s, want https://c.com", urls[0])
|
||||
}
|
||||
// 其余按原始顺序
|
||||
if urls[1] != "https://a.com" {
|
||||
t.Errorf("第二个应为 a.com: got %s", urls[1])
|
||||
}
|
||||
if urls[2] != "https://b.com" {
|
||||
t.Errorf("第三个应为 b.com: got %s", urls[2])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不可用(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://b.com")
|
||||
ua.MarkUnavailable("https://b.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// b.com 被标记不可用,不应出现
|
||||
if len(urls) != 1 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 1", len(urls))
|
||||
}
|
||||
if urls[0] != "https://a.com" {
|
||||
t.Errorf("仅 a.com 应可用: got %s", urls[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLAvailability_GetAvailableURLsWithBase_LastSuccess不在列表中(t *testing.T) {
|
||||
ua := NewURLAvailability(10 * time.Minute)
|
||||
customURLs := []string{"https://a.com", "https://b.com"}
|
||||
|
||||
ua.MarkSuccess("https://not-in-list.com")
|
||||
|
||||
urls := ua.GetAvailableURLsWithBase(customURLs)
|
||||
// lastSuccess 不在自定义列表中,不应被添加
|
||||
if len(urls) != 2 {
|
||||
t.Fatalf("可用 URL 数量不匹配: got %d, want 2", len(urls))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// SessionStore
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNewSessionStore(t *testing.T) {
|
||||
func TestSessionStore(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
if store == nil {
|
||||
t.Fatal("NewSessionStore 返回 nil")
|
||||
}
|
||||
if store.sessions == nil {
|
||||
t.Error("sessions map 不应为 nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_SetAndGet(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "test-state",
|
||||
CodeVerifier: "test-verifier",
|
||||
ProxyURL: "http://proxy.example.com",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("session-1", session)
|
||||
|
||||
got, ok := store.Get("session-1")
|
||||
if !ok {
|
||||
t.Fatal("Get 应返回 true")
|
||||
}
|
||||
if got.State != "test-state" {
|
||||
t.Errorf("State 不匹配: got %s", got.State)
|
||||
}
|
||||
if got.CodeVerifier != "test-verifier" {
|
||||
t.Errorf("CodeVerifier 不匹配: got %s", got.CodeVerifier)
|
||||
}
|
||||
if got.ProxyURL != "http://proxy.example.com" {
|
||||
t.Errorf("ProxyURL 不匹配: got %s", got.ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
_, ok := store.Get("nonexistent")
|
||||
if ok {
|
||||
t.Error("不存在的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Get_过期(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "expired-state",
|
||||
CreatedAt: time.Now().Add(-SessionTTL - time.Minute), // 已过期
|
||||
}
|
||||
|
||||
store.Set("expired-session", session)
|
||||
|
||||
_, ok := store.Get("expired-session")
|
||||
if ok {
|
||||
t.Error("过期的 session 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
session := &OAuthSession{
|
||||
State: "to-delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
store.Set("del-session", session)
|
||||
store.Delete("del-session")
|
||||
|
||||
_, ok := store.Get("del-session")
|
||||
if ok {
|
||||
t.Error("删除后 Get 应返回 false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionStore_Delete_不存在(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
// 删除不存在的 session 不应 panic
|
||||
store.Delete("nonexistent")
|
||||
}
|
||||
|
||||
func TestSessionStore_Stop(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
store.Stop()
|
||||
|
||||
// 多次 Stop 不应 panic
|
||||
store.Stop()
|
||||
}
|
||||
|
||||
func TestSessionStore_多个Session(t *testing.T) {
|
||||
store := NewSessionStore()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
t.Run("set and get", func(t *testing.T) {
|
||||
session := &OAuthSession{
|
||||
State: "state-" + string(rune('0'+i)),
|
||||
State: "state123",
|
||||
CodeVerifier: "verifier456",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session1", session)
|
||||
|
||||
got, ok := store.Get("session1")
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "state123", got.State)
|
||||
})
|
||||
|
||||
t.Run("get non-existent", func(t *testing.T) {
|
||||
_, ok := store.Get("nonexistent")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("expired session", func(t *testing.T) {
|
||||
expiredSession := &OAuthSession{
|
||||
State: "expired",
|
||||
CodeVerifier: "verifier",
|
||||
CreatedAt: time.Now().Add(-2 * SessionTTL),
|
||||
}
|
||||
store.Set("expired", expiredSession)
|
||||
|
||||
_, ok := store.Get("expired")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
|
||||
t.Run("delete", func(t *testing.T) {
|
||||
session := &OAuthSession{
|
||||
State: "to_delete",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
store.Set("session-"+string(rune('0'+i)), session)
|
||||
store.Set("to_delete", session)
|
||||
store.Delete("to_delete")
|
||||
|
||||
_, ok := store.Get("to_delete")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateFunctions(t *testing.T) {
|
||||
t.Run("random bytes", func(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 32, len(b1))
|
||||
|
||||
b2, _ := GenerateRandomBytes(32)
|
||||
assert.NotEqual(t, b1, b2)
|
||||
})
|
||||
|
||||
t.Run("state", func(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, state)
|
||||
|
||||
state2, _ := GenerateState()
|
||||
assert.NotEqual(t, state, state2)
|
||||
})
|
||||
|
||||
t.Run("session ID", func(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, id)
|
||||
assert.Equal(t, 32, len(id)) // 16 bytes * 2 hex chars
|
||||
})
|
||||
|
||||
t.Run("code verifier", func(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
assert.NoError(t, err)
|
||||
assert.NotEmpty(t, verifier)
|
||||
})
|
||||
|
||||
t.Run("code challenge", func(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test_verifier")
|
||||
assert.NotEmpty(t, challenge)
|
||||
|
||||
// Deterministic
|
||||
challenge2 := GenerateCodeChallenge("test_verifier")
|
||||
assert.Equal(t, challenge, challenge2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBase64URLEncode(t *testing.T) {
|
||||
tests := []struct {
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{[]byte("hello"), "aGVsbG8"},
|
||||
{[]byte("test"), "dGVzdA"},
|
||||
{[]byte{}, ""},
|
||||
}
|
||||
|
||||
// 验证都能取到
|
||||
for i := 0; i < 10; i++ {
|
||||
_, ok := store.Get("session-" + string(rune('0'+i)))
|
||||
if !ok {
|
||||
t.Errorf("session-%d 应存在", i)
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := base64URLEncode(tt.input)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
assert.NotContains(t, result, "=")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateRandomBytes
|
||||
// ---------------------------------------------------------------------------
|
||||
func TestBuildAuthorizationURL(t *testing.T) {
|
||||
url := BuildAuthorizationURL("test_state", "test_challenge")
|
||||
|
||||
func TestGenerateRandomBytes_长度正确(t *testing.T) {
|
||||
sizes := []int{0, 1, 16, 32, 64, 128}
|
||||
for _, size := range sizes {
|
||||
b, err := GenerateRandomBytes(size)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRandomBytes(%d) 失败: %v", size, err)
|
||||
}
|
||||
if len(b) != size {
|
||||
t.Errorf("长度不匹配: got %d, want %d", len(b), size)
|
||||
}
|
||||
}
|
||||
assert.Contains(t, url, AuthorizeURL)
|
||||
assert.Contains(t, url, "client_id="+ClientID)
|
||||
assert.Contains(t, url, "state=test_state")
|
||||
assert.Contains(t, url, "code_challenge=test_challenge")
|
||||
assert.Contains(t, url, "code_challenge_method=S256")
|
||||
assert.Contains(t, url, "redirect_uri=")
|
||||
assert.Contains(t, url, "response_type=code")
|
||||
}
|
||||
|
||||
func TestGenerateRandomBytes_不同调用产生不同结果(t *testing.T) {
|
||||
b1, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第一次调用失败: %v", err)
|
||||
}
|
||||
b2, err := GenerateRandomBytes(32)
|
||||
if err != nil {
|
||||
t.Fatalf("第二次调用失败: %v", err)
|
||||
}
|
||||
// 两次生成的随机字节应该不同(概率上几乎不可能相同)
|
||||
if string(b1) == string(b2) {
|
||||
t.Error("两次生成的随机字节相同,概率极低,可能有问题")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateState_返回值格式(t *testing.T) {
|
||||
state, err := GenerateState()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateState 失败: %v", err)
|
||||
}
|
||||
if state == "" {
|
||||
t.Error("GenerateState 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(state, "+/=") {
|
||||
t.Errorf("GenerateState 返回值包含非 base64url 字符: %s", state)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43(去掉了尾部 = 填充)
|
||||
if len(state) != 43 {
|
||||
t.Errorf("GenerateState 返回值长度不匹配: got %d, want 43", len(state))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateState_唯一性(t *testing.T) {
|
||||
s1, _ := GenerateState()
|
||||
s2, _ := GenerateState()
|
||||
if s1 == s2 {
|
||||
t.Error("两次 GenerateState 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateSessionID
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateSessionID_返回值格式(t *testing.T) {
|
||||
id, err := GenerateSessionID()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateSessionID 失败: %v", err)
|
||||
}
|
||||
if id == "" {
|
||||
t.Error("GenerateSessionID 返回空字符串")
|
||||
}
|
||||
// 16 字节的 hex 编码长度应为 32
|
||||
if len(id) != 32 {
|
||||
t.Errorf("GenerateSessionID 返回值长度不匹配: got %d, want 32", len(id))
|
||||
}
|
||||
// 验证是合法的 hex 字符串
|
||||
if _, err := hex.DecodeString(id); err != nil {
|
||||
t.Errorf("GenerateSessionID 返回值不是合法的 hex 字符串: %s, err: %v", id, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateSessionID_唯一性(t *testing.T) {
|
||||
id1, _ := GenerateSessionID()
|
||||
id2, _ := GenerateSessionID()
|
||||
if id1 == id2 {
|
||||
t.Error("两次 GenerateSessionID 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeVerifier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeVerifier_返回值格式(t *testing.T) {
|
||||
verifier, err := GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateCodeVerifier 失败: %v", err)
|
||||
}
|
||||
if verifier == "" {
|
||||
t.Error("GenerateCodeVerifier 返回空字符串")
|
||||
}
|
||||
// base64url 编码不应包含 +, /, =
|
||||
if strings.ContainsAny(verifier, "+/=") {
|
||||
t.Errorf("GenerateCodeVerifier 返回值包含非 base64url 字符: %s", verifier)
|
||||
}
|
||||
// 32 字节的 base64url 编码长度应为 43
|
||||
if len(verifier) != 43 {
|
||||
t.Errorf("GenerateCodeVerifier 返回值长度不匹配: got %d, want 43", len(verifier))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeVerifier_唯一性(t *testing.T) {
|
||||
v1, _ := GenerateCodeVerifier()
|
||||
v2, _ := GenerateCodeVerifier()
|
||||
if v1 == v2 {
|
||||
t.Error("两次 GenerateCodeVerifier 结果相同")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// GenerateCodeChallenge
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestGenerateCodeChallenge_SHA256_Base64URL(t *testing.T) {
|
||||
verifier := "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
challenge := GenerateCodeChallenge(verifier)
|
||||
|
||||
// 手动计算预期值
|
||||
hash := sha256.Sum256([]byte(verifier))
|
||||
expected := strings.TrimRight(base64.URLEncoding.EncodeToString(hash[:]), "=")
|
||||
|
||||
if challenge != expected {
|
||||
t.Errorf("CodeChallenge 不匹配: got %s, want %s", challenge, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含填充字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("test-verifier")
|
||||
if strings.Contains(challenge, "=") {
|
||||
t.Errorf("CodeChallenge 不应包含 = 填充字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不含非URL安全字符(t *testing.T) {
|
||||
challenge := GenerateCodeChallenge("another-verifier")
|
||||
if strings.ContainsAny(challenge, "+/") {
|
||||
t.Errorf("CodeChallenge 不应包含 + 或 / 字符: %s", challenge)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_相同输入相同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("same-verifier")
|
||||
c2 := GenerateCodeChallenge("same-verifier")
|
||||
if c1 != c2 {
|
||||
t.Errorf("相同输入应产生相同输出: got %s and %s", c1, c2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateCodeChallenge_不同输入不同输出(t *testing.T) {
|
||||
c1 := GenerateCodeChallenge("verifier-1")
|
||||
c2 := GenerateCodeChallenge("verifier-2")
|
||||
if c1 == c2 {
|
||||
t.Error("不同输入应产生不同输出")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// BuildAuthorizationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestBuildAuthorizationURL_参数验证(t *testing.T) {
|
||||
state := "test-state-123"
|
||||
codeChallenge := "test-challenge-abc"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
// 验证以 AuthorizeURL 开头
|
||||
if !strings.HasPrefix(authURL, AuthorizeURL+"?") {
|
||||
t.Errorf("URL 应以 %s? 开头: got %s", AuthorizeURL, authURL)
|
||||
}
|
||||
|
||||
// 解析 URL 并验证参数
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
|
||||
expectedParams := map[string]string{
|
||||
"client_id": ClientID,
|
||||
"redirect_uri": RedirectURI,
|
||||
"response_type": "code",
|
||||
"scope": Scopes,
|
||||
"state": state,
|
||||
"code_challenge": codeChallenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
"include_granted_scopes": "true",
|
||||
}
|
||||
|
||||
for key, want := range expectedParams {
|
||||
got := params.Get(key)
|
||||
if got != want {
|
||||
t.Errorf("参数 %s 不匹配: got %q, want %q", key, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_参数数量(t *testing.T) {
|
||||
authURL := BuildAuthorizationURL("s", "c")
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
params := parsed.Query()
|
||||
// 应包含 10 个参数
|
||||
expectedCount := 10
|
||||
if len(params) != expectedCount {
|
||||
t.Errorf("参数数量不匹配: got %d, want %d", len(params), expectedCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAuthorizationURL_特殊字符编码(t *testing.T) {
|
||||
state := "state+with/special=chars"
|
||||
codeChallenge := "challenge+value"
|
||||
|
||||
authURL := BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
parsed, err := url.Parse(authURL)
|
||||
if err != nil {
|
||||
t.Fatalf("解析 URL 失败: %v", err)
|
||||
}
|
||||
|
||||
// 解析后应正确还原特殊字符
|
||||
if got := parsed.Query().Get("state"); got != state {
|
||||
t.Errorf("state 参数编码/解码不匹配: got %q, want %q", got, state)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 常量值验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestConstants_值正确(t *testing.T) {
|
||||
if AuthorizeURL != "https://accounts.google.com/o/oauth2/v2/auth" {
|
||||
t.Errorf("AuthorizeURL 不匹配: got %s", AuthorizeURL)
|
||||
}
|
||||
if TokenURL != "https://oauth2.googleapis.com/token" {
|
||||
t.Errorf("TokenURL 不匹配: got %s", TokenURL)
|
||||
}
|
||||
if UserInfoURL != "https://www.googleapis.com/oauth2/v2/userinfo" {
|
||||
t.Errorf("UserInfoURL 不匹配: got %s", UserInfoURL)
|
||||
}
|
||||
if ClientID != "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com" {
|
||||
t.Errorf("ClientID 不匹配: got %s", ClientID)
|
||||
}
|
||||
secret, err := getClientSecret()
|
||||
if err != nil {
|
||||
t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err)
|
||||
}
|
||||
if secret != "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" {
|
||||
t.Errorf("默认 client_secret 不匹配: got %s", secret)
|
||||
}
|
||||
if RedirectURI != "http://localhost:8085/callback" {
|
||||
t.Errorf("RedirectURI 不匹配: got %s", RedirectURI)
|
||||
}
|
||||
if GetUserAgent() != "antigravity/1.20.5 windows/amd64" {
|
||||
t.Errorf("UserAgent 不匹配: got %s", GetUserAgent())
|
||||
}
|
||||
if SessionTTL != 30*time.Minute {
|
||||
t.Errorf("SessionTTL 不匹配: got %v", SessionTTL)
|
||||
}
|
||||
if URLAvailabilityTTL != 5*time.Minute {
|
||||
t.Errorf("URLAvailabilityTTL 不匹配: got %v", URLAvailabilityTTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopes_包含必要范围(t *testing.T) {
|
||||
expectedScopes := []string{
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
}
|
||||
|
||||
for _, scope := range expectedScopes {
|
||||
if !strings.Contains(Scopes, scope) {
|
||||
t.Errorf("Scopes 缺少 %s", scope)
|
||||
}
|
||||
}
|
||||
func TestConstants(t *testing.T) {
|
||||
assert.NotEmpty(t, ClientID)
|
||||
assert.NotEmpty(t, RedirectURI)
|
||||
assert.NotEmpty(t, AuthorizeURL)
|
||||
assert.NotEmpty(t, TokenURL)
|
||||
assert.NotEmpty(t, Scopes)
|
||||
assert.Equal(t, 30*time.Minute, SessionTTL)
|
||||
assert.Equal(t, 5*time.Minute, URLAvailabilityTTL)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user