package service import ( "context" "errors" "log" "log/slog" "strconv" "strings" "time" ) const ( geminiTokenRefreshSkew = 3 * time.Minute geminiTokenCacheSkew = 5 * time.Minute ) // GeminiTokenProvider manages access_token for Gemini OAuth accounts. type GeminiTokenProvider struct { accountRepo AccountRepository tokenCache GeminiTokenCache geminiOAuthService *GeminiOAuthService refreshAPI *OAuthRefreshAPI executor OAuthRefreshExecutor refreshPolicy ProviderRefreshPolicy } func NewGeminiTokenProvider( accountRepo AccountRepository, tokenCache GeminiTokenCache, geminiOAuthService *GeminiOAuthService, ) *GeminiTokenProvider { return &GeminiTokenProvider{ accountRepo: accountRepo, tokenCache: tokenCache, geminiOAuthService: geminiOAuthService, refreshPolicy: GeminiProviderRefreshPolicy(), } } // SetRefreshAPI injects unified OAuth refresh API and executor. func (p *GeminiTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) { p.refreshAPI = api p.executor = executor } // SetRefreshPolicy injects caller-side refresh policy. func (p *GeminiTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) { p.refreshPolicy = policy } func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { if account == nil { return "", errors.New("account is nil") } if account.Platform != PlatformGemini || account.Type != AccountTypeOAuth { return "", errors.New("not a gemini oauth account") } cacheKey := GeminiTokenCacheKey(account) // 1) Try cache first. if p.tokenCache != nil { if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { return token, nil } } // 2) Refresh if needed (pre-expiry skew). expiresAt := account.GetCredentialAsTime("expires_at") needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= geminiTokenRefreshSkew if needsRefresh && p.refreshAPI != nil && p.executor != nil { result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, geminiTokenRefreshSkew) if err != nil { if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn { return "", err } } else if result.LockHeld { if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil { if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" { return token, nil } } slog.Debug("gemini_token_lock_held_use_old", "account_id", account.ID) } else { account = result.Account expiresAt = account.GetCredentialAsTime("expires_at") } } else if needsRefresh && p.tokenCache != nil { // Backward-compatible test path when refreshAPI is not injected. locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() } else if lockErr != nil { slog.Warn("gemini_token_lock_failed", "account_id", account.ID, "error", lockErr) } } accessToken := account.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found in credentials") } // project_id is optional now: // - If present: use Code Assist API (requires project_id) // - If absent: use AI Studio API with OAuth token. projectID := strings.TrimSpace(account.GetCredential("project_id")) autoDetectProjectID := account.GetCredential("auto_detect_project_id") == "true" if projectID == "" && autoDetectProjectID { if p.geminiOAuthService == nil { return accessToken, nil } var proxyURL string if account.ProxyID != nil && p.geminiOAuthService.proxyRepo != nil { if proxy, err := p.geminiOAuthService.proxyRepo.GetByID(ctx, *account.ProxyID); err == nil && proxy != nil { proxyURL = proxy.URL() } } detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL) if err != nil { log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err) return accessToken, nil } detected = strings.TrimSpace(detected) tierID = strings.TrimSpace(tierID) if detected != "" { if account.Credentials == nil { account.Credentials = make(map[string]any) } account.Credentials["project_id"] = detected if tierID != "" { account.Credentials["tier_id"] = tierID } _ = p.accountRepo.Update(ctx, account) } } // 3) Populate cache with TTL. if p.tokenCache != nil { latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo) if isStale && latestAccount != nil { slog.Debug("gemini_token_version_stale_use_latest", "account_id", account.ID) accessToken = latestAccount.GetCredential("access_token") if strings.TrimSpace(accessToken) == "" { return "", errors.New("access_token not found after version check") } } else { ttl := 30 * time.Minute if expiresAt != nil { until := time.Until(*expiresAt) switch { case until > geminiTokenCacheSkew: ttl = until - geminiTokenCacheSkew case until > 0: ttl = until default: ttl = time.Minute } } _ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl) } } return accessToken, nil } func GeminiTokenCacheKey(account *Account) string { projectID := strings.TrimSpace(account.GetCredential("project_id")) if projectID != "" { return "gemini:" + projectID } return "gemini:account:" + strconv.FormatInt(account.ID, 10) }