import asyncio import copy import time from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import Request from starlette.datastructures import Headers import litellm from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._service_logger import ServiceLogging from litellm.litellm_core_utils.safe_json_loads import safe_json_loads from litellm.proxy._types import ( AddTeamCallback, CommonProxyErrors, LitellmDataForBackendLLMCall, LitellmUserRoles, SpecialHeaders, TeamCallbackMetadata, UserAPIKeyAuth, ) from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers # Cache special headers as a frozenset for O(1) lookup performance _SPECIAL_HEADERS_CACHE = frozenset( v.value.lower() for v in SpecialHeaders._member_map_.values() ) from litellm.router import Router from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS from litellm.types.services import ServiceTypes from litellm.types.utils import ( LlmProviders, ProviderSpecificHeader, StandardLoggingUserAPIKeyMetadata, SupportedCacheControls, ) service_logger_obj = ServiceLogging() # used for tracking latency on OTEL if TYPE_CHECKING: from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig from litellm.types.proxy.policy_engine import PolicyMatchContext ProxyConfig = _ProxyConfig else: ProxyConfig = Any PolicyMatchContext = Any def parse_cache_control(cache_control): cache_dict = {} directives = cache_control.split(", ") for directive in directives: if "=" in directive: key, value = directive.split("=") cache_dict[key] = value else: cache_dict[directive] = True return cache_dict LITELLM_METADATA_ROUTES = ( "batches", "/v1/messages", "responses", "files", ) def _get_metadata_variable_name(request: Request) -> str: """ Helper to return what the "metadata" field should be called in the request data For all /thread or /assistant endpoints we need to call this "litellm_metadata" For ALL other endpoints we call this "metadata" """ path = request.url.path if "thread" in path or "assistant" in path: return "litellm_metadata" if any(route in path for route in LITELLM_METADATA_ROUTES): return "litellm_metadata" return "metadata" def get_chain_id_from_headers(headers: Optional[Dict[str, str]]) -> Optional[str]: """ Extract chain id for call chaining from request headers. x-litellm-trace-id and x-litellm-session-id are interchangeable; when both are present, x-litellm-trace-id takes precedence. Header keys are matched case-insensitively so this works with raw header dicts from any transport. Used by MCP (and other paths that have raw_headers but no Request) to set litellm_trace_id/litellm_session_id for spend logs and logging consistency. """ if not headers: return None normalized = {k.lower(): v for k, v in headers.items() if isinstance(k, str)} return normalized.get("x-litellm-trace-id") or normalized.get( "x-litellm-session-id" ) def safe_add_api_version_from_query_params(data: dict, request: Request): try: if hasattr(request, "query_params"): query_params = dict(request.query_params) if "api-version" in query_params: data["api_version"] = query_params["api-version"] except KeyError: pass except Exception as e: verbose_logger.exception( "error checking api version in query params: %s", str(e) ) def convert_key_logging_metadata_to_callback( data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata] ) -> TeamCallbackMetadata: if team_callback_settings_obj is None: team_callback_settings_obj = TeamCallbackMetadata() if data.callback_type == "success": if team_callback_settings_obj.success_callback is None: team_callback_settings_obj.success_callback = [] if data.callback_name not in team_callback_settings_obj.success_callback: team_callback_settings_obj.success_callback.append(data.callback_name) elif data.callback_type == "failure": if team_callback_settings_obj.failure_callback is None: team_callback_settings_obj.failure_callback = [] if data.callback_name not in team_callback_settings_obj.failure_callback: team_callback_settings_obj.failure_callback.append(data.callback_name) elif ( not data.callback_type or data.callback_type == "success_and_failure" ): # assume 'success_and_failure' = litellm.callbacks if team_callback_settings_obj.success_callback is None: team_callback_settings_obj.success_callback = [] if team_callback_settings_obj.failure_callback is None: team_callback_settings_obj.failure_callback = [] if team_callback_settings_obj.callbacks is None: team_callback_settings_obj.callbacks = [] if data.callback_name not in team_callback_settings_obj.success_callback: team_callback_settings_obj.success_callback.append(data.callback_name) if data.callback_name not in team_callback_settings_obj.failure_callback: team_callback_settings_obj.failure_callback.append(data.callback_name) if data.callback_name not in team_callback_settings_obj.callbacks: team_callback_settings_obj.callbacks.append(data.callback_name) for var, value in data.callback_vars.items(): if team_callback_settings_obj.callback_vars is None: team_callback_settings_obj.callback_vars = {} team_callback_settings_obj.callback_vars[var] = str( litellm.utils.get_secret(value, default_value=value) or value ) return team_callback_settings_obj class KeyAndTeamLoggingSettings: """ Helper class to get the dynamic logging settings for the key and team """ @staticmethod def get_key_dynamic_logging_settings(user_api_key_dict: UserAPIKeyAuth): if ( user_api_key_dict.metadata is not None and "logging" in user_api_key_dict.metadata ): return user_api_key_dict.metadata["logging"] return None @staticmethod def get_team_dynamic_logging_settings(user_api_key_dict: UserAPIKeyAuth): if ( user_api_key_dict.team_metadata is not None and "logging" in user_api_key_dict.team_metadata ): return user_api_key_dict.team_metadata["logging"] return None def _get_dynamic_logging_metadata( user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig ) -> Optional[TeamCallbackMetadata]: callback_settings_obj: Optional[TeamCallbackMetadata] = None key_dynamic_logging_settings: Optional[ dict ] = KeyAndTeamLoggingSettings.get_key_dynamic_logging_settings(user_api_key_dict) team_dynamic_logging_settings: Optional[ dict ] = KeyAndTeamLoggingSettings.get_team_dynamic_logging_settings(user_api_key_dict) ######################################################################################### # Key-based callbacks ######################################################################################### if key_dynamic_logging_settings is not None: for item in key_dynamic_logging_settings: callback_settings_obj = convert_key_logging_metadata_to_callback( data=AddTeamCallback(**item), team_callback_settings_obj=callback_settings_obj, ) ######################################################################################### # Team-based callbacks ######################################################################################### elif team_dynamic_logging_settings is not None: for item in team_dynamic_logging_settings: callback_settings_obj = convert_key_logging_metadata_to_callback( data=AddTeamCallback(**item), team_callback_settings_obj=callback_settings_obj, ) ######################################################################################### # Deprecated format - maintained for backwards compatibility ######################################################################################### elif ( user_api_key_dict.team_metadata is not None and "callback_settings" in user_api_key_dict.team_metadata ): """ callback_settings = { { 'callback_vars': {'langfuse_public_key': 'pk', 'langfuse_secret_key': 'sk_'}, 'failure_callback': [], 'success_callback': ['langfuse', 'langfuse'] } } """ team_metadata = user_api_key_dict.team_metadata callback_settings = team_metadata.get("callback_settings", None) or {} callback_settings_obj = TeamCallbackMetadata(**callback_settings) verbose_proxy_logger.debug( "Team callback settings activated: %s", callback_settings_obj ) ######################################################################################### # Enter here when configured on the config.yaml file. ######################################################################################### elif user_api_key_dict.team_id is not None: callback_settings_obj = ( LiteLLMProxyRequestSetup.add_team_based_callbacks_from_config( team_id=user_api_key_dict.team_id, proxy_config=proxy_config ) ) return callback_settings_obj def clean_headers( headers: Headers, litellm_key_header_name: Optional[str] = None, forward_llm_provider_auth_headers: bool = False, authenticated_with_header: Optional[str] = None, ) -> dict: """ Removes litellm api key from headers Args: headers: Request headers litellm_key_header_name: Custom header name for LiteLLM API key forward_llm_provider_auth_headers: Whether to forward provider auth headers authenticated_with_header: Which header was used for LiteLLM authentication (e.g., "x-litellm-api-key", "authorization", "x-api-key") Returns: Cleaned headers dict """ from litellm.llms.anthropic.common_utils import is_anthropic_oauth_key clean_headers = {} litellm_key_lower = ( litellm_key_header_name.lower() if litellm_key_header_name is not None else None ) for header, value in headers.items(): header_lower = header.lower() if header_lower == "authorization" and is_anthropic_oauth_key(value): if ( authenticated_with_header is None or authenticated_with_header.lower() != "authorization" ): clean_headers[header] = value continue # Special handling for x-api-key: forward it based on authenticated_with_header elif header_lower == "x-api-key": if forward_llm_provider_auth_headers and ( authenticated_with_header is None or authenticated_with_header.lower() != "x-api-key" ): clean_headers[header] = value elif ( forward_llm_provider_auth_headers and header_lower in _SPECIAL_HEADERS_CACHE ): if litellm_key_lower and header_lower == litellm_key_lower: continue if header_lower == "authorization": continue # Never forward x-litellm-api-key (it's for proxy auth only) if header_lower == "x-litellm-api-key": continue clean_headers[header] = value # Check if header should be excluded: either in special headers cache or matches custom litellm key elif header_lower not in _SPECIAL_HEADERS_CACHE and ( litellm_key_lower is None or header_lower != litellm_key_lower ): clean_headers[header] = value return clean_headers class LiteLLMProxyRequestSetup: @staticmethod def _get_timeout_from_request(headers: dict) -> Optional[float]: """ Workaround for client request from Vercel's AI SDK. Allow's user to set a timeout in the request headers. Example: ```js const openaiProvider = createOpenAI({ baseURL: liteLLM.baseURL, apiKey: liteLLM.apiKey, compatibility: "compatible", headers: { "x-litellm-timeout": "90" }, }); ``` """ timeout_header = headers.get("x-litellm-timeout", None) if timeout_header is not None: return float(timeout_header) return None @staticmethod def _get_stream_timeout_from_request(headers: dict) -> Optional[float]: """ Get the `stream_timeout` from the request headers. """ stream_timeout_header = headers.get("x-litellm-stream-timeout", None) if stream_timeout_header is not None: return float(stream_timeout_header) return None @staticmethod def _get_num_retries_from_request(headers: dict) -> Optional[int]: """ Workaround for client request from Vercel's AI SDK. """ num_retries_header = headers.get("x-litellm-num-retries", None) if num_retries_header is not None: return int(num_retries_header) return None @staticmethod def _get_spend_logs_metadata_from_request_headers(headers: dict) -> Optional[dict]: """ Get the `spend_logs_metadata` from the request headers. """ from litellm.litellm_core_utils.safe_json_loads import safe_json_loads spend_logs_metadata_header = headers.get("x-litellm-spend-logs-metadata", None) if spend_logs_metadata_header is not None: return safe_json_loads(spend_logs_metadata_header) return None @staticmethod def _get_forwardable_headers( headers: Union[Headers, dict], ): """ Get the headers that should be forwarded to the LLM Provider. Looks for any `x-` headers and sends them to the LLM Provider. [07/09/2025] - Support 'anthropic-beta' header as well. """ forwarded_headers = {} for header, value in headers.items(): if header.lower().startswith("x-") and not header.lower().startswith( "x-stainless" ): # causes openai sdk to fail forwarded_headers[header] = value elif header.lower().startswith("anthropic-beta"): forwarded_headers[header] = value return forwarded_headers @staticmethod def _get_case_insensitive_header(headers: dict, key: str) -> Optional[str]: """ Get a case-insensitive header from the headers dictionary. """ for header, value in headers.items(): if header.lower() == key.lower(): return value return None @staticmethod def add_internal_user_from_user_mapping( general_settings: Optional[Dict], user_api_key_dict: UserAPIKeyAuth, headers: dict, ) -> UserAPIKeyAuth: if general_settings is None: return user_api_key_dict user_header_mapping = general_settings.get("user_header_mappings") if not user_header_mapping: return user_api_key_dict header_name = LiteLLMProxyRequestSetup.get_internal_user_header_from_mapping( user_header_mapping ) if not header_name: return user_api_key_dict header_value = LiteLLMProxyRequestSetup._get_case_insensitive_header( headers, header_name ) if header_value: user_api_key_dict.user_id = header_value return user_api_key_dict return user_api_key_dict @staticmethod def get_user_from_headers( headers: dict, general_settings: Optional[Dict] = None ) -> Optional[str]: """ Get the user from the specified header if `general_settings.user_header_name` is set. """ if general_settings is None: return None header_name = general_settings.get("user_header_name") if header_name is None or header_name == "": return None if not isinstance(header_name, str): raise TypeError( f"Expected user_header_name to be a str but got {type(header_name)}" ) user = LiteLLMProxyRequestSetup._get_case_insensitive_header( headers, header_name ) if user is not None: verbose_logger.info(f'found user "{user}" in header "{header_name}"') return user @staticmethod def get_openai_org_id_from_headers( headers: dict, general_settings: Optional[Dict] = None ) -> Optional[str]: """ Get the OpenAI Org ID from the headers. """ if ( general_settings is not None and general_settings.get("forward_openai_org_id") is not True ): return None for header, value in headers.items(): if header.lower() == "openai-organization": verbose_logger.info(f"found openai org id: {value}, sending to llm") return value return None @staticmethod def add_headers_to_llm_call( headers: dict, user_api_key_dict: UserAPIKeyAuth ) -> dict: """ Add headers to the LLM call - Checks request headers for forwardable headers - Checks if user information should be added to the headers """ returned_headers = LiteLLMProxyRequestSetup._get_forwardable_headers(headers) if litellm.add_user_information_to_llm_headers is True: litellm_logging_metadata_headers = ( LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( user_api_key_dict=user_api_key_dict ) ) for k, v in litellm_logging_metadata_headers.items(): if v is not None: returned_headers["x-litellm-{}".format(k)] = v return returned_headers @staticmethod def add_headers_to_llm_call_by_model_group( data: dict, headers: dict, user_api_key_dict: UserAPIKeyAuth ) -> dict: """ Add headers to the LLM call by model group """ from litellm.proxy.auth.auth_checks import _check_model_access_helper from litellm.proxy.proxy_server import llm_router data_model = data.get("model") if ( data_model is not None and litellm.model_group_settings is not None and litellm.model_group_settings.forward_client_headers_to_llm_api is not None and _check_model_access_helper( model=data_model, llm_router=llm_router, models=litellm.model_group_settings.forward_client_headers_to_llm_api, team_model_aliases=user_api_key_dict.team_model_aliases, team_id=user_api_key_dict.team_id, ) # handles aliases, wildcards, etc. ): _headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call( headers, user_api_key_dict ) if _headers != {}: data["headers"] = _headers return data @staticmethod def get_internal_user_header_from_mapping(user_header_mapping) -> Optional[str]: if not user_header_mapping: return None items = ( user_header_mapping if isinstance(user_header_mapping, list) else [user_header_mapping] ) for item in items: if not isinstance(item, dict): continue role = item.get("litellm_user_role") header_name = item.get("header_name") if role is None or not header_name: continue if str(role).lower() == str(LitellmUserRoles.INTERNAL_USER).lower(): return header_name return None @staticmethod def add_litellm_data_for_backend_llm_call( *, headers: dict, user_api_key_dict: UserAPIKeyAuth, general_settings: Optional[Dict[str, Any]] = None, ) -> LitellmDataForBackendLLMCall: """ - Adds user from headers - Adds forwardable headers - Adds org id """ data = LitellmDataForBackendLLMCall() if ( general_settings and general_settings.get("forward_client_headers_to_llm_api") is True ): _headers = LiteLLMProxyRequestSetup.add_headers_to_llm_call( headers, user_api_key_dict ) if _headers != {}: data["headers"] = _headers _organization = LiteLLMProxyRequestSetup.get_openai_org_id_from_headers( headers, general_settings ) if _organization is not None: data["organization"] = _organization timeout = LiteLLMProxyRequestSetup._get_timeout_from_request(headers) if timeout is not None: data["timeout"] = timeout stream_timeout = LiteLLMProxyRequestSetup._get_stream_timeout_from_request( headers ) if stream_timeout is not None: data["stream_timeout"] = stream_timeout num_retries = LiteLLMProxyRequestSetup._get_num_retries_from_request(headers) if num_retries is not None: data["num_retries"] = num_retries return data @staticmethod def add_litellm_metadata_from_request_headers( headers: dict, data: dict, _metadata_variable_name: str, ) -> dict: """ Add litellm metadata from request headers Relevant issue: https://github.com/BerriAI/litellm/issues/14008 """ from litellm.proxy._types import LitellmMetadataFromRequestHeaders metadata_from_headers = LitellmMetadataFromRequestHeaders() spend_logs_metadata = ( LiteLLMProxyRequestSetup._get_spend_logs_metadata_from_request_headers( headers ) ) if spend_logs_metadata is not None: metadata_from_headers["spend_logs_metadata"] = spend_logs_metadata ######################################################################################### # Finally update the requests metadata with the `metadata_from_headers` ######################################################################################### agent_id_from_header = headers.get("x-litellm-agent-id") # x-litellm-trace-id and x-litellm-session-id are interchangeable for call chaining chain_id = headers.get("x-litellm-trace-id") or headers.get( "x-litellm-session-id" ) if agent_id_from_header: metadata_from_headers["agent_id"] = agent_id_from_header verbose_proxy_logger.debug( f"Extracted agent_id from header: {agent_id_from_header}" ) if chain_id: metadata_from_headers["trace_id"] = chain_id metadata_from_headers["session_id"] = chain_id data["litellm_session_id"] = chain_id data["litellm_trace_id"] = chain_id verbose_proxy_logger.debug( f"Extracted chain_id from header (trace-id/session-id): {chain_id}" ) if isinstance(data[_metadata_variable_name], dict): data[_metadata_variable_name].update(metadata_from_headers) return data @staticmethod def get_sanitized_user_information_from_key( user_api_key_dict: UserAPIKeyAuth, ) -> StandardLoggingUserAPIKeyMetadata: user_api_key_logged_metadata = StandardLoggingUserAPIKeyMetadata( user_api_key_hash=user_api_key_dict.api_key, # just the hashed token user_api_key_alias=user_api_key_dict.key_alias, user_api_key_spend=user_api_key_dict.spend, user_api_key_max_budget=user_api_key_dict.max_budget, user_api_key_team_id=user_api_key_dict.team_id, user_api_key_project_id=user_api_key_dict.project_id, user_api_key_user_id=user_api_key_dict.user_id, user_api_key_org_id=user_api_key_dict.org_id, user_api_key_team_alias=user_api_key_dict.team_alias, user_api_key_end_user_id=user_api_key_dict.end_user_id, user_api_key_user_email=user_api_key_dict.user_email, user_api_key_request_route=user_api_key_dict.request_route, user_api_key_budget_reset_at=( user_api_key_dict.budget_reset_at.isoformat() if user_api_key_dict.budget_reset_at else None ), user_api_key_auth_metadata=user_api_key_dict.metadata, ) return user_api_key_logged_metadata @staticmethod def add_user_api_key_auth_to_request_metadata( data: dict, user_api_key_dict: UserAPIKeyAuth, _metadata_variable_name: str, ) -> dict: """ Adds the `UserAPIKeyAuth` object to the request metadata. """ user_api_key_logged_metadata = ( LiteLLMProxyRequestSetup.get_sanitized_user_information_from_key( user_api_key_dict=user_api_key_dict ) ) data[_metadata_variable_name].update(user_api_key_logged_metadata) data[_metadata_variable_name][ "user_api_key" ] = user_api_key_dict.api_key # this is just the hashed token # Key-owned agent_id for spend attribution; keep existing (e.g. from header) if key has none _key_agent_id = getattr(user_api_key_dict, "agent_id", None) _existing_agent_id = data[_metadata_variable_name].get("agent_id") _resolved_agent_id = _key_agent_id or _existing_agent_id data[_metadata_variable_name]["agent_id"] = _resolved_agent_id data[_metadata_variable_name]["user_api_end_user_max_budget"] = getattr( user_api_key_dict, "end_user_max_budget", None ) # Add the full UserAPIKeyAuth object for MCP server access control data[_metadata_variable_name]["user_api_key_auth"] = user_api_key_dict return data @staticmethod def add_management_endpoint_metadata_to_request_metadata( data: dict, management_endpoint_metadata: dict, _metadata_variable_name: str, ) -> dict: """ Adds the `UserAPIKeyAuth` metadata to the request metadata. ignore any sensitive fields like logging, api_key, etc. """ if _metadata_variable_name not in data: return data from litellm.proxy._types import ( LiteLLM_ManagementEndpoint_MetadataFields, LiteLLM_ManagementEndpoint_MetadataFields_Premium, ) # ignore any special fields added_metadata = {} for k, v in management_endpoint_metadata.items(): if k not in ( LiteLLM_ManagementEndpoint_MetadataFields_Premium + LiteLLM_ManagementEndpoint_MetadataFields ): added_metadata[k] = v if data[_metadata_variable_name].get("user_api_key_auth_metadata") is None: data[_metadata_variable_name]["user_api_key_auth_metadata"] = {} data[_metadata_variable_name]["user_api_key_auth_metadata"].update( added_metadata ) return data @staticmethod def add_key_level_controls( key_metadata: Optional[dict], data: dict, _metadata_variable_name: str ): if key_metadata is None: return data if "cache" in key_metadata: data["cache"] = {} if isinstance(key_metadata["cache"], dict): for k, v in key_metadata["cache"].items(): if k in SupportedCacheControls: data["cache"][k] = v ## KEY-LEVEL SPEND LOGS / TAGS if "tags" in key_metadata and key_metadata["tags"] is not None: data[_metadata_variable_name][ "tags" ] = LiteLLMProxyRequestSetup._merge_tags( request_tags=data[_metadata_variable_name].get("tags"), tags_to_add=key_metadata["tags"], ) if "disable_global_guardrails" in key_metadata and isinstance( key_metadata["disable_global_guardrails"], bool ): data[_metadata_variable_name]["disable_global_guardrails"] = key_metadata[ "disable_global_guardrails" ] if "spend_logs_metadata" in key_metadata and isinstance( key_metadata["spend_logs_metadata"], dict ): if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( data[_metadata_variable_name]["spend_logs_metadata"], dict ): for key, value in key_metadata["spend_logs_metadata"].items(): if ( key not in data[_metadata_variable_name]["spend_logs_metadata"] ): # don't override k-v pair sent by request (user request) data[_metadata_variable_name]["spend_logs_metadata"][ key ] = value else: data[_metadata_variable_name]["spend_logs_metadata"] = key_metadata[ "spend_logs_metadata" ] ## KEY-LEVEL DISABLE FALLBACKS if "disable_fallbacks" in key_metadata and isinstance( key_metadata["disable_fallbacks"], bool ): data["disable_fallbacks"] = key_metadata["disable_fallbacks"] ## KEY-LEVEL METADATA data = LiteLLMProxyRequestSetup.add_management_endpoint_metadata_to_request_metadata( data=data, management_endpoint_metadata=key_metadata, _metadata_variable_name=_metadata_variable_name, ) return data @staticmethod def _merge_tags(request_tags: Optional[list], tags_to_add: Optional[list]) -> list: """ Helper function to merge two lists of tags, ensuring no duplicates. Args: request_tags (Optional[list]): List of tags from the original request tags_to_add (Optional[list]): List of tags to add Returns: list: Combined list of unique tags """ final_tags = [] if request_tags and isinstance(request_tags, list): final_tags.extend(request_tags) if tags_to_add and isinstance(tags_to_add, list): for tag in tags_to_add: if tag not in final_tags: final_tags.append(tag) return final_tags @staticmethod def add_team_based_callbacks_from_config( team_id: str, proxy_config: ProxyConfig, ) -> Optional[TeamCallbackMetadata]: """ Add team-based callbacks from the config """ team_config = proxy_config.load_team_config(team_id=team_id) if not isinstance(team_config, dict) or len(team_config) == 0: return None callback_vars_dict = {**team_config.get("callback_vars", team_config)} callback_vars_dict.pop("team_id", None) callback_vars_dict.pop("success_callback", None) callback_vars_dict.pop("failure_callback", None) return TeamCallbackMetadata( success_callback=team_config.get("success_callback", None), failure_callback=team_config.get("failure_callback", None), callback_vars=callback_vars_dict, ) @staticmethod def add_request_tag_to_metadata( llm_router: Optional[Router], headers: dict, data: dict, ) -> Optional[List[str]]: tags = None # Check request headers for tags if "x-litellm-tags" in headers: if isinstance(headers["x-litellm-tags"], str): _tags = headers["x-litellm-tags"].split(",") tags = [tag.strip() for tag in _tags] elif isinstance(headers["x-litellm-tags"], list): tags = headers["x-litellm-tags"] # Check request body for tags if "tags" in data and isinstance(data["tags"], list): tags = data["tags"] return tags async def add_litellm_data_to_request( # noqa: PLR0915 data: dict, request: Request, user_api_key_dict: UserAPIKeyAuth, proxy_config: ProxyConfig, general_settings: Optional[Dict[str, Any]] = None, version: Optional[str] = None, ): """ Adds LiteLLM-specific data to the request. Args: data (dict): The data dictionary to be modified. request (Request): The incoming request. user_api_key_dict (UserAPIKeyAuth): The user API key dictionary. general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None. version (Optional[str], optional): Version. Defaults to None. Returns: dict: The modified data dictionary. """ from litellm.proxy.proxy_server import llm_router, premium_user from litellm.types.proxy.litellm_pre_call_utils import RedactedDict, SecretFields _raw_headers: Dict[str, str] = RedactedDict(_safe_get_request_headers(request)) forward_llm_auth = False if general_settings: forward_llm_auth = general_settings.get( "forward_llm_provider_auth_headers", False ) if not forward_llm_auth: forward_llm_auth = getattr(litellm, "forward_llm_provider_auth_headers", False) # Determine which header was used for authentication # This enables forwarding provider keys (e.g., x-api-key) when they weren't used for LiteLLM auth authenticated_with_header = None if "x-litellm-api-key" in request.headers: # If x-litellm-api-key is present, it was used for auth authenticated_with_header = "x-litellm-api-key" elif "authorization" in request.headers: # Authorization header was used for auth authenticated_with_header = "authorization" else: # x-api-key or another header was used for auth authenticated_with_header = "x-api-key" _headers: Dict[str, str] = clean_headers( request.headers, litellm_key_header_name=( general_settings.get("litellm_key_header_name") if general_settings is not None else None ), forward_llm_provider_auth_headers=forward_llm_auth, authenticated_with_header=authenticated_with_header, ) verbose_proxy_logger.debug(f"Request Headers: {_headers}") verbose_proxy_logger.debug(f"Raw Headers: {_raw_headers}") if forward_llm_auth and "x-api-key" in _headers: data["api_key"] = _headers["x-api-key"] verbose_proxy_logger.debug( "Setting client-provided x-api-key as api_key parameter (will override deployment key)" ) ########################################################## # Init - Proxy Server Request # we do this as soon as entering so we track the original request ########################################################## # Track arrival time for queue time metric arrival_time = time.time() data["proxy_server_request"] = { "url": str(request.url), "method": request.method, "headers": _headers, "body": copy.copy(data), # use copy instead of deepcopy "arrival_time": arrival_time, # Track when request arrived at proxy } safe_add_api_version_from_query_params(data, request) _metadata_variable_name = _get_metadata_variable_name(request) if data.get(_metadata_variable_name, None) is None: data[_metadata_variable_name] = {} data.update( LiteLLMProxyRequestSetup.add_litellm_data_for_backend_llm_call( headers=_headers, user_api_key_dict=user_api_key_dict, general_settings=general_settings, ) ) LiteLLMProxyRequestSetup.add_litellm_metadata_from_request_headers( headers=_headers, data=data, _metadata_variable_name=_metadata_variable_name, ) # Add headers to metadata for guardrails to access (fixes #17477) # Guardrails use metadata["headers"] to access request headers (e.g., User-Agent) if _metadata_variable_name in data and isinstance( data[_metadata_variable_name], dict ): data[_metadata_variable_name]["headers"] = _headers # check for forwardable headers data = LiteLLMProxyRequestSetup.add_headers_to_llm_call_by_model_group( data=data, headers=_headers, user_api_key_dict=user_api_key_dict ) user_api_key_dict = LiteLLMProxyRequestSetup.add_internal_user_from_user_mapping( general_settings, user_api_key_dict, _headers ) # Parse user info from headers (fallback to general_settings.user_header_name) user = LiteLLMProxyRequestSetup.get_user_from_headers(_headers, general_settings) if user is not None: if user_api_key_dict.end_user_id is None: user_api_key_dict.end_user_id = user if "user" not in data: data["user"] = user data["secret_fields"] = SecretFields(raw_headers=_raw_headers) ## Dynamic api version (Azure OpenAI endpoints) ## try: query_params = request.query_params # Convert query parameters to a dictionary (optional) query_dict = dict(query_params) except KeyError: query_dict = {} ## check for api version in query params dynamic_api_version: Optional[str] = query_dict.get("api-version") if dynamic_api_version is not None: # only pass, if set data["api_version"] = dynamic_api_version ## Forward any LLM API Provider specific headers in extra_headers add_provider_specific_headers_to_request(data=data, headers=_headers) ## Cache Controls cache_control_header = _headers.get("Cache-Control", None) if cache_control_header: cache_dict = parse_cache_control(cache_control_header) data["ttl"] = cache_dict.get("s-maxage") verbose_proxy_logger.debug("receiving data: %s", data) # Parse metadata if it's a string (e.g., from multipart/form-data) if "metadata" in data and data["metadata"] is not None: if isinstance(data["metadata"], str): data["metadata"] = safe_json_loads(data["metadata"]) if not isinstance(data["metadata"], dict): verbose_proxy_logger.warning( f"Failed to parse 'metadata' as JSON dict. Received value: {data['metadata']}" ) data[_metadata_variable_name]["requester_metadata"] = copy.deepcopy( data["metadata"] ) # Parse litellm_metadata if it's a string (e.g., from multipart/form-data or extra_body) if "litellm_metadata" in data and data["litellm_metadata"] is not None: if isinstance(data["litellm_metadata"], str): parsed_litellm_metadata = safe_json_loads(data["litellm_metadata"]) if not isinstance(parsed_litellm_metadata, dict): verbose_proxy_logger.warning( f"Failed to parse 'litellm_metadata' as JSON dict. Received value: {data['litellm_metadata']}" ) else: data["litellm_metadata"] = parsed_litellm_metadata # Merge litellm_metadata into the metadata variable (preserving existing values) if isinstance(data["litellm_metadata"], dict): for key, value in data["litellm_metadata"].items(): if key not in data[_metadata_variable_name]: data[_metadata_variable_name][key] = value data = LiteLLMProxyRequestSetup.add_user_api_key_auth_to_request_metadata( data=data, user_api_key_dict=user_api_key_dict, _metadata_variable_name=_metadata_variable_name, ) data[_metadata_variable_name]["litellm_api_version"] = version if general_settings is not None: data[_metadata_variable_name][ "global_max_parallel_requests" ] = general_settings.get("global_max_parallel_requests", None) ### KEY-LEVEL Controls key_metadata = user_api_key_dict.metadata data = LiteLLMProxyRequestSetup.add_key_level_controls( key_metadata=key_metadata, data=data, _metadata_variable_name=_metadata_variable_name, ) ## TEAM-LEVEL SPEND LOGS/TAGS team_metadata = user_api_key_dict.team_metadata or {} if "tags" in team_metadata and team_metadata["tags"] is not None: data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags( request_tags=data[_metadata_variable_name].get("tags"), tags_to_add=team_metadata["tags"], ) if "disable_global_guardrails" in team_metadata and isinstance( team_metadata["disable_global_guardrails"], bool ): data[_metadata_variable_name]["disable_global_guardrails"] = team_metadata[ "disable_global_guardrails" ] if "spend_logs_metadata" in team_metadata and isinstance( team_metadata["spend_logs_metadata"], dict ): if "spend_logs_metadata" in data[_metadata_variable_name] and isinstance( data[_metadata_variable_name]["spend_logs_metadata"], dict ): for key, value in team_metadata["spend_logs_metadata"].items(): if ( key not in data[_metadata_variable_name]["spend_logs_metadata"] ): # don't override k-v pair sent by request (user request) data[_metadata_variable_name]["spend_logs_metadata"][key] = value else: data[_metadata_variable_name]["spend_logs_metadata"] = team_metadata[ "spend_logs_metadata" ] ## PROJECT-LEVEL TAGS project_metadata = user_api_key_dict.project_metadata or {} if "tags" in project_metadata and project_metadata["tags"] is not None: data[_metadata_variable_name]["tags"] = LiteLLMProxyRequestSetup._merge_tags( request_tags=data[_metadata_variable_name].get("tags"), tags_to_add=project_metadata["tags"], ) ## TEAM-LEVEL METADATA data = ( LiteLLMProxyRequestSetup.add_management_endpoint_metadata_to_request_metadata( data=data, management_endpoint_metadata=team_metadata, _metadata_variable_name=_metadata_variable_name, ) ) # Team spend, budget - used by prometheus.py data[_metadata_variable_name][ "user_api_key_team_max_budget" ] = user_api_key_dict.team_max_budget data[_metadata_variable_name][ "user_api_key_team_spend" ] = user_api_key_dict.team_spend data[_metadata_variable_name][ "user_api_key_request_route" ] = user_api_key_dict.request_route # API Key spend, budget - used by prometheus.py data[_metadata_variable_name]["user_api_key_spend"] = user_api_key_dict.spend data[_metadata_variable_name][ "user_api_key_max_budget" ] = user_api_key_dict.max_budget data[_metadata_variable_name][ "user_api_key_model_max_budget" ] = user_api_key_dict.model_max_budget data[_metadata_variable_name][ "user_api_key_end_user_model_max_budget" ] = user_api_key_dict.end_user_model_max_budget # User spend, budget - used by prometheus.py # Follow same pattern as team and API key budgets data[_metadata_variable_name][ "user_api_key_user_spend" ] = user_api_key_dict.user_spend data[_metadata_variable_name][ "user_api_key_user_max_budget" ] = user_api_key_dict.user_max_budget data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata data[_metadata_variable_name][ "user_api_key_team_metadata" ] = user_api_key_dict.team_metadata data[_metadata_variable_name]["user_api_key_object_permission_id"] = getattr( user_api_key_dict, "object_permission_id", None ) data[_metadata_variable_name]["user_api_key_team_object_permission_id"] = getattr( user_api_key_dict, "team_object_permission_id", None ) data[_metadata_variable_name]["headers"] = _headers data[_metadata_variable_name]["endpoint"] = str(request.url) # OTEL Controls / Tracing # Add the OTEL Parent Trace before sending it LiteLLM data[_metadata_variable_name][ "litellm_parent_otel_span" ] = user_api_key_dict.parent_otel_span _add_otel_traceparent_to_data(data, request=request) ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: data["allowed_model_region"] = user_api_key_dict.allowed_model_region start_time = time.time() ## [Enterprise Only] # Add User-IP Address requester_ip_address = "" if True: # Always set the IP Address if available # logic for tracking IP Address # logic for tracking IP Address if ( general_settings is not None and general_settings.get("use_x_forwarded_for") is True and request is not None and hasattr(request, "headers") and "x-forwarded-for" in request.headers ): requester_ip_address = request.headers["x-forwarded-for"] elif ( request is not None and hasattr(request, "client") and hasattr(request.client, "host") and request.client is not None ): requester_ip_address = request.client.host data[_metadata_variable_name]["requester_ip_address"] = requester_ip_address # Add User-Agent user_agent = "" if ( request is not None and hasattr(request, "headers") and "user-agent" in request.headers ): user_agent = request.headers["user-agent"] data[_metadata_variable_name]["user_agent"] = user_agent # Check if using tag based routing tags = LiteLLMProxyRequestSetup.add_request_tag_to_metadata( llm_router=llm_router, headers=_headers, data=data, ) if tags is not None: data[_metadata_variable_name]["tags"] = tags # Team Callbacks controls callback_settings_obj = _get_dynamic_logging_metadata( user_api_key_dict=user_api_key_dict, proxy_config=proxy_config ) if callback_settings_obj is not None: data["success_callback"] = callback_settings_obj.success_callback data["failure_callback"] = callback_settings_obj.failure_callback if callback_settings_obj.callback_vars is not None: # unpack callback_vars in data for k, v in callback_settings_obj.callback_vars.items(): data[k] = v # Add disabled callbacks from key metadata if ( user_api_key_dict.metadata and "litellm_disabled_callbacks" in user_api_key_dict.metadata ): disabled_callbacks = user_api_key_dict.metadata["litellm_disabled_callbacks"] if disabled_callbacks and isinstance(disabled_callbacks, list): data["litellm_disabled_callbacks"] = disabled_callbacks # Guardrails from key/team metadata and policy engine await move_guardrails_to_metadata( data=data, _metadata_variable_name=_metadata_variable_name, user_api_key_dict=user_api_key_dict, ) # Team Model Aliases _update_model_if_team_alias_exists( data=data, user_api_key_dict=user_api_key_dict, ) # Key Model Aliases _update_model_if_key_alias_exists( data=data, user_api_key_dict=user_api_key_dict, ) verbose_proxy_logger.debug( "[PROXY] returned data from litellm_pre_call_utils: %s", data ) ## ENFORCED PARAMS CHECK # loop through each enforced param # example enforced_params ['user', 'metadata', 'metadata.generation_name'] _enforced_params_check( request_body=data, general_settings=general_settings, user_api_key_dict=user_api_key_dict, premium_user=premium_user, ) end_time = time.time() asyncio.create_task( service_logger_obj.async_service_success_hook( service=ServiceTypes.PROXY_PRE_CALL, duration=end_time - start_time, call_type="add_litellm_data_to_request", start_time=start_time, end_time=end_time, parent_otel_span=user_api_key_dict.parent_otel_span, ) ) return data def _update_model_if_team_alias_exists( data: dict, user_api_key_dict: UserAPIKeyAuth, ) -> None: """ Update the model if the team alias exists If a alias map has been set on a team, then we want to make the request with the model the team alias is pointing to eg. - user calls `gpt-4o` - team.model_alias_map = { "gpt-4o": "gpt-4o-team-1" } - requested_model = "gpt-4o-team-1" """ _model = data.get("model") if ( _model and user_api_key_dict.team_model_aliases and _model in user_api_key_dict.team_model_aliases ): data["model"] = user_api_key_dict.team_model_aliases[_model] return def _update_model_if_key_alias_exists( data: dict, user_api_key_dict: UserAPIKeyAuth, ) -> None: """ Update the model if the key alias exists If an alias map has been set on a key, then we want to make the request with the model the key alias is pointing to eg. - user calls `modelAlias` - key.aliases = { "modelAlias": "xai/grok-4-fast-non-reasoning" } - requested_model = "xai/grok-4-fast-non-reasoning" """ _model = data.get("model") if ( _model and user_api_key_dict.aliases and isinstance(user_api_key_dict.aliases, dict) and _model in user_api_key_dict.aliases ): data["model"] = user_api_key_dict.aliases[_model] return def _get_enforced_params( general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth ) -> Optional[list]: enforced_params: Optional[list] = None if general_settings is not None: enforced_params = general_settings.get("enforced_params") if ( "service_account_settings" in general_settings and check_if_token_is_service_account(user_api_key_dict) is True ): service_account_settings = general_settings["service_account_settings"] if "enforced_params" in service_account_settings: if enforced_params is None: enforced_params = [] enforced_params.extend(service_account_settings["enforced_params"]) if user_api_key_dict.metadata.get("enforced_params", None) is not None: if enforced_params is None: enforced_params = [] enforced_params.extend(user_api_key_dict.metadata["enforced_params"]) return enforced_params def check_if_token_is_service_account(valid_token: UserAPIKeyAuth) -> bool: """ Checks if the token is a service account Returns: bool: True if token is a service account """ if valid_token.metadata: if "service_account_id" in valid_token.metadata: return True return False def _enforced_params_check( request_body: dict, general_settings: Optional[dict], user_api_key_dict: UserAPIKeyAuth, premium_user: bool, ) -> bool: """ If enforced params are set, check if the request body contains the enforced params. """ enforced_params: Optional[list] = _get_enforced_params( general_settings=general_settings, user_api_key_dict=user_api_key_dict ) if enforced_params is None: return True if enforced_params and premium_user is not True: raise ValueError( f"Enforced Params is an Enterprise feature. Enforced Params: {enforced_params}. {CommonProxyErrors.not_premium_user.value}" ) for enforced_param in enforced_params: _enforced_params = enforced_param.split(".") if len(_enforced_params) == 1: if _enforced_params[0] not in request_body: raise ValueError( f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" ) elif len(_enforced_params) == 2: # this is a scenario where user requires request['metadata']['generation_name'] to exist if _enforced_params[0] not in request_body: raise ValueError( f"BadRequest please pass param={_enforced_params[0]} in request body. This is a required param" ) if _enforced_params[1] not in request_body[_enforced_params[0]]: raise ValueError( f"BadRequest please pass param=[{_enforced_params[0]}][{_enforced_params[1]}] in request body. This is a required param" ) return True def _add_guardrails_from_key_or_team_metadata( key_metadata: Optional[dict], team_metadata: Optional[dict], data: dict, metadata_variable_name: str, ) -> None: """ Helper add guardrails from key or team metadata to request data Key guardrails are set first, then team guardrails are appended (without duplicates). Args: key_metadata: The key metadata dictionary to check for guardrails team_metadata: The team metadata dictionary to check for guardrails data: The request data to update metadata_variable_name: The name of the metadata field in data """ from litellm.proxy.utils import _premium_user_check # Initialize guardrails set (avoiding duplicates) combined_guardrails = set() # Add key-level guardrails first if key_metadata and "guardrails" in key_metadata: if ( isinstance(key_metadata["guardrails"], list) and len(key_metadata["guardrails"]) > 0 ): _premium_user_check() combined_guardrails.update(key_metadata["guardrails"]) # Add team-level guardrails (set automatically handles duplicates) if team_metadata and "guardrails" in team_metadata: if ( isinstance(team_metadata["guardrails"], list) and len(team_metadata["guardrails"]) > 0 ): _premium_user_check() combined_guardrails.update(team_metadata["guardrails"]) # Set combined guardrails in metadata as list if combined_guardrails: data[metadata_variable_name]["guardrails"] = list(combined_guardrails) def _add_guardrails_from_policies_in_metadata( key_metadata: Optional[dict], team_metadata: Optional[dict], data: dict, metadata_variable_name: str, ) -> None: """ Helper to resolve guardrails from policies attached to key/team metadata. This function: 1. Gets policy names from key and team metadata 2. Resolves guardrails from those policies (including inheritance) 3. Adds resolved guardrails to request metadata Args: key_metadata: The key metadata dictionary to check for policies team_metadata: The team metadata dictionary to check for policies data: The request data to update metadata_variable_name: The name of the metadata field in data """ from litellm._logging import verbose_proxy_logger from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.proxy.policy_engine.policy_resolver import PolicyResolver from litellm.proxy.utils import _premium_user_check from litellm.types.proxy.policy_engine import PolicyMatchContext # Collect policy names from key and team metadata policy_names: set = set() # Add key-level policies first if key_metadata and "policies" in key_metadata: if ( isinstance(key_metadata["policies"], list) and len(key_metadata["policies"]) > 0 ): _premium_user_check() policy_names.update(key_metadata["policies"]) # Add team-level policies if team_metadata and "policies" in team_metadata: if ( isinstance(team_metadata["policies"], list) and len(team_metadata["policies"]) > 0 ): _premium_user_check() policy_names.update(team_metadata["policies"]) if not policy_names: return verbose_proxy_logger.debug( f"Policy engine: resolving guardrails from key/team policies: {policy_names}" ) # Check if policy registry is initialized registry = get_policy_registry() if not registry.is_initialized(): verbose_proxy_logger.debug( "Policy engine not initialized, skipping policy resolution from metadata" ) return # Build context for policy resolution (model from request data) context = PolicyMatchContext(model=data.get("model")) # Get all policies from registry all_policies = registry.get_all_policies() # Resolve guardrails from the specified policies resolved_guardrails: set = set() for policy_name in policy_names: if registry.has_policy(policy_name): resolved_policy = PolicyResolver.resolve_policy_guardrails( policy_name=policy_name, policies=all_policies, context=context, ) resolved_guardrails.update(resolved_policy.guardrails) verbose_proxy_logger.debug( f"Policy engine: resolved guardrails from policy '{policy_name}': {resolved_policy.guardrails}" ) else: verbose_proxy_logger.warning( f"Policy engine: policy '{policy_name}' not found in registry" ) if not resolved_guardrails: return # Add resolved guardrails to request metadata if metadata_variable_name not in data: data[metadata_variable_name] = {} existing_guardrails = data[metadata_variable_name].get("guardrails", []) if not isinstance(existing_guardrails, list): existing_guardrails = [] # Combine existing guardrails with policy-resolved guardrails (no duplicates) combined = set(existing_guardrails) combined.update(resolved_guardrails) data[metadata_variable_name]["guardrails"] = list(combined) # Store applied policies in metadata for tracking if "applied_policies" not in data[metadata_variable_name]: data[metadata_variable_name]["applied_policies"] = [] data[metadata_variable_name]["applied_policies"].extend(list(policy_names)) verbose_proxy_logger.debug( f"Policy engine: added guardrails from key/team policies to request metadata: {list(resolved_guardrails)}" ) async def move_guardrails_to_metadata( data: dict, _metadata_variable_name: str, user_api_key_dict: UserAPIKeyAuth, ): """ Helper to add guardrails from request to metadata - If guardrails set on API Key metadata then sets guardrails on request metadata - If guardrails not set on API key, then checks request metadata - Adds guardrails from policies attached to key/team metadata - Adds guardrails from policy engine based on team/key/model context """ # Early-out: skip all guardrails processing when nothing is configured key_metadata = user_api_key_dict.metadata team_metadata = user_api_key_dict.team_metadata has_key_config = key_metadata and ( "guardrails" in key_metadata or "policies" in key_metadata ) has_team_config = team_metadata and ( "guardrails" in team_metadata or "policies" in team_metadata ) has_request_config = ( "guardrails" in data or "guardrail_config" in data or "policies" in data ) # Only check policy engine if no local config (avoid import + registry lookup) if not (has_key_config or has_team_config or has_request_config): from litellm.proxy.policy_engine.policy_registry import get_policy_registry if not get_policy_registry().is_initialized(): # Nothing configured anywhere - clean up request body fields and return data.pop("policies", None) return # Check key-level guardrails _add_guardrails_from_key_or_team_metadata( key_metadata=user_api_key_dict.metadata, team_metadata=user_api_key_dict.team_metadata, data=data, metadata_variable_name=_metadata_variable_name, ) ######################################################################################### # Add guardrails from policies attached to key/team metadata ######################################################################################### _add_guardrails_from_policies_in_metadata( key_metadata=user_api_key_dict.metadata, team_metadata=user_api_key_dict.team_metadata, data=data, metadata_variable_name=_metadata_variable_name, ) ######################################################################################### # Add guardrails from policy engine based on team/key/model context ######################################################################################### await add_guardrails_from_policy_engine( data=data, metadata_variable_name=_metadata_variable_name, user_api_key_dict=user_api_key_dict, ) ######################################################################################### # User's might send "guardrails" in the request body, we need to add them to the request metadata. # Since downstream logic requires "guardrails" to be in the request metadata ######################################################################################### if "guardrails" in data: request_body_guardrails = data.pop("guardrails") if "guardrails" in data[_metadata_variable_name] and isinstance( data[_metadata_variable_name]["guardrails"], list ): data[_metadata_variable_name]["guardrails"].extend(request_body_guardrails) else: data[_metadata_variable_name]["guardrails"] = request_body_guardrails ######################################################################################### if "guardrail_config" in data: request_body_guardrail_config = data.pop("guardrail_config") if "guardrail_config" in data[_metadata_variable_name] and isinstance( data[_metadata_variable_name]["guardrail_config"], dict ): data[_metadata_variable_name]["guardrail_config"].update( request_body_guardrail_config ) else: data[_metadata_variable_name][ "guardrail_config" ] = request_body_guardrail_config def _is_policy_version_id(s: str) -> bool: """Return True if string is a policy version ID (starts with policy_ prefix).""" from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX return isinstance(s, str) and s.startswith(POLICY_VERSION_ID_PREFIX) def _extract_policy_id(s: str) -> Optional[str]: """Extract raw UUID from policy_ string, or None if not a valid version ID.""" from litellm.proxy.policy_engine.policy_registry import POLICY_VERSION_ID_PREFIX if not _is_policy_version_id(s): return None return s[len(POLICY_VERSION_ID_PREFIX) :].strip() or None def _match_and_track_policies( data: dict, context: "PolicyMatchContext", request_body_policies: Any, policies_override: Optional[Dict[str, Any]] = None, ) -> tuple[list[str], dict[str, str]]: """ Match policies via attachments and request body, track them in metadata. Returns: Tuple of (applied_policy_names, policy_reasons) """ from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.callback_utils import ( add_policy_sources_to_metadata, add_policy_to_applied_policies_header, ) from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher # Get matching policies via attachments (with match reasons for attribution) attachment_registry = get_attachment_registry() matches_with_reasons = attachment_registry.get_attached_policies_with_reasons( context ) matching_policy_names = [m["policy_name"] for m in matches_with_reasons] policy_reasons = {m["policy_name"]: m["matched_via"] for m in matches_with_reasons} verbose_proxy_logger.debug( f"Policy engine: matched policies via attachments: {matching_policy_names}" ) # Combine attachment-based policies with dynamic request body policies all_policy_names = set(matching_policy_names) if request_body_policies and isinstance(request_body_policies, list): all_policy_names.update(request_body_policies) verbose_proxy_logger.debug( f"Policy engine: added dynamic policies from request body: {request_body_policies}" ) if not all_policy_names: return [], {} # Filter to only policies whose conditions match the context applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions( policy_names=list(all_policy_names), context=context, policies=policies_override, ) verbose_proxy_logger.debug( f"Policy engine: applied policies (conditions matched): {applied_policy_names}" ) # Track applied policies in metadata for response headers for policy_name in applied_policy_names: add_policy_to_applied_policies_header( request_data=data, policy_name=policy_name ) # Track policy attribution sources for x-litellm-policy-sources header applied_reasons = { name: policy_reasons[name] for name in applied_policy_names if name in policy_reasons } add_policy_sources_to_metadata(request_data=data, policy_sources=applied_reasons) return applied_policy_names, policy_reasons def _apply_resolved_guardrails_to_metadata( data: dict, metadata_variable_name: str, context: "PolicyMatchContext", policy_names: Optional[List[str]] = None, policies: Optional[Dict[str, Any]] = None, ) -> None: """Apply resolved guardrails and pipelines to request metadata.""" from litellm._logging import verbose_proxy_logger from litellm.proxy.policy_engine.policy_resolver import PolicyResolver # Resolve guardrails from matching policies resolved_guardrails = PolicyResolver.resolve_guardrails_for_context( context=context, policies=policies, policy_names=policy_names, ) verbose_proxy_logger.debug( f"Policy engine: resolved guardrails: {resolved_guardrails}" ) # Resolve pipelines from matching policies pipelines = PolicyResolver.resolve_pipelines_for_context( context=context, policies=policies, policy_names=policy_names, ) # Add resolved guardrails to request metadata if metadata_variable_name not in data: data[metadata_variable_name] = {} # Track pipeline-managed guardrails to exclude from independent execution pipeline_managed_guardrails: set = set() if pipelines: pipeline_managed_guardrails = PolicyResolver.get_pipeline_managed_guardrails( pipelines ) data[metadata_variable_name]["_guardrail_pipelines"] = pipelines data[metadata_variable_name][ "_pipeline_managed_guardrails" ] = pipeline_managed_guardrails verbose_proxy_logger.debug( f"Policy engine: resolved {len(pipelines)} pipeline(s), " f"managed guardrails: {pipeline_managed_guardrails}" ) if not resolved_guardrails and not pipelines: return existing_guardrails = data[metadata_variable_name].get("guardrails", []) if not isinstance(existing_guardrails, list): existing_guardrails = [] # Combine existing guardrails with policy-resolved guardrails (no duplicates) # Exclude pipeline-managed guardrails from the flat list combined = set(existing_guardrails) combined.update(resolved_guardrails) combined -= pipeline_managed_guardrails data[metadata_variable_name]["guardrails"] = list(combined) verbose_proxy_logger.debug( f"Policy engine: added guardrails to request metadata: {list(combined)}" ) async def add_guardrails_from_policy_engine( data: dict, metadata_variable_name: str, user_api_key_dict: UserAPIKeyAuth, ) -> None: """ Add guardrails from the policy engine based on request context. This function: 1. Extracts "policies" from request body (if present) for dynamic policy application 2. Supports policy_ in policies to execute a specific version (e.g. published) 3. Gets matching policies based on team_alias, key_alias, and model (via attachments) 4. Combines dynamic policies with attachment-based policies 5. Resolves guardrails from all policies (including inheritance) 6. Adds guardrails to request metadata 7. Tracks applied policies in metadata for response headers 8. Removes "policies" from request body so it's not forwarded to LLM provider Args: data: The request data to update metadata_variable_name: The name of the metadata field in data user_api_key_dict: The user's API key authentication info """ from litellm._logging import verbose_proxy_logger from litellm.proxy.common_utils.http_parsing_utils import get_tags_from_request_body from litellm.proxy.policy_engine.policy_registry import get_policy_registry from litellm.types.proxy.policy_engine import PolicyMatchContext # Extract dynamic policies from request body (if present) request_body_policies_raw = data.pop("policies", None) registry = get_policy_registry() verbose_proxy_logger.debug( f"Policy engine: registry initialized={registry.is_initialized()}, " f"policy_count={len(registry.get_all_policies())}" ) if not registry.is_initialized(): verbose_proxy_logger.debug( "Policy engine not initialized, skipping policy matching" ) return # Extract tags and build context all_tags = get_tags_from_request_body(data) or None _team_alias = user_api_key_dict.team_alias _key_alias = user_api_key_dict.key_alias context = PolicyMatchContext( team_alias=_team_alias if isinstance(_team_alias, str) else None, key_alias=_key_alias if isinstance(_key_alias, str) else None, model=data.get("model"), tags=all_tags, ) verbose_proxy_logger.debug( f"Policy engine: matching policies for context team_alias={context.team_alias}, " f"key_alias={context.key_alias}, model={context.model}, tags={context.tags}" ) # Separate policy names from policy version IDs (policy_) request_body_names: List[str] = [] request_body_version_ids: List[str] = [] if request_body_policies_raw and isinstance(request_body_policies_raw, list): for item in request_body_policies_raw: if not isinstance(item, str): continue if _is_policy_version_id(item): policy_id = _extract_policy_id(item) if policy_id: request_body_version_ids.append(policy_id) else: request_body_names.append(item) # Resolve policy versions by ID from in-memory cache (populated by sync job; no DB in hot path) merged_policies: Dict[str, Any] = dict(registry.get_all_policies()) fetched_policy_names: List[str] = [] for policy_id in request_body_version_ids: result = registry.get_policy_by_id_for_request(policy_id=policy_id) if result is not None: pname, policy = result merged_policies[pname] = policy fetched_policy_names.append(pname) verbose_proxy_logger.debug( f"Policy engine: loaded version by ID policy_{policy_id} -> {pname}" ) else: verbose_proxy_logger.debug( f"Policy engine: policy version {policy_id} not found in cache, skipping" ) # Build request body list: names + policy names from fetched versions request_body_policies = request_body_names + fetched_policy_names # Match and track policies (with merged_policies when we have version overrides) applied_policy_names, _ = _match_and_track_policies( data, context, request_body_policies, policies_override=merged_policies if request_body_version_ids else None, ) # Resolve and apply guardrails. Use applied_policy_names so request-body policies # (names + version IDs) are included. Use merged_policies when we have version overrides. _apply_resolved_guardrails_to_metadata( data, metadata_variable_name, context, policy_names=applied_policy_names if applied_policy_names else None, policies=merged_policies if request_body_version_ids else None, ) def add_provider_specific_headers_to_request( data: dict, headers: dict, ): from litellm.llms.anthropic.common_utils import is_anthropic_oauth_key anthropic_headers = {} # boolean to indicate if a header was added added_header = False for header in ANTHROPIC_API_HEADERS: if header in headers: header_value = headers[header] anthropic_headers[header] = header_value added_header = True # Check for Authorization header with Anthropic OAuth token (sk-ant-oat*) # This needs to be handled via provider-specific headers to ensure it only # goes to Anthropic-compatible providers, not all providers in the router for header, value in headers.items(): if header.lower() == "authorization" and is_anthropic_oauth_key(value): anthropic_headers[header] = value added_header = True break if added_header is True: # Anthropic headers work across multiple providers # Store as comma-separated list so retrieval can match any of them data["provider_specific_header"] = ProviderSpecificHeader( custom_llm_provider=f"{LlmProviders.ANTHROPIC.value},{LlmProviders.BEDROCK.value},{LlmProviders.VERTEX_AI.value}", extra_headers=anthropic_headers, ) return def _add_otel_traceparent_to_data(data: dict, request: Request): from litellm.proxy.proxy_server import open_telemetry_logger if data is None: return if open_telemetry_logger is None: # if user is not use OTEL don't send extra_headers # relevant issue: https://github.com/BerriAI/litellm/issues/4448 return if litellm.forward_traceparent_to_llm_provider is True: if request.headers: if "traceparent" in request.headers: # we want to forward this to the LLM Provider # Relevant issue: https://github.com/BerriAI/litellm/issues/4419 # pass this in extra_headers if "extra_headers" not in data: data["extra_headers"] = {} _exra_headers = data["extra_headers"] if "traceparent" not in _exra_headers: _exra_headers["traceparent"] = request.headers["traceparent"]