""" Has all /sso/* routes /sso/key/generate - handles user signing in with SSO and redirects to /sso/callback /sso/callback - returns JWT Redirect Response that redirects to LiteLLM UI /sso/debug/login - handles user signing in with SSO and redirects to /sso/debug/callback /sso/debug/callback - returns the OpenID object returned by the SSO provider """ import asyncio import base64 import hashlib import inspect import os import secrets from copy import deepcopy from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast if TYPE_CHECKING: import httpx import jwt from fastapi import APIRouter, Depends, HTTPException, Request, status from fastapi.responses import RedirectResponse import litellm from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.caching import DualCache from litellm.constants import ( LITELLM_UI_SESSION_DURATION, MAX_SPENDLOG_ROWS_TO_QUERY, MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE, MICROSOFT_USER_EMAIL_ATTRIBUTE, MICROSOFT_USER_FIRST_NAME_ATTRIBUTE, MICROSOFT_USER_ID_ATTRIBUTE, MICROSOFT_USER_LAST_NAME_ATTRIBUTE, ) from litellm.litellm_core_utils.dot_notation_indexing import get_nested_value from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, get_async_httpx_client, httpxSpecialProvider, ) from litellm.proxy._types import ( CommonProxyErrors, LiteLLM_UserTable, LitellmUserRoles, Member, NewTeamRequest, NewUserRequest, NewUserResponse, ProxyErrorTypes, ProxyException, SSOUserDefinedValues, TeamMemberAddRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken, get_user_object from litellm.proxy.auth.auth_utils import _has_user_setup_sso from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.admin_ui_utils import ( admin_ui_disabled, show_missing_vars_in_env, ) from litellm.proxy.common_utils.html_forms.jwt_display_template import ( jwt_display_template, ) from litellm.proxy.common_utils.html_forms.ui_login import html_form from litellm.proxy.management_endpoints.internal_user_endpoints import new_user from litellm.proxy.management_endpoints.sso import CustomMicrosoftSSO from litellm.proxy.management_endpoints.sso_helper_utils import ( check_is_admin_only_access, has_admin_ui_access, ) from litellm.proxy.management_endpoints.team_endpoints import new_team, team_member_add from litellm.proxy.management_endpoints.types import ( CustomOpenID, get_litellm_user_role, is_valid_litellm_user_role, ) from litellm.proxy.utils import ( PrismaClient, ProxyLogging, get_custom_url, get_server_root_path, ) from litellm.secret_managers.main import get_secret_bool, str_to_bool from litellm.types.proxy.management_endpoints.ui_sso import * # noqa: F403, F401 from litellm.types.proxy.management_endpoints.ui_sso import ( DefaultTeamSSOParams, MicrosoftGraphAPIUserGroupDirectoryObject, MicrosoftGraphAPIUserGroupResponse, MicrosoftServicePrincipalTeam, RoleMappings, TeamMappings, ) from litellm.types.proxy.ui_sso import ParsedOpenIDResult if TYPE_CHECKING: from fastapi_sso.sso.base import OpenID else: from typing import Any as OpenID router = APIRouter() # OAuth bearer credential fields that must not appear in SSO debug responses # (received_response is included in restricted-group error messages). # Metadata fields (token_type, expires_in, scope) are intentionally kept so # response convertors see the same fields in the PKCE path as in the non-PKCE path. _OAUTH_TOKEN_FIELDS = frozenset({"access_token", "id_token", "refresh_token"}) def normalize_email(email: Optional[str]) -> Optional[str]: """ Normalize email address to lowercase for consistent storage and comparison. Email addresses should be treated as case-insensitive for SSO purposes, even though RFC 5321 technically allows case-sensitive local parts. This prevents issues where SSO providers return emails with different casing than what's stored in the database. Args: email: Email address to normalize, can be None Returns: Lowercased email address, or None if input is None """ if email is None: return None return email.lower() if isinstance(email, str) else email def determine_role_from_groups( user_groups: List[str], role_mappings: "RoleMappings", ) -> Optional[LitellmUserRoles]: """ Determine the highest privilege role for a user based on their groups. Role hierarchy (highest to lowest): - proxy_admin - proxy_admin_viewer - internal_user - internal_user_viewer Args: user_groups: List of group names from the SSO token role_mappings: RoleMappings configuration object Returns: The highest privilege role found, or default_role if no matches, or None """ if not role_mappings.roles: # No role mappings configured, return default_role return role_mappings.default_role # Role hierarchy (highest to lowest) role_hierarchy = [ LitellmUserRoles.PROXY_ADMIN, LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, ] # Convert user_groups to a set for efficient lookup user_groups_set = set(user_groups) if isinstance(user_groups, list) else set() # Find the highest privilege role the user belongs to for role in role_hierarchy: if role in role_mappings.roles: role_groups = role_mappings.roles[role] if isinstance(role_groups, list) and user_groups_set.intersection( set(role_groups) ): verbose_proxy_logger.debug( f"User groups {user_groups} matched role '{role.value}' via groups: {role_groups}" ) return role # No matching groups found, return default_role verbose_proxy_logger.debug( f"User groups {user_groups} did not match any role mappings, using default_role: {role_mappings.default_role}" ) return role_mappings.default_role def process_sso_jwt_access_token( access_token_str: Optional[str], sso_jwt_handler: Optional[JWTHandler], result: Union[OpenID, dict, None], role_mappings: Optional["RoleMappings"] = None, ) -> None: """ Process SSO JWT access token and extract team IDs and user role if available. This function decodes the JWT access token and extracts team IDs and user role, then sets them on the result object. Role extraction from the access token is needed because some SSO providers (e.g., Keycloak) do not include role claims in the UserInfo endpoint response. Args: access_token_str: The JWT access token string sso_jwt_handler: SSO-specific JWT handler for team ID extraction result: The SSO result object to update with team IDs and role role_mappings: Optional role mappings configuration for group-based role determination """ if access_token_str and result: import jwt try: access_token_payload = jwt.decode( access_token_str, options={"verify_signature": False} ) except jwt.exceptions.DecodeError: verbose_proxy_logger.debug( "Access token is not a valid JWT (possibly an opaque token), skipping JWT-based extraction" ) return # Extract team IDs from access token if sso_jwt_handler is available if sso_jwt_handler: if isinstance(result, dict): result_team_ids: Optional[List[str]] = result.get("team_ids", []) if not result_team_ids: team_ids = sso_jwt_handler.get_team_ids_from_jwt( access_token_payload ) result["team_ids"] = team_ids else: result_team_ids = getattr(result, "team_ids", []) if result else [] if not result_team_ids: team_ids = sso_jwt_handler.get_team_ids_from_jwt( access_token_payload ) setattr(result, "team_ids", team_ids) # Extract user role from access token if not already set from UserInfo existing_role = ( result.get("user_role") if isinstance(result, dict) else getattr(result, "user_role", None) ) if existing_role is None: user_role: Optional[LitellmUserRoles] = None # Try role_mappings first (group-based role determination) if role_mappings is not None and role_mappings.roles: group_claim = role_mappings.group_claim user_groups_raw: Any = get_nested_value( access_token_payload, group_claim ) user_groups: List[str] = [] if isinstance(user_groups_raw, list): user_groups = [str(g) for g in user_groups_raw] elif isinstance(user_groups_raw, str): user_groups = [ g.strip() for g in user_groups_raw.split(",") if g.strip() ] elif user_groups_raw is not None: user_groups = [str(user_groups_raw)] if user_groups: user_role = determine_role_from_groups(user_groups, role_mappings) verbose_proxy_logger.debug( f"Determined role '{user_role}' from access token groups '{user_groups}' using role_mappings" ) elif role_mappings.default_role: user_role = role_mappings.default_role # Fallback: try GENERIC_USER_ROLE_ATTRIBUTE on the access token payload if user_role is None: generic_user_role_attribute_name = os.getenv( "GENERIC_USER_ROLE_ATTRIBUTE", "role" ) user_role_from_token = get_nested_value( access_token_payload, generic_user_role_attribute_name ) if user_role_from_token is not None: user_role = get_litellm_user_role(user_role_from_token) verbose_proxy_logger.debug( f"Extracted role '{user_role}' from access token field '{generic_user_role_attribute_name}'" ) if user_role is not None: if isinstance(result, dict): result["user_role"] = user_role else: setattr(result, "user_role", user_role) verbose_proxy_logger.debug( f"Set user_role='{user_role}' from JWT access token" ) @router.get("/sso/key/generate", tags=["experimental"], include_in_schema=False) async def google_login( request: Request, source: Optional[str] = None, key: Optional[str] = None, existing_key: Optional[str] = None, ): # noqa: PLR0915 """ Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" Example: """ from litellm.proxy.proxy_server import ( premium_user, prisma_client, user_custom_ui_sso_sign_in_handler, ) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) ####### Check if UI is disabled ####### _disable_ui_flag = os.getenv("DISABLE_ADMIN_UI") if _disable_ui_flag is not None: is_disabled = str_to_bool(value=_disable_ui_flag) if is_disabled: return admin_ui_disabled() ####### Check if user is a Enterprise / Premium User ####### if ( microsoft_client_id is not None or google_client_id is not None or generic_client_id is not None ): if premium_user is not True: # Check if under 'free SSO user' limit if prisma_client is not None: total_users = await prisma_client.db.litellm_usertable.count() if total_users and total_users > 5: raise ProxyException( message="You must be a LiteLLM Enterprise user to use SSO for more than 5 users. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", type=ProxyErrorTypes.auth_error, param="premium_user", code=status.HTTP_403_FORBIDDEN, ) else: raise ProxyException( message=CommonProxyErrors.db_not_connected_error.value, type=ProxyErrorTypes.auth_error, param="premium_user", code=status.HTTP_403_FORBIDDEN, ) ####### Detect DB + MASTER KEY in .env ####### missing_env_vars = show_missing_vars_in_env() if missing_env_vars is not None: return missing_env_vars ui_username = os.getenv("UI_USERNAME") # get url from request - always use regular callback, but set state for CLI redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( request=request, sso_callback_route="sso/callback", existing_key=existing_key, ) # Store CLI key in state for OAuth flow cli_state: Optional[str] = SSOAuthenticationHandler._get_cli_state( source=source, key=key, existing_key=existing_key, ) # check if user defined a custom auth sso sign in handler, if yes, use it if user_custom_ui_sso_sign_in_handler is not None: try: from litellm_enterprise.proxy.auth.custom_sso_handler import ( # type: ignore[import-untyped] EnterpriseCustomSSOHandler, ) return await EnterpriseCustomSSOHandler.handle_custom_ui_sso_sign_in( request=request, ) except ImportError: raise ValueError( "Enterprise features are not available. Custom UI SSO sign-in requires LiteLLM Enterprise." ) # Check if we should use SSO handler if ( SSOAuthenticationHandler.should_use_sso_handler( microsoft_client_id=microsoft_client_id, google_client_id=google_client_id, generic_client_id=generic_client_id, ) is True ): verbose_proxy_logger.info(f"Redirecting to SSO login for {redirect_url}") return await SSOAuthenticationHandler.get_sso_login_redirect( redirect_url=redirect_url, microsoft_client_id=microsoft_client_id, google_client_id=google_client_id, generic_client_id=generic_client_id, state=cli_state, ) elif ui_username is not None: # No Google, Microsoft SSO # Use UI Credentials set in .env from fastapi.responses import HTMLResponse return HTMLResponse(content=html_form, status_code=200) else: from fastapi.responses import HTMLResponse return HTMLResponse(content=html_form, status_code=200) def generic_response_convertor( response, jwt_handler: JWTHandler, sso_jwt_handler: Optional[JWTHandler] = None, role_mappings: Optional["RoleMappings"] = None, team_mappings: Optional["TeamMappings"] = None, ) -> CustomOpenID: generic_user_id_attribute_name = os.getenv( "GENERIC_USER_ID_ATTRIBUTE", "preferred_username" ) generic_user_display_name_attribute_name = os.getenv( "GENERIC_USER_DISPLAY_NAME_ATTRIBUTE", "sub" ) generic_user_email_attribute_name = os.getenv( "GENERIC_USER_EMAIL_ATTRIBUTE", "email" ) generic_user_first_name_attribute_name = os.getenv( "GENERIC_USER_FIRST_NAME_ATTRIBUTE", "first_name" ) generic_user_last_name_attribute_name = os.getenv( "GENERIC_USER_LAST_NAME_ATTRIBUTE", "last_name" ) generic_provider_attribute_name = os.getenv( "GENERIC_USER_PROVIDER_ATTRIBUTE", "provider" ) generic_user_role_attribute_name = os.getenv("GENERIC_USER_ROLE_ATTRIBUTE", "role") generic_user_extra_attributes = os.getenv("GENERIC_USER_EXTRA_ATTRIBUTES", None) verbose_proxy_logger.debug( f" generic_user_id_attribute_name: {generic_user_id_attribute_name}\n generic_user_email_attribute_name: {generic_user_email_attribute_name}" ) all_teams = [] if sso_jwt_handler is not None: team_ids = sso_jwt_handler.get_team_ids_from_jwt(cast(dict, response)) all_teams.extend(team_ids) if team_mappings is not None and team_mappings.team_ids_jwt_field is not None: team_ids_from_db_mapping: Optional[List[str]] = get_nested_value( data=cast(dict, response), key_path=team_mappings.team_ids_jwt_field, default=[], ) if team_ids_from_db_mapping: all_teams.extend(team_ids_from_db_mapping) verbose_proxy_logger.debug( f"Loaded team_ids from DB team_mappings.team_ids_jwt_field='{team_mappings.team_ids_jwt_field}': {team_ids_from_db_mapping}" ) else: team_ids = jwt_handler.get_team_ids_from_jwt(cast(dict, response)) all_teams.extend(team_ids) # Determine user role based on role_mappings if available # Only apply role_mappings for GENERIC SSO provider user_role: Optional[LitellmUserRoles] = None if role_mappings is not None and role_mappings.provider.lower() in [ "generic", "okta", ]: # Use role_mappings to determine role from groups group_claim = role_mappings.group_claim user_groups_raw: Any = get_nested_value(response, group_claim) # Handle different formats: could be a list, string (comma-separated), or single value user_groups: List[str] = [] if isinstance(user_groups_raw, list): user_groups = [str(g) for g in user_groups_raw] elif isinstance(user_groups_raw, str): # Handle comma-separated string user_groups = [g.strip() for g in user_groups_raw.split(",") if g.strip()] elif user_groups_raw is not None: # Single value user_groups = [str(user_groups_raw)] if user_groups: user_role = determine_role_from_groups(user_groups, role_mappings) verbose_proxy_logger.debug( f"Determined role '{user_role.value if user_role else None}' from groups '{user_groups}' using role_mappings" ) else: # No groups found, use default_role user_role = role_mappings.default_role verbose_proxy_logger.debug( f"No groups found in '{group_claim}', using default_role: {role_mappings.default_role}" ) # Fallback to existing logic if role_mappings not used if user_role is None: user_role_from_sso = get_nested_value( response, generic_user_role_attribute_name ) if user_role_from_sso is not None: role = get_litellm_user_role(user_role_from_sso) if role is not None: user_role = role verbose_proxy_logger.debug( f"Found valid LitellmUserRoles '{role.value}' from SSO attribute '{generic_user_role_attribute_name}'" ) # Build extra_fields dict from GENERIC_USER_EXTRA_ATTRIBUTES if specified extra_fields: Optional[Dict[str, Any]] = None if generic_user_extra_attributes: extra_fields = {} for attr_name in generic_user_extra_attributes.split(","): attr_name = attr_name.strip() extra_fields[attr_name] = get_nested_value(response, attr_name) return CustomOpenID( id=get_nested_value(response, generic_user_id_attribute_name), display_name=get_nested_value( response, generic_user_display_name_attribute_name ), email=normalize_email( get_nested_value(response, generic_user_email_attribute_name) ), first_name=get_nested_value(response, generic_user_first_name_attribute_name), last_name=get_nested_value(response, generic_user_last_name_attribute_name), provider=get_nested_value(response, generic_provider_attribute_name), team_ids=all_teams, user_role=user_role, extra_fields=extra_fields, ) def _setup_generic_sso_env_vars( generic_client_id: str, redirect_url: str ) -> Tuple[str, List[str], str, str, str, bool]: """Setup and validate Generic SSO environment variables.""" generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ") generic_authorization_endpoint = os.getenv("GENERIC_AUTHORIZATION_ENDPOINT", None) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) generic_include_client_id = ( os.getenv("GENERIC_INCLUDE_CLIENT_ID", "false").lower() == "true" ) # Validate required environment variables if generic_client_secret is None: raise ProxyException( message="GENERIC_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_authorization_endpoint is None: raise ProxyException( message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_AUTHORIZATION_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_token_endpoint is None: raise ProxyException( message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_TOKEN_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_userinfo_endpoint is None: raise ProxyException( message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_USERINFO_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) verbose_proxy_logger.debug( f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" ) verbose_proxy_logger.debug( f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" ) return ( generic_client_secret, generic_scope, generic_authorization_endpoint, generic_token_endpoint, generic_userinfo_endpoint, generic_include_client_id, ) async def _setup_team_mappings() -> Optional["TeamMappings"]: """Setup team mappings from SSO database settings.""" team_mappings: Optional["TeamMappings"] = None try: from litellm.proxy.utils import get_prisma_client_or_throw prisma_client = get_prisma_client_or_throw( "Prisma client is None, connect a database to your proxy" ) sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique( where={"id": "sso_config"} ) if sso_db_record and sso_db_record.sso_settings: sso_settings_dict = dict(sso_db_record.sso_settings) team_mappings_data = sso_settings_dict.get("team_mappings") if team_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import TeamMappings if isinstance(team_mappings_data, dict): team_mappings = TeamMappings(**team_mappings_data) elif isinstance(team_mappings_data, TeamMappings): team_mappings = team_mappings_data if team_mappings and team_mappings.team_ids_jwt_field: verbose_proxy_logger.debug( f"Loaded team_mappings with team_ids_jwt_field: '{team_mappings.team_ids_jwt_field}'" ) except Exception as e: verbose_proxy_logger.debug( f"Could not load team_mappings from database: {e}. Continuing with config-based team mapping." ) return team_mappings async def _setup_role_mappings() -> Optional["RoleMappings"]: """Setup role mappings from SSO database settings.""" role_mappings: Optional["RoleMappings"] = None try: from litellm.proxy.utils import get_prisma_client_or_throw prisma_client = get_prisma_client_or_throw( "Prisma client is None, connect a database to your proxy" ) sso_db_record = await prisma_client.db.litellm_ssoconfig.find_unique( where={"id": "sso_config"} ) if sso_db_record and sso_db_record.sso_settings: sso_settings_dict = dict(sso_db_record.sso_settings) role_mappings_data = sso_settings_dict.get("role_mappings") if role_mappings_data: from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings if isinstance(role_mappings_data, dict): role_mappings = RoleMappings(**role_mappings_data) elif isinstance(role_mappings_data, RoleMappings): role_mappings = role_mappings_data if role_mappings: verbose_proxy_logger.debug( f"Loaded role_mappings for provider '{role_mappings.provider}'" ) except Exception as e: verbose_proxy_logger.debug( f"Could not load role_mappings from database: {e}. Continuing with existing role logic." ) generic_role_mappings = os.getenv("GENERIC_ROLE_MAPPINGS_ROLES", None) generic_role_mappings_group_claim = os.getenv( "GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None ) generic_role_mappoings_default_role = os.getenv( "GENERIC_ROLE_MAPPINGS_DEFAULT_ROLE", None ) if generic_role_mappings is not None: verbose_proxy_logger.debug( "Found role_mappings for generic provider in environment variables" ) import ast try: generic_user_role_mappings_data: Dict[ LitellmUserRoles, List[str] ] = ast.literal_eval(generic_role_mappings) if isinstance(generic_user_role_mappings_data, dict): from litellm.types.proxy.management_endpoints.ui_sso import RoleMappings role_mappings_data = { "provider": "generic", "group_claim": generic_role_mappings_group_claim, "default_role": generic_role_mappoings_default_role, "roles": generic_user_role_mappings_data, } role_mappings = RoleMappings(**role_mappings_data) verbose_proxy_logger.debug( f"Loaded role_mappings from environments for provider '{role_mappings.provider}'." ) return role_mappings except TypeError as e: verbose_proxy_logger.warning( f"Error decoding role mappings from environment variables: {e}. Continuing with existing role logic." ) return role_mappings def _parse_generic_sso_headers() -> dict: """Parse comma-separated GENERIC_SSO_HEADERS env var into a dict.""" raw = os.getenv("GENERIC_SSO_HEADERS", None) if raw is None: return {} result: Dict[str, str] = {} for header in raw.split(","): header = header.strip() if header: key, value = header.split("=") result[key] = value return result def _handle_generic_sso_error( e: Exception, generic_authorization_endpoint: Optional[str], generic_token_endpoint: Optional[str], additional_headers: dict, ) -> None: """Handle errors from generic SSO verify_and_process. Always re-raises.""" error_message = str(e) # Surface a helpful PKCE misconfiguration hint only when: # 1. The error mentions PKCE/code verifier, AND # 2. PKCE is not currently configured (GENERIC_CLIENT_USE_PKCE != true) pkce_configured = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true" if not pkce_configured and ( "PKCE" in error_message or "code verifier" in error_message.lower() ): is_okta = ( generic_authorization_endpoint and "okta" in generic_authorization_endpoint.lower() ) or (generic_token_endpoint and "okta" in generic_token_endpoint.lower()) provider_name = "Okta" if is_okta else "Your OAuth provider" detailed_message = ( f"SSO authentication failed: {provider_name} requires PKCE (Proof Key for Code Exchange) " f"but it's not enabled in your LiteLLM configuration.\n\n" f"SOLUTION: Add this environment variable and restart your proxy:\n" f" GENERIC_CLIENT_USE_PKCE=true\n\n" ) if is_okta: detailed_message += ( "For AWS ECS: Add the environment variable to your task definition.\n" "For Docker: Add -e GENERIC_CLIENT_USE_PKCE=true to your docker run command.\n" "For .env file: Add GENERIC_CLIENT_USE_PKCE=true to your .env file.\n\n" ) detailed_message += f"Original error: {error_message}" raise ProxyException( message=detailed_message, type=ProxyErrorTypes.auth_error, param="GENERIC_CLIENT_USE_PKCE", code=status.HTTP_401_UNAUTHORIZED, ) if isinstance(e, ProxyException): verbose_proxy_logger.error( "SSO authentication failed: %s. Passed in headers: %s", e, additional_headers, ) else: verbose_proxy_logger.exception( "Error verifying and processing generic SSO: %s. Passed in headers: %s", e, additional_headers, ) raise e async def get_generic_sso_response( request: Request, jwt_handler: JWTHandler, sso_jwt_handler: Optional[ JWTHandler ], # sso specific jwt handler - used for restricted sso group access control generic_client_id: str, redirect_url: str, ) -> Tuple[Union[OpenID, dict], Optional[dict]]: # return received response # make generic sso provider from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider received_response: Optional[dict] = None # Setup environment variables ( generic_client_secret, generic_scope, generic_authorization_endpoint, generic_token_endpoint, generic_userinfo_endpoint, generic_include_client_id, ) = _setup_generic_sso_env_vars(generic_client_id, redirect_url) discovery = DiscoveryDocument( authorization_endpoint=generic_authorization_endpoint, token_endpoint=generic_token_endpoint, userinfo_endpoint=generic_userinfo_endpoint, ) role_mappings = await _setup_role_mappings() team_mappings = await _setup_team_mappings() def response_convertor(response, client): nonlocal received_response # return for user debugging received_response = response return generic_response_convertor( response=response, jwt_handler=jwt_handler, sso_jwt_handler=sso_jwt_handler, role_mappings=role_mappings, team_mappings=team_mappings, ) SSOProvider = create_provider( name="oidc", discovery_document=discovery, response_convertor=response_convertor, ) generic_sso = SSOProvider( client_id=generic_client_id, client_secret=generic_client_secret, redirect_uri=redirect_url, allow_insecure_http=True, scope=generic_scope, ) verbose_proxy_logger.debug("calling generic_sso.verify_and_process") additional_generic_sso_headers_dict = _parse_generic_sso_headers() code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking try: token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters( request=request, generic_include_client_id=generic_include_client_id, ) # Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso code_verifier = token_exchange_params.pop("code_verifier", None) pkce_cache_key = token_exchange_params.pop("_pkce_cache_key", None) # Get authorization code from query params (only used in the PKCE path below; # the non-PKCE path delegates to verify_and_process which handles OAuth error # callbacks — user-denied, CSRF mismatch — internally). authorization_code = request.query_params.get("code") if code_verifier: if not authorization_code: raise ProxyException( message="Missing authorization code in callback", type=ProxyErrorTypes.auth_error, param="code", code=status.HTTP_400_BAD_REQUEST, ) if not generic_client_id: raise ProxyException( message="GENERIC_CLIENT_ID must be set when PKCE is enabled", type=ProxyErrorTypes.auth_error, param="GENERIC_CLIENT_ID", code=status.HTTP_401_UNAUTHORIZED, ) if not generic_token_endpoint: raise ProxyException( message="GENERIC_TOKEN_ENDPOINT must be set when PKCE is enabled", type=ProxyErrorTypes.auth_error, param="GENERIC_TOKEN_ENDPOINT", code=status.HTTP_401_UNAUTHORIZED, ) # All guards above raise, so authorization_code is a non-empty str here. # Use an explicit type guard rather than assert (assert is a no-op with -O). if not isinstance(authorization_code, str): raise ProxyException( message="Missing authorization code in callback", type=ProxyErrorTypes.auth_error, param="code", code=status.HTTP_400_BAD_REQUEST, ) combined_response = await SSOAuthenticationHandler._pkce_token_exchange( authorization_code=authorization_code, code_verifier=code_verifier, client_id=generic_client_id, client_secret=generic_client_secret, token_endpoint=generic_token_endpoint, userinfo_endpoint=generic_userinfo_endpoint, include_client_id=generic_include_client_id, redirect_url=redirect_url, additional_headers=additional_generic_sso_headers_dict, ) # Pass the full response so custom response_convertor implementations # can access all fields (including id_token for claim extraction). result = response_convertor(combined_response, generic_sso) # Strip bearer credentials from combined_response before storing in # received_response. received_response may appear in restricted-group # error messages — bearer tokens (access_token, id_token, refresh_token) # must not be exposed to callers. # Assign directly rather than relying on nonlocal mutation so that Pyright # can track that received_response is non-None from this point on. received_response = { k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS } # In the PKCE path verify_and_process is skipped, so generic_sso.access_token # is never set. Read the token directly from the exchange response instead so # process_sso_jwt_access_token can extract JWT-embedded roles/teams. access_token_str: Optional[str] = combined_response.get("access_token") else: result = await generic_sso.verify_and_process( request, params=token_exchange_params, headers=additional_generic_sso_headers_dict, ) access_token_str = generic_sso.access_token process_sso_jwt_access_token( access_token_str, sso_jwt_handler, result, role_mappings=role_mappings ) # Delete the single-use PKCE verifier only after all downstream processing # (response_convertor and process_sso_jwt_access_token) has completed # successfully. Deleting earlier would consume the verifier on a transient # failure, forcing the user to restart the entire OAuth flow from scratch. if pkce_cache_key: await SSOAuthenticationHandler._delete_pkce_verifier(pkce_cache_key) except Exception as e: _handle_generic_sso_error( e, generic_authorization_endpoint, generic_token_endpoint, additional_generic_sso_headers_dict, ) verbose_proxy_logger.debug("generic result: %s", result) return result or {}, received_response async def create_team_member_add_task(team_id, user_info): """Create a task for adding a member to a team.""" try: member = Member(user_id=user_info.user_id, role="user") team_member_add_request = TeamMemberAddRequest( member=member, team_id=team_id, ) return await team_member_add( data=team_member_add_request, user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), ) except Exception as e: verbose_proxy_logger.debug( f"[Non-Blocking] Error trying to add sso user to db: {e}" ) async def add_missing_team_member( user_info: Union[NewUserResponse, LiteLLM_UserTable], sso_teams: List[str] ): """ - Get missing teams (diff b/w user_info.team_ids and sso_teams) - Add missing user to missing teams """ # Handle None as empty list for new users user_teams = user_info.teams if user_info.teams is not None else [] missing_teams = set(sso_teams) - set(user_teams) missing_teams_list = list(missing_teams) tasks = [] tasks = [ create_team_member_add_task(team_id, user_info) for team_id in missing_teams_list ] try: await asyncio.gather(*tasks) except Exception as e: verbose_proxy_logger.debug( f"[Non-Blocking] Error trying to add sso user to db: {e}" ) def get_disabled_non_admin_personal_key_creation(): key_generation_settings = litellm.key_generation_settings if key_generation_settings is None: return False personal_key_generation = ( key_generation_settings.get("personal_key_generation") or {} ) allowed_user_roles = personal_key_generation.get("allowed_user_roles") or [] return bool("proxy_admin" in allowed_user_roles) async def get_existing_user_info_from_db( user_id: Optional[str], user_email: Optional[str], prisma_client: PrismaClient, user_api_key_cache: DualCache, proxy_logging_obj: ProxyLogging, ) -> Optional[LiteLLM_UserTable]: try: user_info = await get_user_object( user_id=user_id, user_email=user_email, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, user_id_upsert=False, parent_otel_span=None, proxy_logging_obj=proxy_logging_obj, sso_user_id=user_id, ) except Exception as e: verbose_proxy_logger.debug(f"Error getting user object: {e}") user_info = None return user_info async def get_user_info_from_db( result: Union[CustomOpenID, OpenID, dict], prisma_client: PrismaClient, user_api_key_cache: DualCache, proxy_logging_obj: ProxyLogging, user_email: Optional[str], user_defined_values: Optional[SSOUserDefinedValues], alternate_user_id: Optional[str] = None, ) -> Optional[Union[LiteLLM_UserTable, NewUserResponse]]: try: potential_user_ids = [] if alternate_user_id is not None: potential_user_ids.append(alternate_user_id) if not isinstance(result, dict): _id = getattr(result, "id", None) if _id is not None and isinstance(_id, str): potential_user_ids.append(_id) else: _id = result.get("id", None) if _id is not None and isinstance(_id, str): potential_user_ids.append(_id) user_email = normalize_email( getattr(result, "email", None) if not isinstance(result, dict) else result.get("email", None) ) user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]] = None for user_id in potential_user_ids: user_info = await get_existing_user_info_from_db( user_id=user_id, user_email=user_email, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, ) if user_info is not None: break verbose_proxy_logger.debug( f"user_info: {user_info}; litellm.default_internal_user_params: {litellm.default_internal_user_params}" ) # Upsert SSO User to LiteLLM DB user_info = await SSOAuthenticationHandler.upsert_sso_user( result=result, user_info=user_info, user_email=user_email, user_defined_values=user_defined_values, prisma_client=prisma_client, ) await SSOAuthenticationHandler.add_user_to_teams_from_sso_response( result=result, user_info=user_info, ) return user_info except Exception as e: verbose_proxy_logger.exception( f"[Non-Blocking] Error trying to add sso user to db: {e}" ) return None def _should_use_role_from_sso_response(sso_role: Optional[str]) -> bool: """returns true if SSO upsert should use the 'role' defined on the SSO response""" if sso_role is None: return False if not is_valid_litellm_user_role(sso_role): verbose_proxy_logger.debug( f"SSO role '{sso_role}' is not a valid LiteLLM user role. " "Ignoring role from SSO response. See LitellmUserRoles enum for valid roles." ) return False return True def _build_sso_user_update_data( result: Optional[Union["CustomOpenID", OpenID, dict]], user_email: Optional[str], user_id: Optional[str], ) -> dict: """ Build the update data dictionary for SSO user upsert. Args: result: The SSO response containing user information user_email: The user's email from SSO user_id: The user's ID for logging purposes Returns: dict: Update data containing user_email and optionally user_role if valid """ update_data: dict = {"user_email": normalize_email(user_email)} # Get SSO role from result and include if valid sso_role = getattr(result, "user_role", None) if sso_role is not None: # Convert enum to string if needed sso_role_str = ( sso_role.value if isinstance(sso_role, LitellmUserRoles) else sso_role ) # Only include if it's a valid LiteLLM role if _should_use_role_from_sso_response(sso_role_str): update_data["user_role"] = sso_role_str verbose_proxy_logger.info( f"Updating user {user_id} role from SSO: {sso_role_str}" ) return update_data def apply_user_info_values_to_sso_user_defined_values( user_info: Optional[Union[LiteLLM_UserTable, NewUserResponse]], user_defined_values: Optional[SSOUserDefinedValues], ) -> Optional[SSOUserDefinedValues]: if user_defined_values is None: return None if user_info is not None and user_info.user_id is not None: user_defined_values["user_id"] = user_info.user_id # SSO role takes precedence - only use DB role if SSO didn't provide one # This ensures SSO is the authoritative source for user roles sso_role = user_defined_values.get("user_role") db_role = user_info.user_role if user_info else None if _should_use_role_from_sso_response(sso_role): # SSO provided a valid role, keep it and log that we're using it verbose_proxy_logger.info( f"Using SSO role: {sso_role} (DB role was: {db_role})" ) else: # SSO didn't provide a valid role, fall back to DB role or default if user_info is None or user_info.user_role is None: user_defined_values[ "user_role" ] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value verbose_proxy_logger.debug( "No SSO or DB role found, using default: INTERNAL_USER_VIEW_ONLY" ) else: user_defined_values["user_role"] = user_info.user_role verbose_proxy_logger.debug(f"Using DB role: {user_info.user_role}") # Preserve the user's existing models from the database if user_info is not None and hasattr(user_info, "models") and user_info.models: user_defined_values["models"] = user_info.models return user_defined_values async def check_and_update_if_proxy_admin_id( user_role: str, user_id: str, prisma_client: Optional[PrismaClient] ): """ - Check if user role in DB is admin - If not, update user role in DB to admin role """ proxy_admin_id = os.getenv("PROXY_ADMIN_ID") if proxy_admin_id is not None and proxy_admin_id == user_id: if user_role and user_role == LitellmUserRoles.PROXY_ADMIN.value: return user_role if prisma_client: await prisma_client.db.litellm_usertable.update( where={"user_id": user_id}, data={"user_role": LitellmUserRoles.PROXY_ADMIN.value}, ) user_role = LitellmUserRoles.PROXY_ADMIN.value return user_role @router.get("/sso/callback", tags=["experimental"], include_in_schema=False) async def auth_callback(request: Request, state: Optional[str] = None): # noqa: PLR0915 """Verify login""" verbose_proxy_logger.info(f"Starting SSO callback with state: {state}") # Check if this is a CLI login (state starts with our CLI prefix) from litellm.constants import LITELLM_CLI_SESSION_TOKEN_PREFIX from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.proxy_server import ( general_settings, jwt_handler, master_key, prisma_client, user_api_key_cache, ) if prisma_client is None: raise HTTPException( status_code=500, detail=CommonProxyErrors.db_not_connected_error.value ) sso_jwt_handler: Optional[JWTHandler] = None ui_access_mode = general_settings.get("ui_access_mode", None) if ui_access_mode is not None and isinstance(ui_access_mode, dict): sso_jwt_handler = JWTHandler() sso_jwt_handler.update_environment( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, litellm_jwtauth=LiteLLM_JWTAuth( team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get( "sso_group_jwt_field", None ), ), leeway=0, ) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) received_response: Optional[dict] = None # get url from request if master_key is None: raise ProxyException( message="Master Key not set for Proxy. Please set Master Key to use Admin UI. Set `LITELLM_MASTER_KEY` in .env or set general_settings:master_key in config.yaml. https://docs.litellm.ai/docs/proxy/virtual_keys. If set, use `--detailed_debug` to debug issue.", type=ProxyErrorTypes.auth_error, param="master_key", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( request=request, sso_callback_route="sso/callback" ) verbose_proxy_logger.info(f"Redirecting to {redirect_url}") result = None if google_client_id is not None: result = await GoogleSSOHandler.get_google_callback_response( request=request, google_client_id=google_client_id, redirect_url=redirect_url, ) elif microsoft_client_id is not None: result = await MicrosoftSSOHandler.get_microsoft_callback_response( request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, ) elif generic_client_id is not None: result, received_response = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, sso_jwt_handler=sso_jwt_handler, ) if result is None: raise HTTPException( status_code=401, detail="Result not returned by SSO provider.", ) if state and state.startswith(f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:"): # Extract the key ID and existing_key from the state # State format: {PREFIX}:{key}:{existing_key} or {PREFIX}:{key} state_parts = state.split(":", 2) # Split into max 3 parts key_id = state_parts[1] if len(state_parts) > 1 else None existing_key = state_parts[2] if len(state_parts) > 2 else None verbose_proxy_logger.info( f"CLI SSO callback detected for key: {key_id}, existing_key: {existing_key}" ) return await cli_sso_callback( request=request, key=key_id, existing_key=existing_key, result=result ) return await SSOAuthenticationHandler.get_redirect_response_from_openid( result=result, request=request, received_response=received_response, generic_client_id=generic_client_id, ui_access_mode=ui_access_mode, ) async def cli_sso_callback( request: Request, key: Optional[str] = None, existing_key: Optional[str] = None, result: Optional[Union[OpenID, dict]] = None, ): """CLI SSO callback - stores session info for JWT generation on polling""" verbose_proxy_logger.info( f"CLI SSO callback for key: {key}, existing_key: {existing_key}" ) from litellm.proxy.proxy_server import ( prisma_client, proxy_logging_obj, user_api_key_cache, ) if not key or not key.startswith("sk-"): raise HTTPException( status_code=400, detail="Invalid key parameter. Must be a valid key ID starting with 'sk-'", ) if prisma_client is None: raise HTTPException( status_code=500, detail=CommonProxyErrors.db_not_connected_error.value ) if result is None: raise HTTPException( status_code=500, detail="SSO authentication failed - no result returned from provider", ) # After None check, cast to non-None type for type checker result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result) parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result( result=result_non_none ) verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}") try: # Get full user info from DB user_info = await get_user_info_from_db( result=result_non_none, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, user_email=parsed_openid_result.get("user_email"), user_defined_values=None, alternate_user_id=parsed_openid_result.get("user_id"), ) if user_info is None: raise HTTPException( status_code=500, detail="Failed to retrieve user information from SSO" ) # Store session info in cache (10 min TTL) from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX # Get all teams from user_info - CLI will let user select which one teams: List[str] = [] if hasattr(user_info, "teams") and user_info.teams: teams = user_info.teams if isinstance(user_info.teams, list) else [] # Also fetch team aliases for a better CLI UX. We keep the original # "teams" list of IDs for backwards compatibility and add an # optional "team_details" field containing objects with both # team_id and team_alias. team_details: List[Dict[str, Any]] = [] try: if teams: prisma_teams = await prisma_client.db.litellm_teamtable.find_many( where={"team_id": {"in": teams}} ) for team_row in prisma_teams: team_dict = team_row.model_dump() team_details.append( { "team_id": team_dict.get("team_id"), "team_alias": team_dict.get("team_alias"), } ) except Exception as e: # If anything goes wrong here, fall back gracefully without # impacting the SSO flow. verbose_proxy_logger.error( f"Error fetching team details for CLI SSO session: {e}" ) session_data = { "user_id": user_info.user_id, "user_role": user_info.user_role, "models": user_info.models if hasattr(user_info, "models") else [], "user_email": parsed_openid_result.get("user_email"), "teams": teams, # Optional rich metadata for clients that want nicer display "team_details": team_details, } cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key}" user_api_key_cache.set_cache(key=cache_key, value=session_data, ttl=600) verbose_proxy_logger.info( f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}" ) # Return success page from fastapi.responses import HTMLResponse from litellm.proxy.common_utils.html_forms.cli_sso_success import ( render_cli_sso_success_page, ) html_content = render_cli_sso_success_page() return HTMLResponse(content=html_content, status_code=200) except Exception as e: verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}") raise HTTPException( status_code=500, detail=f"Failed to process CLI SSO: {str(e)}" ) @router.get("/sso/cli/poll/{key_id}", tags=["experimental"], include_in_schema=False) async def cli_poll_key(key_id: str, team_id: Optional[str] = None): """ CLI polling endpoint - retrieves session from cache and generates JWT. Flow: 1. First poll (no team_id): Returns teams list without generating JWT 2. Second poll (with team_id): Generates JWT with selected team and deletes session Args: key_id: The session key ID team_id: Optional team ID to assign to the JWT. If provided, must be one of user's teams. """ from litellm.constants import CLI_SSO_SESSION_CACHE_KEY_PREFIX from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken from litellm.proxy.proxy_server import user_api_key_cache if not key_id.startswith("sk-"): raise HTTPException(status_code=400, detail="Invalid key ID format") try: # Look up session in cache cache_key = f"{CLI_SSO_SESSION_CACHE_KEY_PREFIX}:{key_id}" session_data = user_api_key_cache.get_cache(key=cache_key) if session_data: user_teams = session_data.get("teams", []) user_team_details = session_data.get("team_details") user_id = session_data["user_id"] verbose_proxy_logger.info( f"CLI poll: user={user_id}, team_id={team_id}, user_teams={user_teams}, num_teams={len(user_teams)}" ) # If no team_id provided and user has teams, return teams list for selection # Don't generate JWT yet - let CLI select a team first. For newer # clients we return rich team details (id + alias); older clients # can continue to rely on the simple "teams" list. if team_id is None and len(user_teams) > 1: verbose_proxy_logger.info( f"Returning teams list for user {user_id} to select from: {user_teams}" ) # Best-effort construction of team_details if it wasn't # already cached for some reason. team_details_response: Optional[List[Dict[str, Any]]] = None if isinstance(user_team_details, list) and user_team_details: team_details_response = user_team_details elif user_teams: team_details_response = [ {"team_id": t, "team_alias": None} for t in user_teams ] return { "status": "ready", "user_id": user_id, "teams": user_teams, "team_details": team_details_response, "requires_team_selection": True, } # Validate team_id if provided if team_id is not None: if team_id not in user_teams: raise HTTPException( status_code=403, detail=f"User does not belong to team: {team_id}. Available teams: {user_teams}", ) else: # If no team_id provided and user has 0 or 1 team, use first team (or None) team_id = user_teams[0] if len(user_teams) > 0 else None # Create user object for JWT generation user_info = LiteLLM_UserTable( user_id=user_id, user_role=session_data["user_role"], models=session_data.get("models", []), max_budget=litellm.max_ui_session_budget, ) # Generate CLI JWT on-demand (expiration configurable via LITELLM_CLI_JWT_EXPIRATION_HOURS) # Pass selected team_id to ensure JWT has correct team jwt_token = ExperimentalUIJWTToken.get_cli_jwt_auth_token( user_info=user_info, team_id=team_id ) # Delete cache entry (single-use) user_api_key_cache.delete_cache(key=cache_key) verbose_proxy_logger.info( f"CLI JWT generated for user: {user_id}, team: {team_id}" ) return { "status": "ready", "key": jwt_token, "user_id": user_id, "team_id": team_id, "teams": user_teams, # Echo back any team details we have so clients can # present nicer information if needed. "team_details": user_team_details, } else: return {"status": "pending"} except Exception as e: verbose_proxy_logger.error(f"Error polling for CLI JWT: {e}") raise HTTPException( status_code=500, detail=f"Error checking session status: {str(e)}" ) async def insert_sso_user( result_openid: Optional[Union[OpenID, dict]], user_defined_values: Optional[SSOUserDefinedValues] = None, ) -> NewUserResponse: """ Helper function to create a New User in LiteLLM DB after a successful SSO login Args: result_openid (OpenID): User information in OpenID format if the login was successful. user_defined_values (Optional[SSOUserDefinedValues], optional): LiteLLM SSOValues / fields that were read Returns: Tuple[str, str]: User ID and User Role """ verbose_proxy_logger.debug( f"Inserting SSO user into DB. User values: {user_defined_values}" ) if result_openid is None: raise ValueError("result_openid is None") if isinstance(result_openid, dict): result_openid = OpenID(**result_openid) if user_defined_values is None: raise ValueError("user_defined_values is None") # Apply default_internal_user_params if litellm.default_internal_user_params: # Preserve the SSO-extracted role if it's a valid LiteLLM role, # regardless of how it was determined (role_mappings, Microsoft app_roles, # GENERIC_USER_ROLE_ATTRIBUTE, custom SSO handler, etc.) sso_role = user_defined_values.get("user_role") if _should_use_role_from_sso_response(sso_role): # Preserve the SSO-extracted role, but apply other defaults preserved_role = sso_role user_defined_values.update(litellm.default_internal_user_params) # type: ignore user_defined_values["user_role"] = preserved_role # Restore preserved role verbose_proxy_logger.debug( f"Preserved SSO-extracted role '{preserved_role}'" ) else: # SSO didn't provide a valid role, apply all defaults including role user_defined_values.update(litellm.default_internal_user_params) # type: ignore # Set budget for internal users if user_defined_values.get("user_role") == LitellmUserRoles.INTERNAL_USER.value: if user_defined_values.get("max_budget") is None: user_defined_values["max_budget"] = litellm.max_internal_user_budget if user_defined_values.get("budget_duration") is None: user_defined_values[ "budget_duration" ] = litellm.internal_user_budget_duration if user_defined_values["user_role"] is None: user_defined_values["user_role"] = LitellmUserRoles.INTERNAL_USER_VIEW_ONLY new_user_request = NewUserRequest( user_id=user_defined_values["user_id"], user_email=normalize_email(user_defined_values["user_email"]), user_role=user_defined_values["user_role"], # type: ignore max_budget=user_defined_values["max_budget"], budget_duration=user_defined_values["budget_duration"], sso_user_id=user_defined_values["user_id"], auto_create_key=False, ) if result_openid and hasattr(result_openid, "provider"): new_user_request.metadata = { "auth_provider": getattr(result_openid, "provider") } response = await new_user( data=new_user_request, user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), ) return response @router.get( "/sso/get/ui_settings", tags=["experimental"], include_in_schema=False, dependencies=[Depends(user_api_key_auth)], ) async def get_ui_settings(request: Request): from litellm.proxy.proxy_server import general_settings, proxy_state _proxy_base_url = os.getenv("PROXY_BASE_URL", None) _logout_url = os.getenv("PROXY_LOGOUT_URL", None) _api_doc_base_url = os.getenv("LITELLM_UI_API_DOC_BASE_URL", None) _is_sso_enabled = _has_user_setup_sso() disable_expensive_db_queries = ( proxy_state.get_proxy_state_variable("spend_logs_row_count") > MAX_SPENDLOG_ROWS_TO_QUERY ) default_team_disabled = general_settings.get("default_team_disabled", False) if "PROXY_DEFAULT_TEAM_DISABLED" in os.environ: if os.environ["PROXY_DEFAULT_TEAM_DISABLED"].lower() == "true": default_team_disabled = True return { "PROXY_BASE_URL": _proxy_base_url, "PROXY_LOGOUT_URL": _logout_url, "LITELLM_UI_API_DOC_BASE_URL": _api_doc_base_url, "DEFAULT_TEAM_DISABLED": default_team_disabled, "SSO_ENABLED": _is_sso_enabled, "NUM_SPEND_LOGS_ROWS": proxy_state.get_proxy_state_variable( "spend_logs_row_count" ), "DISABLE_EXPENSIVE_DB_QUERIES": disable_expensive_db_queries, } @router.get( "/sso/readiness", tags=["experimental"], dependencies=[Depends(user_api_key_auth)], ) async def sso_readiness(): """ Health endpoint for checking SSO readiness. Checks if the configured SSO provider has all required environment variables set in memory. """ microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) # Determine which SSO provider is configured configured_provider = None if google_client_id is not None: configured_provider = "google" elif microsoft_client_id is not None: configured_provider = "microsoft" elif generic_client_id is not None: configured_provider = "generic" # If no SSO is configured, return healthy (SSO is optional) if configured_provider is None: return { "status": "healthy", "sso_configured": False, "message": "No SSO provider configured", } # Check required environment variables for the configured provider missing_vars = [] if configured_provider == "google": google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) if google_client_secret is None: missing_vars.append("GOOGLE_CLIENT_SECRET") elif configured_provider == "microsoft": microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) if microsoft_client_secret is None: missing_vars.append("MICROSOFT_CLIENT_SECRET") if microsoft_tenant is None: missing_vars.append("MICROSOFT_TENANT") elif configured_provider == "generic": generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_authorization_endpoint = os.getenv( "GENERIC_AUTHORIZATION_ENDPOINT", None ) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) if generic_client_secret is None: missing_vars.append("GENERIC_CLIENT_SECRET") if generic_authorization_endpoint is None: missing_vars.append("GENERIC_AUTHORIZATION_ENDPOINT") if generic_token_endpoint is None: missing_vars.append("GENERIC_TOKEN_ENDPOINT") if generic_userinfo_endpoint is None: missing_vars.append("GENERIC_USERINFO_ENDPOINT") # If all required variables are present, return healthy if len(missing_vars) == 0: return { "status": "healthy", "sso_configured": True, "provider": configured_provider, "message": f"{configured_provider.capitalize()} SSO is properly configured", } # If some variables are missing, return unhealthy raise HTTPException( status_code=503, detail={ "status": "unhealthy", "sso_configured": True, "provider": configured_provider, "missing_environment_variables": missing_vars, "message": f"{configured_provider.capitalize()} SSO is configured but missing required environment variables: {', '.join(missing_vars)}", }, ) class SSOAuthenticationHandler: """ Handler for SSO Authentication across all SSO providers """ @staticmethod async def get_sso_login_redirect( redirect_url: str, google_client_id: Optional[str] = None, microsoft_client_id: Optional[str] = None, generic_client_id: Optional[str] = None, state: Optional[str] = None, ) -> Optional[RedirectResponse]: """ Step 1. Call Get Login Redirect for the SSO provider. Send the redirect response to `redirect_url` Args: redirect_url (str): The URL to redirect the user to after login google_client_id (Optional[str], optional): The Google Client ID. Defaults to None. microsoft_client_id (Optional[str], optional): The Microsoft Client ID. Defaults to None. generic_client_id (Optional[str], optional): The Generic Client ID. Defaults to None. Returns: RedirectResponse: The redirect response from the SSO provider. """ # Google SSO Auth if google_client_id is not None: from fastapi_sso.sso.google import GoogleSSO google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) if google_client_secret is None: raise ProxyException( message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GOOGLE_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) google_sso = GoogleSSO( client_id=google_client_id, client_secret=google_client_secret, redirect_uri=redirect_url, ) verbose_proxy_logger.info( f"In /google-login/key/generate, \nGOOGLE_REDIRECT_URI: {redirect_url}\nGOOGLE_CLIENT_ID: {google_client_id}" ) with google_sso: return await google_sso.get_login_redirect(state=state) # Microsoft SSO Auth elif microsoft_client_id is not None: microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) if microsoft_client_secret is None: raise ProxyException( message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="MICROSOFT_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) microsoft_sso = CustomMicrosoftSSO( client_id=microsoft_client_id, client_secret=microsoft_client_secret, tenant=microsoft_tenant, redirect_uri=redirect_url, allow_insecure_http=True, ) with microsoft_sso: return await microsoft_sso.get_login_redirect(state=state) elif generic_client_id is not None: from fastapi_sso.sso.base import DiscoveryDocument from fastapi_sso.sso.generic import create_provider generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None) generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split( " " ) generic_authorization_endpoint = os.getenv( "GENERIC_AUTHORIZATION_ENDPOINT", None ) generic_token_endpoint = os.getenv("GENERIC_TOKEN_ENDPOINT", None) generic_userinfo_endpoint = os.getenv("GENERIC_USERINFO_ENDPOINT", None) if generic_client_secret is None: raise ProxyException( message="GENERIC_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_authorization_endpoint is None: raise ProxyException( message="GENERIC_AUTHORIZATION_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_AUTHORIZATION_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_token_endpoint is None: raise ProxyException( message="GENERIC_TOKEN_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_TOKEN_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if generic_userinfo_endpoint is None: raise ProxyException( message="GENERIC_USERINFO_ENDPOINT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GENERIC_USERINFO_ENDPOINT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) verbose_proxy_logger.debug( f"authorization_endpoint: {generic_authorization_endpoint}\ntoken_endpoint: {generic_token_endpoint}\nuserinfo_endpoint: {generic_userinfo_endpoint}" ) verbose_proxy_logger.debug( f"GENERIC_REDIRECT_URI: {redirect_url}\nGENERIC_CLIENT_ID: {generic_client_id}\n" ) discovery = DiscoveryDocument( authorization_endpoint=generic_authorization_endpoint, token_endpoint=generic_token_endpoint, userinfo_endpoint=generic_userinfo_endpoint, ) SSOProvider = create_provider(name="oidc", discovery_document=discovery) generic_sso = SSOProvider( client_id=generic_client_id, client_secret=generic_client_secret, redirect_uri=redirect_url, allow_insecure_http=True, scope=generic_scope, ) return await SSOAuthenticationHandler.get_generic_sso_redirect_response( generic_sso=generic_sso, state=state, generic_authorization_endpoint=generic_authorization_endpoint, ) raise ValueError( "Unknown SSO provider. Please setup SSO with client IDs https://docs.litellm.ai/docs/proxy/admin_ui_sso" ) @staticmethod async def get_generic_sso_redirect_response( generic_sso: Any, state: Optional[str] = None, generic_authorization_endpoint: Optional[str] = None, ) -> Optional[RedirectResponse]: """ Get the redirect response for Generic SSO """ from urllib.parse import parse_qs, urlencode, urlparse, urlunparse from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache with generic_sso: # TODO: state should be a random string and added to the user session with cookie # or a cryptographicly signed state that we can verify stateless # For simplification we are using a static state, this is not perfect but some # SSO providers do not allow stateless verification ( redirect_params, code_verifier, ) = SSOAuthenticationHandler._get_generic_sso_redirect_params( state=state, generic_authorization_endpoint=generic_authorization_endpoint, ) # Separate PKCE params from state params (fastapi-sso doesn't accept code_challenge) pkce_params = {} state_only_params = {} for key, value in redirect_params.items(): if key in ("code_challenge", "code_challenge_method"): pkce_params[key] = value else: state_only_params[key] = value # Get the redirect response from fastapi-sso with only state param redirect_response = await generic_sso.get_login_redirect(**state_only_params) # type: ignore # If PKCE is enabled, add PKCE parameters to the redirect URL if code_verifier and "state" in redirect_params: # Store code_verifier in cache (10 min TTL). Wrap in dict for proper # JSON serialization in Redis. Use Redis when available so callbacks # landing on another pod can retrieve it (multi-pod SSO). cache_key = f"pkce_verifier:{redirect_params['state']}" if redis_usage_cache is not None: await redis_usage_cache.async_set_cache( key=cache_key, value={"code_verifier": code_verifier}, ttl=600, ) else: await user_api_key_cache.async_set_cache( key=cache_key, value={"code_verifier": code_verifier}, ttl=600, ) verbose_proxy_logger.debug( "PKCE code_verifier stored in cache (TTL: 600s)" ) # Add PKCE parameters to the authorization URL if pkce_params: parsed_url = urlparse(str(redirect_response.headers["location"])) query_params = parse_qs(parsed_url.query) # Add PKCE parameters for key, value in pkce_params.items(): query_params[key] = [value] # Reconstruct the URL with PKCE parameters new_query = urlencode(query_params, doseq=True) new_url = urlunparse( ( parsed_url.scheme, parsed_url.netloc, parsed_url.path, parsed_url.params, new_query, parsed_url.fragment, ) ) # Update the redirect response redirect_response.headers["location"] = new_url return redirect_response @staticmethod def _get_generic_sso_redirect_params( state: Optional[str] = None, generic_authorization_endpoint: Optional[str] = None, ) -> Tuple[dict, Optional[str]]: """ Get redirect parameters for Generic SSO with proper state priority handling. Optionally generates PKCE parameters if GENERIC_CLIENT_USE_PKCE is enabled. Priority order: 1. CLI state (if provided) 2. GENERIC_CLIENT_STATE environment variable 3. Generated UUID (required by Okta and most OAuth providers) Args: state: Optional state parameter (e.g., CLI state) generic_authorization_endpoint: Authorization endpoint URL Returns: Tuple[dict, Optional[str]]: - Redirect parameters for SSO login (may include PKCE params) - code_verifier (if PKCE is enabled, None otherwise) """ redirect_params = {} code_verifier: Optional[str] = None if state: # CLI state takes priority # the litellm proxy cli sends the "state" parameter to the proxy server for auth. We should maintain the state parameter for the cli if it is provided redirect_params["state"] = state else: generic_client_state = os.getenv("GENERIC_CLIENT_STATE", None) if generic_client_state: redirect_params["state"] = generic_client_state else: redirect_params["state"] = uuid.uuid4().hex # Handle PKCE (Proof Key for Code Exchange) if enabled # Set GENERIC_CLIENT_USE_PKCE=true to enable PKCE for enhanced OAuth security use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true" if use_pkce: ( code_verifier, code_challenge, ) = SSOAuthenticationHandler.generate_pkce_params() redirect_params["code_challenge"] = code_challenge redirect_params["code_challenge_method"] = "S256" verbose_proxy_logger.debug("PKCE enabled for authorization request") return redirect_params, code_verifier @staticmethod def should_use_sso_handler( google_client_id: Optional[str] = None, microsoft_client_id: Optional[str] = None, generic_client_id: Optional[str] = None, ) -> bool: if ( google_client_id is not None or microsoft_client_id is not None or generic_client_id is not None ): return True return False @staticmethod def get_redirect_url_for_sso( request: Request, sso_callback_route: str, existing_key: Optional[str] = None, ) -> str: """ Get the redirect URL for SSO Note: existing_key is not added to the URL to avoid changing the callback URL. It should be passed via the state parameter instead. """ from litellm.proxy.utils import get_custom_url redirect_url = get_custom_url(request_base_url=str(request.base_url)) if redirect_url.endswith("/"): redirect_url += sso_callback_route else: redirect_url += "/" + sso_callback_route return redirect_url @staticmethod async def upsert_sso_user( result: Optional[Union[CustomOpenID, OpenID, dict]], user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], user_email: Optional[str], user_defined_values: Optional[SSOUserDefinedValues], prisma_client: PrismaClient, ): """ Connects the SSO Users to the User Table in LiteLLM DB - If user on LiteLLM DB, update the user_email and user_role (if SSO provides valid role) with the SSO values - If user not on LiteLLM DB, insert the user into LiteLLM DB """ try: if user_info is not None: user_id = user_info.user_id update_data = _build_sso_user_update_data( result=result, user_email=user_email, user_id=user_id, ) await prisma_client.db.litellm_usertable.update_many( where={"user_id": user_id}, data=update_data ) else: verbose_proxy_logger.info( "user not in DB, inserting user into LiteLLM DB" ) # user not in DB, insert User into LiteLLM DB user_info = await insert_sso_user( result_openid=result, user_defined_values=user_defined_values, ) return user_info except Exception as e: verbose_proxy_logger.exception( f"Error upserting SSO user into LiteLLM DB: {e}" ) return user_info @staticmethod async def add_user_to_teams_from_sso_response( result: Optional[Union[CustomOpenID, OpenID, dict]], user_info: Optional[Union[NewUserResponse, LiteLLM_UserTable]], ): """ Adds the user as a team member to the teams specified in the SSO responses `team_ids` field The `team_ids` field is populated by litellm after processing the SSO response """ if user_info is None: verbose_proxy_logger.debug( "User not found in LiteLLM DB, skipping team member addition" ) return sso_teams = getattr(result, "team_ids", []) await add_missing_team_member(user_info=user_info, sso_teams=sso_teams) @staticmethod def verify_user_in_restricted_sso_group( general_settings: Dict, result: Optional[Union[CustomOpenID, OpenID, dict]], received_response: Optional[dict], ) -> Literal[True]: """ when ui_access_mode.type == "restricted_sso_group": - result.team_ids should contain the restricted_sso_group - if not, raise a ProxyException - if so, return True - if result.team_ids is None, return False - if result.team_ids is an empty list, return False - if result.team_ids is a list, return True if the restricted_sso_group is in the list, otherwise return False """ ui_access_mode = cast( Optional[Union[Dict, str]], general_settings.get("ui_access_mode") ) if ui_access_mode is None: return True if isinstance(ui_access_mode, str): return True team_ids = getattr(result, "team_ids", []) if ui_access_mode.get("type") == "restricted_sso_group": restricted_sso_group = ui_access_mode.get("restricted_sso_group") if restricted_sso_group not in team_ids: raise ProxyException( message=f"User is not in the restricted SSO group: {restricted_sso_group}. User groups: {team_ids}. Received SSO response: {received_response}", type=ProxyErrorTypes.auth_error, param="restricted_sso_group", code=status.HTTP_403_FORBIDDEN, ) return True @staticmethod async def create_litellm_team_from_sso_group( litellm_team_id: str, litellm_team_name: Optional[str] = None, ): """ Creates a Litellm Team from a SSO Group ID Your SSO provider might have groups that should be created on LiteLLM Use this helper to create a Litellm Team from a SSO Group ID Args: litellm_team_id (str): The ID of the Litellm Team litellm_team_name (Optional[str]): The name of the Litellm Team """ from litellm.proxy.proxy_server import prisma_client if prisma_client is None: raise ProxyException( message="Prisma client not found. Set it in the proxy_server.py file", type=ProxyErrorTypes.auth_error, param="prisma_client", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) try: team_obj = await prisma_client.db.litellm_teamtable.find_first( where={"team_id": litellm_team_id} ) verbose_proxy_logger.debug(f"Team object: {team_obj}") # only create a new team if it doesn't exist if team_obj: verbose_proxy_logger.debug( f"Team already exists: {litellm_team_id} - {litellm_team_name}" ) return team_request: NewTeamRequest = NewTeamRequest( team_id=litellm_team_id, team_alias=litellm_team_name, ) if litellm.default_team_params: team_request = SSOAuthenticationHandler._cast_and_deepcopy_litellm_default_team_params( default_team_params=litellm.default_team_params, litellm_team_id=litellm_team_id, litellm_team_name=litellm_team_name, team_request=team_request, ) await new_team( data=team_request, # params used for Audit Logging http_request=Request(scope={"type": "http", "method": "POST"}), user_api_key_dict=UserAPIKeyAuth( token="", key_alias=f"litellm.{MicrosoftSSOHandler.__name__}", ), ) except Exception as e: verbose_proxy_logger.exception(f"Error creating Litellm Team: {e}") @staticmethod def _cast_and_deepcopy_litellm_default_team_params( default_team_params: Union[DefaultTeamSSOParams, Dict], team_request: NewTeamRequest, litellm_team_id: str, litellm_team_name: Optional[str] = None, ) -> NewTeamRequest: """ Casts and deepcopies the litellm.default_team_params to a NewTeamRequest object - Ensures we create a new DefaultTeamSSOParams object - Handle the case where litellm.default_team_params is a dict or a DefaultTeamSSOParams object - Adds the litellm_team_id and litellm_team_name to the DefaultTeamSSOParams object """ if isinstance(default_team_params, dict): _team_request = deepcopy(default_team_params) _team_request["team_id"] = litellm_team_id _team_request["team_alias"] = litellm_team_name team_request = NewTeamRequest(**_team_request) elif isinstance(litellm.default_team_params, DefaultTeamSSOParams): _default_team_params = deepcopy(litellm.default_team_params) _new_team_request = team_request.model_dump() _new_team_request.update(_default_team_params) team_request = NewTeamRequest(**_new_team_request) return team_request @staticmethod def _get_cli_state( source: Optional[str], key: Optional[str], existing_key: Optional[str] = None ) -> Optional[str]: """ Checks the request 'source' if a cli state token was passed in This is used to authenticate through the CLI login flow. The state parameter format is: {PREFIX}:{key}:{existing_key} - If existing_key is provided, it's included in the state - The state parameter is used to pass data through the OAuth flow without changing the callback URL """ from litellm.constants import ( LITELLM_CLI_SESSION_TOKEN_PREFIX, LITELLM_CLI_SOURCE_IDENTIFIER, ) if source == LITELLM_CLI_SOURCE_IDENTIFIER and key: if existing_key: return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}:{existing_key}" else: return f"{LITELLM_CLI_SESSION_TOKEN_PREFIX}:{key}" else: return None @staticmethod def _get_user_email_and_id_from_result( result: Optional[Union[OpenID, dict]], generic_client_id: Optional[str] = None, ) -> ParsedOpenIDResult: """ Gets the user email and id from the OpenID result after validating the email domain """ user_email: Optional[str] = normalize_email(getattr(result, "email", None)) user_id: Optional[str] = ( getattr(result, "id", None) if result is not None else None ) user_role: Optional[str] = None if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None: email_domain = user_email.split("@")[1] allowed_domains = os.getenv("ALLOWED_EMAIL_DOMAINS").split(",") # type: ignore if email_domain not in allowed_domains: raise HTTPException( status_code=401, detail={ "message": "The email domain={}, is not an allowed email domain={}. Contact your admin to change this.".format( email_domain, allowed_domains ) }, ) # Extract user_role from result (works for all SSO providers) if result is not None: _user_role = getattr(result, "user_role", None) if _user_role is not None: # Convert enum to string if needed user_role = ( _user_role.value if isinstance(_user_role, LitellmUserRoles) else _user_role ) verbose_proxy_logger.debug( f"Extracted user_role from SSO result: {user_role}" ) # generic client id - override with custom attribute name if specified if generic_client_id is not None and result is not None: generic_user_role_attribute_name = os.getenv( "GENERIC_USER_ROLE_ATTRIBUTE", "role" ) user_id = getattr(result, "id", None) user_email = normalize_email(getattr(result, "email", None)) if user_role is None: _role_from_attr = getattr(result, generic_user_role_attribute_name, None) # type: ignore if _role_from_attr is not None: # Convert enum to string if needed user_role = ( _role_from_attr.value if isinstance(_role_from_attr, LitellmUserRoles) else _role_from_attr ) if user_id is None and result is not None: _first_name = getattr(result, "first_name", "") or "" _last_name = getattr(result, "last_name", "") or "" user_id = _first_name + _last_name if user_email is not None and (user_id is None or len(user_id) == 0): user_id = user_email return ParsedOpenIDResult( user_email=user_email, user_id=user_id, user_role=user_role, ) @staticmethod async def get_redirect_response_from_openid( # noqa: PLR0915 result: Union[OpenID, dict, CustomOpenID], request: Request, received_response: Optional[dict] = None, generic_client_id: Optional[str] = None, ui_access_mode: Optional[Dict] = None, ) -> RedirectResponse: import jwt from litellm.proxy.proxy_server import ( general_settings, generate_key_helper_fn, master_key, premium_user, proxy_logging_obj, user_api_key_cache, user_custom_sso, ) from litellm.proxy.utils import get_prisma_client_or_throw from litellm.types.proxy.ui_sso import ReturnedUITokenObject prisma_client = get_prisma_client_or_throw( "Prisma client is None, connect a database to your proxy" ) # User is Authe'd in - generate key for the UI to access Proxy parsed_openid_result = ( SSOAuthenticationHandler._get_user_email_and_id_from_result( result=result, generic_client_id=generic_client_id ) ) user_email = parsed_openid_result.get("user_email") user_id = parsed_openid_result.get("user_id") user_role = parsed_openid_result.get("user_role") verbose_proxy_logger.info(f"SSO callback result: {result}") user_info = None user_id_models: List = [] max_internal_user_budget = litellm.max_internal_user_budget internal_user_budget_duration = litellm.internal_user_budget_duration # User might not be already created on first generation of key # But if it is, we want their models preferences default_ui_key_values: Dict[str, Any] = { "duration": LITELLM_UI_SESSION_DURATION, "key_max_budget": litellm.max_ui_session_budget, "aliases": {}, "config": {}, "spend": 0, "team_id": "litellm-dashboard", } user_defined_values: Optional[SSOUserDefinedValues] = None if user_custom_sso is not None: if inspect.iscoroutinefunction(user_custom_sso): user_defined_values = await user_custom_sso(result) # type: ignore else: raise ValueError("user_custom_sso must be a coroutine function") elif user_id is not None: user_defined_values = SSOUserDefinedValues( models=user_id_models, user_id=user_id, user_email=user_email, max_budget=max_internal_user_budget, user_role=user_role, budget_duration=internal_user_budget_duration, ) # (IF SET) Verify user is in restricted SSO group SSOAuthenticationHandler.verify_user_in_restricted_sso_group( general_settings=general_settings, result=result, received_response=received_response, ) user_info = await get_user_info_from_db( result=result, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, user_email=user_email, user_defined_values=user_defined_values, alternate_user_id=user_id, ) user_defined_values = apply_user_info_values_to_sso_user_defined_values( user_info=user_info, user_defined_values=user_defined_values ) if user_defined_values is None: raise Exception( "Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues" ) verbose_proxy_logger.info( f"user_defined_values for creating ui key: {user_defined_values}" ) default_ui_key_values.update(user_defined_values) default_ui_key_values["request_type"] = "key" response = await generate_key_helper_fn( **default_ui_key_values, # type: ignore table_name="key", ) key = response["token"] # type: ignore user_id = response["user_id"] # type: ignore user_role = ( user_defined_values["user_role"] or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value ) if user_id and isinstance(user_id, str): user_role = await check_and_update_if_proxy_admin_id( user_role=user_role, user_id=user_id, prisma_client=prisma_client ) verbose_proxy_logger.debug( f"user_role: {user_role}; ui_access_mode: {ui_access_mode}" ) ## CHECK IF ROLE ALLOWED TO USE PROXY ## is_admin_only_access = check_is_admin_only_access(ui_access_mode or {}) if is_admin_only_access: has_access = has_admin_ui_access(user_role or "") if not has_access: raise HTTPException( status_code=401, detail={ "error": f"User not allowed to access proxy. User role={user_role}, proxy mode={ui_access_mode}" }, ) disabled_non_admin_personal_key_creation = ( get_disabled_non_admin_personal_key_creation() ) litellm_dashboard_ui = get_custom_url( request_base_url=str(request.base_url), route="ui/" ) if get_secret_bool("EXPERIMENTAL_UI_LOGIN"): _user_info: Optional[LiteLLM_UserTable] = None if ( user_defined_values is not None and user_defined_values["user_id"] is not None ): _user_info = LiteLLM_UserTable( user_id=user_defined_values["user_id"], user_role=user_defined_values["user_role"] or user_role, models=[], max_budget=litellm.max_ui_session_budget, ) if _user_info is None: raise HTTPException( status_code=401, detail={ "error": "User Information is required for experimental UI login" }, ) key = ExperimentalUIJWTToken.get_experimental_ui_login_jwt_auth_token( _user_info ) returned_ui_token_object = ReturnedUITokenObject( user_id=cast(str, user_id), key=key, user_email=user_email, user_role=user_role or LitellmUserRoles.INTERNAL_USER_VIEW_ONLY.value, login_method="sso", premium_user=premium_user, auth_header_name=general_settings.get( "litellm_key_header_name", "Authorization" ), disabled_non_admin_personal_key_creation=disabled_non_admin_personal_key_creation, server_root_path=get_server_root_path(), ) jwt_token = jwt.encode( cast(dict, returned_ui_token_object), master_key or "", algorithm="HS256", ) if user_id is not None and isinstance(user_id, str): litellm_dashboard_ui += "?login=success" verbose_proxy_logger.info(f"Redirecting to {litellm_dashboard_ui}") redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303) redirect_response.set_cookie(key="token", value=jwt_token) return redirect_response @staticmethod async def prepare_token_exchange_parameters( request: Request, generic_include_client_id: bool, ) -> dict: """ Prepare token exchange parameters for Generic SSO. Args: request: Request object generic_include_client_id: Generic OAuth Client ID Returns: dict: Token exchange parameters """ # Prepare token exchange parameters (may add code_verifier: str later) token_params: Dict[str, Any] = {"include_client_id": generic_include_client_id} # Retrieve PKCE code_verifier if PKCE was used in authorization. # Gate on GENERIC_CLIENT_USE_PKCE to avoid an unnecessary Redis round-trip # on every non-PKCE SSO callback. query_params = dict(request.query_params) state = query_params.get("state") use_pkce = os.getenv("GENERIC_CLIENT_USE_PKCE", "false").lower() == "true" if use_pkce and not state: verbose_proxy_logger.warning( "PKCE is enabled (GENERIC_CLIENT_USE_PKCE=true) but no 'state' parameter " "was found in the callback. The PKCE verifier cannot be retrieved without " "a state value — the token exchange will proceed without code_verifier, " "which the provider may reject. Ensure your OAuth provider returns 'state' " "in the callback redirect." ) if state and use_pkce: from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache cache_key = f"pkce_verifier:{state}" if redis_usage_cache is not None: cached_data = await redis_usage_cache.async_get_cache(key=cache_key) else: cached_data = await user_api_key_cache.async_get_cache(key=cache_key) code_verifier = None # Track why code_verifier is absent for accurate strict-mode diagnostics. _empty_value_in_dict = False # dict format correct but value is empty/null if cached_data: # Extract code_verifier from dict (stored as dict for JSON serialization) if isinstance(cached_data, dict) and "code_verifier" in cached_data: code_verifier = cached_data["code_verifier"] if not code_verifier: # Dict format is correct but value is empty or null. This is # a distinct case from an unrecognized format — the entry exists # but was stored with an empty/null verifier (data integrity issue). _empty_value_in_dict = True verbose_proxy_logger.warning( "PKCE verifier dict for state '%s' has an empty/null code_verifier " "value — may indicate a storage bug. Treating as a cache miss.", state, ) else: verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache") elif isinstance(cached_data, str): # Handle legacy format (plain string) for backward compatibility code_verifier = cached_data verbose_proxy_logger.warning( "Retrieved code_verifier in legacy plain-string format. " "Future storage will use dict format." ) else: # Defer the detailed ERROR log to the strict-mode branch below # (which includes state and a diagnostic message). Log at DEBUG # here to avoid duplicate ERROR entries in the same request. verbose_proxy_logger.debug( "Unexpected PKCE verifier cache format (type=%s); skipping.", type(cached_data).__name__, ) if code_verifier: # Add code_verifier to token exchange parameters. token_params["code_verifier"] = code_verifier # Return the cache key so the caller can delete it *after* a # successful token exchange (avoids losing the verifier on retry # if the exchange fails partway through). token_params["_pkce_cache_key"] = cache_key else: await SSOAuthenticationHandler._handle_missing_pkce_verifier( state=state, cache_key=cache_key, cached_data=cached_data, empty_value_in_dict=_empty_value_in_dict, redis_usage_cache=redis_usage_cache, user_api_key_cache=user_api_key_cache, ) return token_params @staticmethod async def _handle_missing_pkce_verifier( state: Optional[str], cache_key: str, cached_data: object, empty_value_in_dict: bool, redis_usage_cache: object, user_api_key_cache: object, ) -> None: """Handle the case where PKCE verifier could not be extracted from cache. In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException. Otherwise logs a warning and returns (token exchange proceeds without verifier). """ active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache strict_cache_miss = ( os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true" ) if strict_cache_miss: if empty_value_in_dict: await SSOAuthenticationHandler._delete_pkce_verifier(cache_key) raise ProxyException( message=( f"PKCE verifier for state '{state}' was found in cache but " f"has an empty or null code_verifier value — possible storage bug." ), type=ProxyErrorTypes.auth_error, param="PKCE_CACHE_MISS", code=status.HTTP_401_UNAUTHORIZED, ) elif cached_data is not None: await SSOAuthenticationHandler._delete_pkce_verifier(cache_key) verbose_proxy_logger.error( "PKCE verifier for state '%s' has an unrecognized format (type=%s); " "treating as a cache miss. Investigate the cached value — it may be " "a corrupt or stale entry.", state, type(cached_data).__name__, ) raise ProxyException( message=( f"PKCE verifier for state '{state}' has an unrecognized format " f"(type={type(cached_data).__name__}). The cached entry may be corrupt." ), type=ProxyErrorTypes.auth_error, param="PKCE_CACHE_MISS", code=status.HTTP_401_UNAUTHORIZED, ) else: if redis_usage_cache is not None: cause = ( "The authorization and callback were likely handled by different " "instances — the verifier was stored on one pod but not found on another." ) else: cause = ( "The verifier may have expired (TTL), been lost on a pod restart, " "or the PKCE authorization step was never completed. " "Configure Redis so all proxy instances share the PKCE verifier." ) verbose_proxy_logger.error( "PKCE is enabled but no verifier found in cache for state '%s'. " "%s Cache type: %s.", state, cause, type(active_cache).__name__, ) raise ProxyException( message=f"PKCE verifier not found in cache for state '{state}'. {cause}", type=ProxyErrorTypes.auth_error, param="PKCE_CACHE_MISS", code=status.HTTP_401_UNAUTHORIZED, ) else: if cached_data is not None: await SSOAuthenticationHandler._delete_pkce_verifier(cache_key) verbose_proxy_logger.warning( "PKCE is enabled but verifier not found in cache for state '%s' " "(cache type: %s, raw data present: %s). " "Continuing without code_verifier — set PKCE_STRICT_CACHE_MISS=true to fail fast instead.", state, type(active_cache).__name__, cached_data is not None, ) @staticmethod async def _delete_pkce_verifier(cache_key: str) -> None: """Delete a single-use PKCE verifier from cache after a successful exchange. Failure is non-fatal: a leftover verifier is a minor security concern (unused key in cache) but not worth aborting an otherwise-successful login. """ from litellm.proxy.proxy_server import redis_usage_cache, user_api_key_cache try: if redis_usage_cache is not None: await redis_usage_cache.async_delete_cache(key=cache_key) else: await user_api_key_cache.async_delete_cache(key=cache_key) except Exception as exc: verbose_proxy_logger.warning( "PKCE: failed to delete verifier cache key '%s' (best-effort cleanup): %s", cache_key, exc, ) @staticmethod def generate_pkce_params() -> Tuple[str, str]: """ Generate PKCE (Proof Key for Code Exchange) parameters for OAuth 2.0. Returns: Tuple[str, str]: (code_verifier, code_challenge) - code_verifier: Random 43-128 character string (we use 43 for efficiency) - code_challenge: Base64-URL-encoded SHA256 hash of the code_verifier Reference: https://datatracker.ietf.org/doc/html/rfc7636 """ # Generate a cryptographically random code_verifier (43 characters) # Using 32 random bytes which becomes 43 characters when base64-url-encoded code_verifier = ( base64.urlsafe_b64encode(secrets.token_bytes(32)) .decode("utf-8") .rstrip("=") ) # Generate code_challenge using S256 method (SHA256) code_challenge_bytes = hashlib.sha256(code_verifier.encode("utf-8")).digest() code_challenge = ( base64.urlsafe_b64encode(code_challenge_bytes).decode("utf-8").rstrip("=") ) return code_verifier, code_challenge @staticmethod def _validate_token_response(response: "httpx.Response") -> dict: """ Parse and validate the token endpoint response. Ensures the response is valid JSON, a dict, and contains a non-null access_token string. Raises ProxyException on any validation failure. """ try: token_response_raw = response.json() except Exception as json_err: verbose_proxy_logger.error( "Failed to parse token response as JSON: %s. Body: %s", json_err, response.text[:500], ) raise ProxyException( message=f"Token endpoint returned invalid JSON: {json_err}", type=ProxyErrorTypes.auth_error, param="token_exchange", code=status.HTTP_401_UNAUTHORIZED, ) if not isinstance(token_response_raw, dict): verbose_proxy_logger.error( "Token endpoint returned non-dict JSON (type=%s). Body: %s", type(token_response_raw).__name__, response.text[:500], ) raise ProxyException( message=( f"Token endpoint returned unexpected response format " f"(expected JSON object, got {type(token_response_raw).__name__})" ), type=ProxyErrorTypes.auth_error, param="token_exchange", code=status.HTTP_401_UNAUTHORIZED, ) token_response: dict = token_response_raw access_token_val = token_response.get("access_token") if not isinstance(access_token_val, str) or not access_token_val: error = token_response.get("error") error_desc = token_response.get("error_description", "") if error: detail = f"{error} - {error_desc}" if error_desc else error else: detail = ( "token endpoint returned HTTP 200 but no access_token " f"(response keys: {sorted(token_response.keys())})" ) verbose_proxy_logger.error( "Token response missing or null access_token. detail=%s", detail ) raise ProxyException( message=f"Token exchange failed: {detail}", type=ProxyErrorTypes.auth_error, param="token_exchange", code=status.HTTP_401_UNAUTHORIZED, ) return token_response @staticmethod async def _pkce_token_exchange( authorization_code: str, code_verifier: str, client_id: str, client_secret: Optional[str], token_endpoint: str, userinfo_endpoint: Optional[str], include_client_id: bool, redirect_url: Optional[str], additional_headers: Dict[str, str], ) -> dict: """ Performs a direct OAuth token exchange including the PKCE code_verifier. fastapi-sso does not forward code_verifier, so when PKCE is enabled we bypass it and call the token endpoint ourselves, then fetch user info. Returns a combined dict of the token response and user info, suitable for passing to a response_convertor. """ verbose_proxy_logger.debug( "PKCE: performing direct token exchange (code_verifier length=%d)", len(code_verifier), ) token_data: Dict[str, str] = { "grant_type": "authorization_code", "code": authorization_code, "code_verifier": code_verifier, } # Only include redirect_uri when set — omitting it avoids sending the # literal string "None" to the provider if the env var is missing. if redirect_url: token_data["redirect_uri"] = redirect_url request_headers = { **additional_headers, "Content-Type": "application/x-www-form-urlencoded", # must not be overridden "Accept": "application/json", } if not include_client_id: # Use Basic Auth only when a secret is available; public PKCE clients omit it. if client_secret: credentials = base64.b64encode( f"{client_id}:{client_secret}".encode() ).decode() request_headers["Authorization"] = f"Basic {credentials}" else: token_data["client_id"] = client_id else: token_data["client_id"] = client_id if client_secret: token_data["client_secret"] = client_secret http_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.SSO_HANDLER ) try: response = await http_client.post( url=token_endpoint, data=token_data, headers=request_headers, timeout=30.0, ) except Exception as exc: # Catch network-level errors (SSL, DNS, TCP, timeout, etc.) and # wrap them as a clean ProxyException rather than leaking raw # httpx or OS exceptions to callers. verbose_proxy_logger.error("PKCE token endpoint unreachable: %s", exc) raise ProxyException( message=f"Token endpoint request failed: {exc}", type=ProxyErrorTypes.auth_error, param="token_exchange", code=status.HTTP_401_UNAUTHORIZED, ) from exc if response.status_code != 200: verbose_proxy_logger.error( "PKCE token exchange failed. status=%s body=%s", response.status_code, response.text[:500], ) raise ProxyException( message=f"Token exchange failed: {response.status_code} - {response.text[:500]}", type=ProxyErrorTypes.auth_error, param="token_exchange", code=status.HTTP_401_UNAUTHORIZED, ) token_response = SSOAuthenticationHandler._validate_token_response(response) verbose_proxy_logger.debug( "PKCE token exchange successful. id_token_present=%s", bool(token_response.get("id_token")), ) # Bearer credentials (access_token, id_token, refresh_token) are always sourced # from token_response — not from userinfo — in the merge step below. userinfo = await SSOAuthenticationHandler._get_pkce_userinfo( access_token=token_response["access_token"], id_token=token_response.get("id_token"), userinfo_endpoint=userinfo_endpoint, additional_headers=additional_headers, ) # Merge: userinfo takes precedence for identity claims (sub, email, name, …) per # the OpenID Connect spec (userinfo is the authoritative source for identity). # Bearer credentials (access_token, id_token, refresh_token) from the token endpoint # take precedence over same-named fields in userinfo — non-standard providers sometimes # include token fields in userinfo, which must not shadow the real bearer token. # If a bearer field is absent from the token response, any userinfo-provided value # is preserved as a fallback (useful for non-standard providers that omit id_token # from the token response but include it in userinfo). # # Three-way merge semantics for each bearer-credential field: # 1. token_response has a non-null value → use it (token endpoint is authoritative) # 2. token_response explicitly sent null → remove the key so callers get a clean # absence signal; the null from the token endpoint overrides userinfo too # 3. field absent from token_response → leave whatever userinfo provided as-is # (e.g. userinfo-provided id_token from a non-standard provider) merged = {**token_response, **userinfo} for field in _OAUTH_TOKEN_FIELDS: if token_response.get(field) is not None: # Case 1: non-null in token_response — restore authoritative value. merged[field] = token_response[field] elif field in token_response: # Case 2: key exists but value is explicitly null — remove from merged. merged.pop(field, None) # Case 3: field absent from token_response — leave userinfo value as-is. return merged @staticmethod async def _get_pkce_userinfo( access_token: str, id_token: Optional[str], userinfo_endpoint: Optional[str], additional_headers: Dict[str, str], ) -> dict: """ Fetches user info from the userinfo endpoint. Falls back to decoding the id_token if the endpoint is unavailable. """ # None = request not yet attempted, failed, or returned empty/null (treated as failure # so the id_token fallback can be attempted instead of returning a session with no claims). userinfo: Optional[dict] = None if userinfo_endpoint: try: client = get_async_httpx_client( llm_provider=httpxSpecialProvider.SSO_HANDLER ) resp = await client.get( url=userinfo_endpoint, headers={ **additional_headers, "Authorization": f"Bearer {access_token}", # must not be overridden }, ) if resp.status_code == 200: try: userinfo_raw = resp.json() if not userinfo_raw: # JSON null (None) or empty dict ({}) — no identity claims. # Treat as failure so id_token fallback can be attempted. verbose_proxy_logger.warning( "Userinfo endpoint returned an empty or null response " "(type=%s); treating as failure and attempting id_token fallback. " "Check your provider's userinfo endpoint configuration.", type(userinfo_raw).__name__, ) userinfo = None else: userinfo = userinfo_raw except Exception as json_err: verbose_proxy_logger.warning( "Userinfo endpoint returned non-JSON response (status 200): %s", json_err, ) else: verbose_proxy_logger.warning( "Userinfo endpoint returned %s (body: %s), falling back to id_token", resp.status_code, resp.text[:500], ) except Exception as e: verbose_proxy_logger.warning( "Userinfo endpoint error: %s, falling back to id_token", e ) # Only fall back to id_token when the userinfo request failed (None). # Empty dict ({}) and JSON null are both treated as failure (set to None above) since # they contain no identity claims — id_token fallback is attempted in that case too. # Explicitly check for a non-empty string to avoid attempting JWT decode on # a blank or non-string id_token field from a misbehaving provider. if userinfo is None and isinstance(id_token, str) and id_token: try: userinfo = jwt.decode(id_token, options={"verify_signature": False}) if not userinfo: # jwt.decode returned an empty dict (payload-free JWT or provider bug). # Treat this the same as a missing userinfo — the session would have no # identity claims, which is equivalent to a broken session. verbose_proxy_logger.warning( "id_token decoded to an empty payload — treating as failure." ) userinfo = None except Exception as decode_err: verbose_proxy_logger.error("Failed to decode id_token: %s", decode_err) raise ProxyException( message=f"Failed to decode id_token JWT: {decode_err}", type=ProxyErrorTypes.auth_error, param="userinfo", code=status.HTTP_401_UNAUTHORIZED, ) if userinfo is None: id_token_attempted = isinstance(id_token, str) and bool(id_token) if userinfo_endpoint: if id_token_attempted: detail = ( "userinfo endpoint failed and id_token was present but " "decoded to an empty payload — no identity claims available" ) else: detail = "userinfo endpoint failed and no id_token was present in the token response" else: if id_token_attempted: detail = ( "no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) " "and id_token decoded to an empty payload — no identity claims available" ) else: detail = "no userinfo endpoint is configured (GENERIC_USERINFO_ENDPOINT) and no id_token was present" raise ProxyException( message=f"SSO user info unavailable: {detail}.", type=ProxyErrorTypes.auth_error, param="userinfo", code=status.HTTP_401_UNAUTHORIZED, ) return userinfo class MicrosoftSSOHandler: """ Handles Microsoft SSO callback response and returns a CustomOpenID object """ graph_api_base_url = "https://graph.microsoft.com/v1.0" graph_api_user_groups_endpoint = f"{graph_api_base_url}/me/memberOf" """ Constants """ MAX_GRAPH_API_PAGES = 200 # used for debugging to show the user groups litellm found from Graph API GRAPH_API_RESPONSE_KEY = "graph_api_user_groups" @staticmethod async def get_microsoft_callback_response( request: Request, microsoft_client_id: str, redirect_url: str, return_raw_sso_response: bool = False, ) -> Union[CustomOpenID, OpenID, dict]: """ Get the Microsoft SSO callback response Args: return_raw_sso_response: If True, return the raw SSO response """ microsoft_client_secret = os.getenv("MICROSOFT_CLIENT_SECRET", None) microsoft_tenant = os.getenv("MICROSOFT_TENANT", None) if microsoft_client_secret is None: raise ProxyException( message="MICROSOFT_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="MICROSOFT_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) if microsoft_tenant is None: raise ProxyException( message="MICROSOFT_TENANT not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="MICROSOFT_TENANT", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) microsoft_sso = CustomMicrosoftSSO( client_id=microsoft_client_id, client_secret=microsoft_client_secret, tenant=microsoft_tenant, redirect_uri=redirect_url, allow_insecure_http=True, ) original_msft_result = ( await microsoft_sso.verify_and_process( request=request, convert_response=False, # type: ignore ) or {} ) user_team_ids = await MicrosoftSSOHandler.get_user_groups_from_graph_api( access_token=microsoft_sso.access_token ) # Extract app roles from the id_token JWT app_roles = MicrosoftSSOHandler.get_app_roles_from_id_token( id_token=microsoft_sso.id_token ) verbose_proxy_logger.debug(f"Extracted app roles from id_token: {app_roles}") # Combine groups and app roles user_role: Optional[LitellmUserRoles] = None if app_roles: # Check if any app role is a valid LitellmUserRoles for role_str in app_roles: role = get_litellm_user_role(role_str) if role is not None: user_role = role verbose_proxy_logger.debug( f"Found valid LitellmUserRoles '{role.value}' in app_roles" ) break verbose_proxy_logger.debug( f"Combined team_ids (groups + app roles): {user_team_ids}" ) # if user is trying to get the raw sso response for debugging, return the raw sso response if return_raw_sso_response: original_msft_result[ MicrosoftSSOHandler.GRAPH_API_RESPONSE_KEY ] = user_team_ids original_msft_result["app_roles"] = app_roles return original_msft_result or {} result = MicrosoftSSOHandler.openid_from_response( response=original_msft_result, team_ids=user_team_ids, user_role=user_role, ) return result @staticmethod def openid_from_response( response: Optional[dict], team_ids: List[str], user_role: Optional[LitellmUserRoles], ) -> CustomOpenID: response = response or {} verbose_proxy_logger.debug(f"Microsoft SSO Callback Response: {response}") openid_response = CustomOpenID( email=normalize_email( response.get(MICROSOFT_USER_EMAIL_ATTRIBUTE) or response.get("mail") ), display_name=response.get(MICROSOFT_USER_DISPLAY_NAME_ATTRIBUTE), provider="microsoft", id=response.get(MICROSOFT_USER_ID_ATTRIBUTE), first_name=response.get(MICROSOFT_USER_FIRST_NAME_ATTRIBUTE), last_name=response.get(MICROSOFT_USER_LAST_NAME_ATTRIBUTE), team_ids=team_ids, user_role=user_role, ) verbose_proxy_logger.debug(f"Microsoft SSO OpenID Response: {openid_response}") return openid_response @staticmethod def get_app_roles_from_id_token(id_token: Optional[str]) -> List[str]: """ Extract app roles from the Microsoft Entra ID (Azure AD) id_token JWT. App roles are assigned in the Azure AD Enterprise Application and appear in the 'app_roles' claim of the id_token. Args: id_token (Optional[str]): The JWT id_token from Microsoft SSO Returns: List[str]: List of app role names assigned to the user """ if not id_token: verbose_proxy_logger.debug("No id_token provided for app role extraction") return [] try: import jwt # Decode the JWT without signature verification # (signature is already verified by fastapi_sso) decoded_token = jwt.decode(id_token, options={"verify_signature": False}) # Extract app_roles claim from the token ## check for both 'roles' and 'app_roles' claims roles = decoded_token.get("app_roles", []) or decoded_token.get("roles", []) if roles and isinstance(roles, list): verbose_proxy_logger.debug( f"Found {len(roles)} app role(s) in id_token: {roles}" ) return roles else: verbose_proxy_logger.debug( "No app roles found in id_token or roles claim is not a list" ) return [] except Exception as e: verbose_proxy_logger.error(f"Error extracting app roles from id_token: {e}") return [] @staticmethod async def get_user_groups_from_graph_api( access_token: Optional[str] = None, ) -> List[str]: """ Returns a list of `team_ids` the user belongs to from the Microsoft Graph API Args: access_token (Optional[str]): Microsoft Graph API access token Returns: List[str]: List of group IDs the user belongs to """ try: async_client = get_async_httpx_client( llm_provider=httpxSpecialProvider.SSO_HANDLER ) # Handle MSFT Enterprise Application Groups service_principal_id = os.getenv("MICROSOFT_SERVICE_PRINCIPAL_ID", None) service_principal_group_ids: Optional[List[str]] = [] service_principal_teams: Optional[List[MicrosoftServicePrincipalTeam]] = [] if service_principal_id: ( service_principal_group_ids, service_principal_teams, ) = await MicrosoftSSOHandler.get_group_ids_from_service_principal( service_principal_id=service_principal_id, async_client=async_client, access_token=access_token, ) verbose_proxy_logger.debug( f"Service principal group IDs: {service_principal_group_ids}" ) if len(service_principal_group_ids) > 0: await MicrosoftSSOHandler.create_litellm_teams_from_service_principal_team_ids( service_principal_teams=service_principal_teams, ) # Fetch user membership from Microsoft Graph API all_group_ids = [] next_link: Optional[ str ] = MicrosoftSSOHandler.graph_api_user_groups_endpoint auth_headers = {"Authorization": f"Bearer {access_token}"} page_count = 0 while ( next_link is not None and page_count < MicrosoftSSOHandler.MAX_GRAPH_API_PAGES ): group_ids, next_link = await MicrosoftSSOHandler.fetch_and_parse_groups( url=next_link, headers=auth_headers, async_client=async_client ) all_group_ids.extend(group_ids) page_count += 1 if ( next_link is not None and page_count >= MicrosoftSSOHandler.MAX_GRAPH_API_PAGES ): verbose_proxy_logger.warning( f"Reached maximum page limit of {MicrosoftSSOHandler.MAX_GRAPH_API_PAGES}. Some groups may not be included." ) # If service_principal_group_ids is not empty, only return group_ids that are in both all_group_ids and service_principal_group_ids if service_principal_group_ids and len(service_principal_group_ids) > 0: all_group_ids = [ group_id for group_id in all_group_ids if group_id in service_principal_group_ids ] return all_group_ids except Exception as e: verbose_proxy_logger.error( f"Error getting user groups from Microsoft Graph API: {e}" ) return [] @staticmethod async def fetch_and_parse_groups( url: str, headers: dict, async_client: AsyncHTTPHandler ) -> Tuple[List[str], Optional[str]]: """Helper function to fetch and parse group data from a URL""" response = await async_client.get(url, headers=headers) response_json = response.json() response_typed = await MicrosoftSSOHandler._cast_graph_api_response_dict( response=response_json ) group_ids = MicrosoftSSOHandler._get_group_ids_from_graph_api_response( response=response_typed ) return group_ids, response_typed.get("odata_nextLink") @staticmethod def _get_group_ids_from_graph_api_response( response: MicrosoftGraphAPIUserGroupResponse, ) -> List[str]: group_ids = [] for _object in response.get("value", []) or []: _group_id = _object.get("id") if _group_id is not None: group_ids.append(_group_id) return group_ids @staticmethod async def _cast_graph_api_response_dict( response: dict, ) -> MicrosoftGraphAPIUserGroupResponse: directory_objects: List[MicrosoftGraphAPIUserGroupDirectoryObject] = [] for _object in response.get("value", []): directory_objects.append( MicrosoftGraphAPIUserGroupDirectoryObject( odata_type=_object.get("@odata.type"), id=_object.get("id"), deletedDateTime=_object.get("deletedDateTime"), description=_object.get("description"), displayName=_object.get("displayName"), roleTemplateId=_object.get("roleTemplateId"), ) ) return MicrosoftGraphAPIUserGroupResponse( odata_context=response.get("@odata.context"), odata_nextLink=response.get("@odata.nextLink"), value=directory_objects, ) @staticmethod async def get_group_ids_from_service_principal( service_principal_id: str, async_client: AsyncHTTPHandler, access_token: Optional[str] = None, ) -> Tuple[List[str], List[MicrosoftServicePrincipalTeam]]: """ Gets the groups belonging to the Service Principal Application Service Principal Id is an `Enterprise Application` in Azure AD Users use Enterprise Applications to manage Groups and Users on Microsoft Entra ID """ base_url = "https://graph.microsoft.com/v1.0" # Endpoint to get app role assignments for the given service principal endpoint = f"/servicePrincipals/{service_principal_id}/appRoleAssignedTo" url = base_url + endpoint headers = { "Authorization": f"Bearer {access_token}", "Content-Type": "application/json", } response = await async_client.get(url, headers=headers) response_json = response.json() verbose_proxy_logger.debug( f"Response from service principal app role assigned to: {response_json}" ) group_ids: List[str] = [] service_principal_teams: List[MicrosoftServicePrincipalTeam] = [] for _object in response_json.get("value", []): if _object.get("principalType") == "Group": # Append the group ID to the list group_ids.append(_object.get("principalId")) # Append the service principal team to the list service_principal_teams.append( MicrosoftServicePrincipalTeam( principalDisplayName=_object.get("principalDisplayName"), principalId=_object.get("principalId"), ) ) return group_ids, service_principal_teams @staticmethod async def create_litellm_teams_from_service_principal_team_ids( service_principal_teams: List[MicrosoftServicePrincipalTeam], ): """ Creates Litellm Teams from the Service Principal Group IDs When a user sets a `SERVICE_PRINCIPAL_ID` in the env, litellm will fetch groups under that service principal and create Litellm Teams from them """ verbose_proxy_logger.debug( f"Creating Litellm Teams from Service Principal Teams: {service_principal_teams}" ) for service_principal_team in service_principal_teams: litellm_team_id: Optional[str] = service_principal_team.get("principalId") litellm_team_name: Optional[str] = service_principal_team.get( "principalDisplayName" ) if not litellm_team_id: verbose_proxy_logger.debug( f"Skipping team creation for {litellm_team_name} because it has no principalId" ) continue await SSOAuthenticationHandler.create_litellm_team_from_sso_group( litellm_team_id=litellm_team_id, litellm_team_name=litellm_team_name, ) class GoogleSSOHandler: """ Handles Google SSO callback response and returns a CustomOpenID object """ @staticmethod async def get_google_callback_response( request: Request, google_client_id: str, redirect_url: str, return_raw_sso_response: bool = False, ) -> Union[OpenID, dict]: """ Get the Google SSO callback response Args: return_raw_sso_response: If True, return the raw SSO response """ from fastapi_sso.sso.google import GoogleSSO google_client_secret = os.getenv("GOOGLE_CLIENT_SECRET", None) if google_client_secret is None: raise ProxyException( message="GOOGLE_CLIENT_SECRET not set. Set it in .env file", type=ProxyErrorTypes.auth_error, param="GOOGLE_CLIENT_SECRET", code=status.HTTP_500_INTERNAL_SERVER_ERROR, ) google_sso = GoogleSSO( client_id=google_client_id, redirect_uri=redirect_url, client_secret=google_client_secret, ) # if user is trying to get the raw sso response for debugging, return the raw sso response if return_raw_sso_response: return ( await google_sso.verify_and_process( request=request, convert_response=False, # type: ignore ) or {} ) result = await google_sso.verify_and_process(request) return result or {} @router.get("/sso/debug/login", tags=["experimental"], include_in_schema=False) async def debug_sso_login(request: Request): """ Create Proxy API Keys using Google Workspace SSO. Requires setting PROXY_BASE_URL in .env PROXY_BASE_URL should be the your deployed proxy endpoint, e.g. PROXY_BASE_URL="https://litellm-production-7002.up.railway.app/" Example: """ from litellm.proxy.proxy_server import premium_user microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) ####### Check if user is a Enterprise / Premium User ####### if ( microsoft_client_id is not None or google_client_id is not None or generic_client_id is not None ): if premium_user is not True: raise ProxyException( message="You must be a LiteLLM Enterprise user to use SSO. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions You are seeing this error message because You set one of `MICROSOFT_CLIENT_ID`, `GOOGLE_CLIENT_ID`, or `GENERIC_CLIENT_ID` in your env. Please unset this", type=ProxyErrorTypes.auth_error, param="premium_user", code=status.HTTP_403_FORBIDDEN, ) # get url from request redirect_url = SSOAuthenticationHandler.get_redirect_url_for_sso( request=request, sso_callback_route="sso/debug/callback", ) # Check if we should use SSO handler if ( SSOAuthenticationHandler.should_use_sso_handler( microsoft_client_id=microsoft_client_id, google_client_id=google_client_id, generic_client_id=generic_client_id, ) is True ): return await SSOAuthenticationHandler.get_sso_login_redirect( redirect_url=redirect_url, microsoft_client_id=microsoft_client_id, google_client_id=google_client_id, generic_client_id=generic_client_id, ) @router.get("/sso/debug/callback", tags=["experimental"], include_in_schema=False) async def debug_sso_callback(request: Request): """ Returns the OpenID object returned by the SSO provider """ import json from fastapi.responses import HTMLResponse from litellm.proxy._types import LiteLLM_JWTAuth from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.proxy_server import ( general_settings, jwt_handler, prisma_client, user_api_key_cache, ) sso_jwt_handler: Optional[JWTHandler] = None ui_access_mode = general_settings.get("ui_access_mode", None) if ui_access_mode is not None and isinstance(ui_access_mode, dict): sso_jwt_handler = JWTHandler() sso_jwt_handler.update_environment( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, litellm_jwtauth=LiteLLM_JWTAuth( team_ids_jwt_field=general_settings.get("ui_access_mode", {}).get( "sso_group_jwt_field", None ), ), leeway=0, ) microsoft_client_id = os.getenv("MICROSOFT_CLIENT_ID", None) google_client_id = os.getenv("GOOGLE_CLIENT_ID", None) generic_client_id = os.getenv("GENERIC_CLIENT_ID", None) redirect_url = os.getenv("PROXY_BASE_URL", str(request.base_url)) if redirect_url.endswith("/"): redirect_url += "sso/debug/callback" else: redirect_url += "/sso/debug/callback" result = None if google_client_id is not None: result = await GoogleSSOHandler.get_google_callback_response( request=request, google_client_id=google_client_id, redirect_url=redirect_url, return_raw_sso_response=True, ) elif microsoft_client_id is not None: result = await MicrosoftSSOHandler.get_microsoft_callback_response( request=request, microsoft_client_id=microsoft_client_id, redirect_url=redirect_url, return_raw_sso_response=True, ) elif generic_client_id is not None: result, _ = await get_generic_sso_response( request=request, jwt_handler=jwt_handler, generic_client_id=generic_client_id, redirect_url=redirect_url, sso_jwt_handler=sso_jwt_handler, ) # If result is None, return a basic error message if result is None: return HTMLResponse( content="
No data was returned from the SSO provider.
", status_code=400, ) # Convert the OpenID object to a dictionary if hasattr(result, "__dict__"): result_dict = result.__dict__ else: result_dict = dict(result) # Filter out any None values and convert to JSON serializable format filtered_result = {} for key, value in result_dict.items(): if value is not None and not key.startswith("_"): if isinstance(value, (str, int, float, bool)) or value is None: filtered_result[key] = value else: try: # Try to convert to string or another JSON serializable format filtered_result[key] = str(value) except Exception as e: filtered_result[key] = f"Complex value (not displayable): {str(e)}" # Replace the placeholder in the template with the actual data html_content = jwt_display_template.replace( "const userData = SSO_DATA;", f"const userData = {json.dumps(filtered_result, indent=2)};", ) return HTMLResponse(content=html_content)