727 lines
28 KiB
Python
727 lines
28 KiB
Python
import base64
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
get_type_hints,
|
|
overload,
|
|
)
|
|
|
|
from pydantic import BaseModel
|
|
|
|
import litellm
|
|
from litellm._logging import verbose_logger
|
|
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
|
from litellm.types.llms.openai import (
|
|
ResponseAPIUsage,
|
|
ResponsesAPIOptionalRequestParams,
|
|
ResponsesAPIResponse,
|
|
ResponseText,
|
|
)
|
|
from litellm.types.responses.main import DecodedResponseId
|
|
from litellm.types.utils import (
|
|
CompletionTokensDetailsWrapper,
|
|
PromptTokensDetailsWrapper,
|
|
SpecialEnums,
|
|
Usage,
|
|
)
|
|
|
|
|
|
class ResponsesAPIRequestUtils:
|
|
"""Helper utils for constructing ResponseAPI requests"""
|
|
|
|
@staticmethod
|
|
def _check_valid_arg(
|
|
supported_params: Optional[List[str]],
|
|
non_default_params: Dict,
|
|
drop_params: Optional[bool],
|
|
custom_llm_provider: Optional[str],
|
|
model: str,
|
|
):
|
|
if supported_params is None:
|
|
return
|
|
unsupported_params = {}
|
|
for k in non_default_params.keys():
|
|
if k not in supported_params:
|
|
unsupported_params[k] = non_default_params[k]
|
|
if unsupported_params:
|
|
if litellm.drop_params is True or (
|
|
drop_params is not None and drop_params is True
|
|
):
|
|
pass
|
|
else:
|
|
raise litellm.UnsupportedParamsError(
|
|
status_code=500,
|
|
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
|
|
)
|
|
|
|
@staticmethod
|
|
def get_optional_params_responses_api(
|
|
model: str,
|
|
responses_api_provider_config: BaseResponsesAPIConfig,
|
|
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
|
allowed_openai_params: Optional[List[str]] = None,
|
|
) -> Dict:
|
|
"""
|
|
Get optional parameters for the responses API.
|
|
|
|
Args:
|
|
params: Dictionary of all parameters
|
|
model: The model name
|
|
responses_api_provider_config: The provider configuration for responses API
|
|
|
|
Returns:
|
|
A dictionary of supported parameters for the responses API
|
|
"""
|
|
from litellm.utils import _apply_openai_param_overrides
|
|
|
|
# Remove None values and internal parameters
|
|
# Get supported parameters for the model
|
|
supported_params = responses_api_provider_config.get_supported_openai_params(
|
|
model
|
|
)
|
|
|
|
non_default_params = cast(Dict, response_api_optional_params)
|
|
# Check for unsupported parameters
|
|
ResponsesAPIRequestUtils._check_valid_arg(
|
|
supported_params=supported_params + (allowed_openai_params or []),
|
|
non_default_params=non_default_params,
|
|
drop_params=litellm.drop_params,
|
|
custom_llm_provider=responses_api_provider_config.custom_llm_provider,
|
|
model=model,
|
|
)
|
|
|
|
# Map parameters to provider-specific format
|
|
mapped_params = responses_api_provider_config.map_openai_params(
|
|
response_api_optional_params=response_api_optional_params,
|
|
model=model,
|
|
drop_params=litellm.drop_params,
|
|
)
|
|
|
|
# add any allowed_openai_params to the mapped_params
|
|
mapped_params = _apply_openai_param_overrides(
|
|
optional_params=mapped_params,
|
|
non_default_params=non_default_params,
|
|
allowed_openai_params=allowed_openai_params or [],
|
|
)
|
|
|
|
return mapped_params
|
|
|
|
@staticmethod
|
|
def get_requested_response_api_optional_param(
|
|
params: Dict[str, Any],
|
|
) -> ResponsesAPIOptionalRequestParams:
|
|
"""
|
|
Filter parameters to only include those defined in ResponsesAPIOptionalRequestParams.
|
|
|
|
Args:
|
|
params: Dictionary of parameters to filter
|
|
|
|
Returns:
|
|
ResponsesAPIOptionalRequestParams instance with only the valid parameters
|
|
"""
|
|
from litellm.utils import PreProcessNonDefaultParams
|
|
|
|
valid_keys = get_type_hints(ResponsesAPIOptionalRequestParams).keys()
|
|
custom_llm_provider = params.pop("custom_llm_provider", None)
|
|
special_params = params.pop("kwargs", {})
|
|
|
|
additional_drop_params = params.pop("additional_drop_params", None)
|
|
non_default_params = (
|
|
PreProcessNonDefaultParams.base_pre_process_non_default_params(
|
|
passed_params=params,
|
|
special_params=special_params,
|
|
custom_llm_provider=custom_llm_provider,
|
|
additional_drop_params=additional_drop_params,
|
|
default_param_values={k: None for k in valid_keys},
|
|
additional_endpoint_specific_params=["input"],
|
|
)
|
|
)
|
|
|
|
# decode previous_response_id if it's a litellm encoded id
|
|
if "previous_response_id" in non_default_params:
|
|
decoded_previous_response_id = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
|
|
non_default_params["previous_response_id"]
|
|
)
|
|
non_default_params["previous_response_id"] = decoded_previous_response_id
|
|
|
|
if "metadata" in non_default_params:
|
|
from litellm.utils import add_openai_metadata
|
|
|
|
converted_metadata = add_openai_metadata(non_default_params["metadata"])
|
|
if converted_metadata is not None:
|
|
non_default_params["metadata"] = converted_metadata
|
|
else:
|
|
non_default_params.pop("metadata", None)
|
|
|
|
return cast(ResponsesAPIOptionalRequestParams, non_default_params)
|
|
|
|
# fmt: off
|
|
@overload
|
|
@staticmethod
|
|
def _update_responses_api_response_id_with_model_id(
|
|
responses_api_response: ResponsesAPIResponse,
|
|
custom_llm_provider: Optional[str],
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> ResponsesAPIResponse:
|
|
...
|
|
|
|
@overload
|
|
@staticmethod
|
|
def _update_responses_api_response_id_with_model_id(
|
|
responses_api_response: Dict[str, Any],
|
|
custom_llm_provider: Optional[str],
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
...
|
|
|
|
# fmt: on
|
|
|
|
@staticmethod
|
|
def _update_responses_api_response_id_with_model_id(
|
|
responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]],
|
|
custom_llm_provider: Optional[str],
|
|
litellm_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Union[ResponsesAPIResponse, Dict[str, Any]]:
|
|
"""Update the responses_api_response_id with model_id and custom_llm_provider.
|
|
|
|
Handles both ``ResponsesAPIResponse`` objects and plain dictionaries returned
|
|
by some streaming providers.
|
|
"""
|
|
litellm_metadata = litellm_metadata or {}
|
|
model_info: Dict[str, Any] = litellm_metadata.get("model_info", {}) or {}
|
|
model_id = model_info.get("id")
|
|
|
|
# access the response id based on the object type
|
|
if isinstance(responses_api_response, dict):
|
|
response_id = responses_api_response.get("id")
|
|
else:
|
|
response_id = getattr(responses_api_response, "id", None)
|
|
|
|
# If no response_id, return the response as-is (likely an error response)
|
|
if response_id is None:
|
|
return responses_api_response
|
|
|
|
updated_id = ResponsesAPIRequestUtils._build_responses_api_response_id(
|
|
model_id=model_id,
|
|
custom_llm_provider=custom_llm_provider,
|
|
response_id=response_id,
|
|
)
|
|
|
|
if isinstance(responses_api_response, dict):
|
|
responses_api_response["id"] = updated_id
|
|
else:
|
|
responses_api_response.id = updated_id
|
|
|
|
if litellm_metadata.get("encrypted_content_affinity_enabled"):
|
|
responses_api_response = (
|
|
ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response(
|
|
response=responses_api_response,
|
|
model_id=model_id,
|
|
)
|
|
)
|
|
|
|
return responses_api_response
|
|
|
|
@staticmethod
|
|
def _build_encrypted_item_id(model_id: str, item_id: str) -> str:
|
|
"""Encode model_id into an output item ID for encrypted-content items.
|
|
|
|
Format: ``encitem_{base64("litellm:model_id:{model_id};item_id:{original_id}")}``
|
|
"""
|
|
assembled = f"litellm:model_id:{model_id};item_id:{item_id}"
|
|
encoded = base64.b64encode(assembled.encode("utf-8")).decode("utf-8")
|
|
return f"encitem_{encoded}"
|
|
|
|
@staticmethod
|
|
def _decode_encrypted_item_id(encoded_id: str) -> Optional[Dict[str, str]]:
|
|
"""Decode a litellm-encoded encrypted-content item ID.
|
|
|
|
Returns a dict with ``model_id`` and ``item_id`` keys, or ``None`` if
|
|
the string is not a litellm-encoded item ID.
|
|
"""
|
|
if not encoded_id.startswith("encitem_"):
|
|
return None
|
|
try:
|
|
cleaned = encoded_id[len("encitem_") :]
|
|
# Restore any padding that may have been stripped in transit
|
|
missing = len(cleaned) % 4
|
|
if missing:
|
|
cleaned += "=" * (4 - missing)
|
|
decoded = base64.b64decode(cleaned.encode("utf-8")).decode("utf-8")
|
|
# Split on first ";" only so that semicolons inside item_id are preserved
|
|
parts = decoded.split(";", 1)
|
|
if len(parts) < 2:
|
|
return None
|
|
model_id = parts[0].replace("litellm:model_id:", "")
|
|
item_id = parts[1].replace("item_id:", "")
|
|
return {"model_id": model_id, "item_id": item_id}
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _wrap_encrypted_content_with_model_id(
|
|
encrypted_content: str, model_id: str
|
|
) -> str:
|
|
"""Wrap encrypted_content with model_id metadata for affinity routing.
|
|
|
|
When Codex or other clients send items with encrypted_content but no ID,
|
|
we encode the model_id directly into the encrypted_content itself.
|
|
|
|
Format: ``litellm_enc:{base64("model_id:{model_id}")};{original_encrypted_content}``
|
|
"""
|
|
metadata = f"model_id:{model_id}"
|
|
encoded_metadata = base64.b64encode(metadata.encode("utf-8")).decode("utf-8")
|
|
return f"litellm_enc:{encoded_metadata};{encrypted_content}"
|
|
|
|
@staticmethod
|
|
def _unwrap_encrypted_content_with_model_id(
|
|
wrapped_content: str,
|
|
) -> tuple[Optional[str], str]:
|
|
"""Unwrap encrypted_content to extract model_id and original content.
|
|
|
|
Returns:
|
|
Tuple of (model_id, original_encrypted_content).
|
|
If not wrapped, returns (None, original_content).
|
|
"""
|
|
if not wrapped_content.startswith("litellm_enc:"):
|
|
return None, wrapped_content
|
|
|
|
try:
|
|
# Split on first ";" to separate metadata from content
|
|
parts = wrapped_content.split(";", 1)
|
|
if len(parts) < 2:
|
|
return None, wrapped_content
|
|
|
|
metadata_b64 = parts[0].replace("litellm_enc:", "")
|
|
original_content = parts[1]
|
|
|
|
# Restore padding if needed
|
|
missing = len(metadata_b64) % 4
|
|
if missing:
|
|
metadata_b64 += "=" * (4 - missing)
|
|
|
|
decoded_metadata = base64.b64decode(metadata_b64.encode("utf-8")).decode(
|
|
"utf-8"
|
|
)
|
|
model_id = decoded_metadata.replace("model_id:", "")
|
|
return model_id, original_content
|
|
except Exception:
|
|
return None, wrapped_content
|
|
|
|
@staticmethod
|
|
def _update_encrypted_content_item_ids_in_response(
|
|
response: Union["ResponsesAPIResponse", Dict[str, Any]],
|
|
model_id: Optional[str],
|
|
) -> Union["ResponsesAPIResponse", Dict[str, Any]]:
|
|
"""Rewrite item IDs for output items that contain ``encrypted_content``.
|
|
|
|
Encodes ``model_id`` into the item ID so that follow-up requests can be
|
|
routed back to the originating deployment without any cache lookup.
|
|
|
|
For items without an ID (e.g., from Codex), encodes model_id directly
|
|
into the encrypted_content itself.
|
|
"""
|
|
if not model_id:
|
|
return response
|
|
|
|
output: Optional[list] = None
|
|
if isinstance(response, dict):
|
|
output = response.get("output")
|
|
else:
|
|
output = getattr(response, "output", None)
|
|
|
|
if not isinstance(output, list):
|
|
return response
|
|
|
|
for item in output:
|
|
if isinstance(item, dict):
|
|
item_id = item.get("id")
|
|
encrypted_content = item.get("encrypted_content")
|
|
|
|
if encrypted_content and isinstance(encrypted_content, str):
|
|
# Always wrap encrypted_content with model_id for redundancy
|
|
item[
|
|
"encrypted_content"
|
|
] = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
|
|
encrypted_content, model_id
|
|
)
|
|
# Also encode the ID if present
|
|
if item_id and isinstance(item_id, str):
|
|
item["id"] = ResponsesAPIRequestUtils._build_encrypted_item_id(
|
|
model_id, item_id
|
|
)
|
|
else:
|
|
item_id = getattr(item, "id", None)
|
|
encrypted_content = getattr(item, "encrypted_content", None)
|
|
|
|
if encrypted_content and isinstance(encrypted_content, str):
|
|
# Always wrap encrypted_content with model_id for redundancy
|
|
try:
|
|
item.encrypted_content = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
|
|
encrypted_content, model_id
|
|
)
|
|
except AttributeError:
|
|
pass
|
|
# Also encode the ID if present
|
|
if item_id and isinstance(item_id, str):
|
|
try:
|
|
item.id = ResponsesAPIRequestUtils._build_encrypted_item_id(
|
|
model_id, item_id
|
|
)
|
|
except AttributeError:
|
|
pass
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def _restore_encrypted_content_item_ids_in_input(request_input: Any) -> Any:
|
|
"""Decode litellm-encoded item IDs in request input back to original IDs.
|
|
|
|
Called before forwarding the request to the upstream provider so the
|
|
provider receives the original item IDs and unwrapped encrypted_content.
|
|
|
|
Handles both:
|
|
1. Items with encoded IDs (encitem_...)
|
|
2. Items with wrapped encrypted_content (litellm_enc:...)
|
|
"""
|
|
if not isinstance(request_input, list):
|
|
return request_input
|
|
|
|
for item in request_input:
|
|
if isinstance(item, dict):
|
|
item_id = item.get("id")
|
|
if item_id and isinstance(item_id, str):
|
|
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(
|
|
item_id
|
|
)
|
|
if decoded:
|
|
item["id"] = decoded["item_id"]
|
|
|
|
encrypted_content = item.get("encrypted_content")
|
|
if encrypted_content and isinstance(encrypted_content, str):
|
|
(
|
|
_,
|
|
unwrapped,
|
|
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
|
|
encrypted_content
|
|
)
|
|
if unwrapped != encrypted_content:
|
|
item["encrypted_content"] = unwrapped
|
|
|
|
return request_input
|
|
|
|
@staticmethod
|
|
def _build_responses_api_response_id(
|
|
custom_llm_provider: Optional[str],
|
|
model_id: Optional[str],
|
|
response_id: str,
|
|
) -> str:
|
|
"""Build the responses_api_response_id"""
|
|
assembled_id: str = str(
|
|
SpecialEnums.LITELLM_MANAGED_RESPONSE_COMPLETE_STR.value
|
|
).format(custom_llm_provider, model_id, response_id)
|
|
base64_encoded_id: str = base64.b64encode(assembled_id.encode("utf-8")).decode(
|
|
"utf-8"
|
|
)
|
|
return f"resp_{base64_encoded_id}"
|
|
|
|
@staticmethod
|
|
def _decode_responses_api_response_id(
|
|
response_id: str,
|
|
) -> DecodedResponseId:
|
|
"""
|
|
Decode the responses_api_response_id
|
|
|
|
Returns:
|
|
DecodedResponseId: Structured tuple with custom_llm_provider, model_id, and response_id
|
|
"""
|
|
try:
|
|
# Remove prefix and decode
|
|
cleaned_id = response_id.replace("resp_", "")
|
|
decoded_id = base64.b64decode(cleaned_id.encode("utf-8")).decode("utf-8")
|
|
|
|
# Parse components using known prefixes
|
|
if ";" not in decoded_id:
|
|
return DecodedResponseId(
|
|
custom_llm_provider=None,
|
|
model_id=None,
|
|
response_id=response_id,
|
|
)
|
|
|
|
parts = decoded_id.split(";")
|
|
|
|
# Format: litellm:custom_llm_provider:{};model_id:{};response_id:{}
|
|
custom_llm_provider = None
|
|
model_id = None
|
|
|
|
if (
|
|
len(parts) >= 3
|
|
): # Full format with custom_llm_provider, model_id, and response_id
|
|
custom_llm_provider_part = parts[0]
|
|
model_id_part = parts[1]
|
|
response_part = parts[2]
|
|
|
|
custom_llm_provider = custom_llm_provider_part.replace(
|
|
"litellm:custom_llm_provider:", ""
|
|
)
|
|
model_id = model_id_part.replace("model_id:", "")
|
|
decoded_response_id = response_part.replace("response_id:", "")
|
|
else:
|
|
decoded_response_id = response_id
|
|
|
|
return DecodedResponseId(
|
|
custom_llm_provider=custom_llm_provider,
|
|
model_id=model_id,
|
|
response_id=decoded_response_id,
|
|
)
|
|
except Exception as e:
|
|
verbose_logger.debug(f"Error decoding response_id '{response_id}': {e}")
|
|
return DecodedResponseId(
|
|
custom_llm_provider=None,
|
|
model_id=None,
|
|
response_id=response_id,
|
|
)
|
|
|
|
@staticmethod
|
|
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
|
|
"""Get the model_id from the response_id"""
|
|
if response_id is None:
|
|
return None
|
|
decoded_response_id = (
|
|
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
|
|
)
|
|
return decoded_response_id.get("model_id") or None
|
|
|
|
@staticmethod
|
|
def decode_previous_response_id_to_original_previous_response_id(
|
|
previous_response_id: str,
|
|
) -> str:
|
|
"""
|
|
Decode the previous_response_id to the original previous_response_id
|
|
|
|
Why?
|
|
- LiteLLM encodes the `custom_llm_provider` and `model_id` into the `previous_response_id` this helps with maintaining session consistency when load balancing multiple deployments of the same model.
|
|
- We cannot send the litellm encoded b64 to the upstream llm api, hence we decode it to the original `previous_response_id`
|
|
|
|
Args:
|
|
previous_response_id: The previous_response_id to decode
|
|
|
|
Returns:
|
|
The original previous_response_id
|
|
"""
|
|
decoded_response_id = (
|
|
ResponsesAPIRequestUtils._decode_responses_api_response_id(
|
|
previous_response_id
|
|
)
|
|
)
|
|
return decoded_response_id.get("response_id", previous_response_id)
|
|
|
|
@staticmethod
|
|
def convert_text_format_to_text_param(
|
|
text_format: Optional[Union[Type["BaseModel"], dict]],
|
|
text: Optional["ResponseText"] = None,
|
|
) -> Optional["ResponseText"]:
|
|
"""
|
|
Convert text_format parameter to text parameter for the responses API.
|
|
|
|
Args:
|
|
text_format: Pydantic model class or dict to convert to response format
|
|
text: Existing text parameter (if provided, text_format is ignored)
|
|
|
|
Returns:
|
|
ResponseText object with the converted format, or None if conversion fails
|
|
"""
|
|
if text_format is not None and text is None:
|
|
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
|
|
|
# Convert Pydantic model to response format
|
|
response_format = type_to_response_format_param(text_format)
|
|
if response_format is not None:
|
|
# Create ResponseText object with the format
|
|
# The responses API expects the format to have name at the top level
|
|
text = {
|
|
"format": {
|
|
"type": response_format["type"],
|
|
"name": response_format["json_schema"]["name"],
|
|
"schema": response_format["json_schema"]["schema"],
|
|
"strict": response_format["json_schema"]["strict"],
|
|
}
|
|
}
|
|
return text
|
|
return text
|
|
|
|
@staticmethod
|
|
def extract_mcp_headers_from_request(
|
|
secret_fields: Optional[Dict[str, Any]],
|
|
tools: Optional[Iterable[Any]],
|
|
) -> tuple[
|
|
Optional[str],
|
|
Optional[Dict[str, Dict[str, str]]],
|
|
Optional[Dict[str, str]],
|
|
Optional[Dict[str, str]],
|
|
]:
|
|
"""
|
|
Extract MCP auth headers from the request to pass to MCP server.
|
|
Headers from tools.headers in request body should be passed to MCP server.
|
|
"""
|
|
from starlette.datastructures import Headers
|
|
|
|
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
|
MCPRequestHandler,
|
|
)
|
|
|
|
# Extract headers from secret_fields which contains the original request headers
|
|
raw_headers_from_request: Optional[Dict[str, str]] = None
|
|
if secret_fields and isinstance(secret_fields, dict):
|
|
raw_headers_from_request = secret_fields.get("raw_headers")
|
|
|
|
# Extract MCP-specific headers using MCPRequestHandler methods
|
|
mcp_auth_header: Optional[str] = None
|
|
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
|
|
oauth2_headers: Optional[Dict[str, str]] = None
|
|
|
|
if raw_headers_from_request:
|
|
headers_obj = Headers(raw_headers_from_request)
|
|
mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
|
|
headers_obj
|
|
)
|
|
mcp_server_auth_headers = (
|
|
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
|
|
)
|
|
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
|
|
headers_obj
|
|
)
|
|
|
|
if tools:
|
|
for tool in tools:
|
|
if isinstance(tool, dict) and tool.get("type") == "mcp":
|
|
tool_headers = tool.get("headers", {})
|
|
if tool_headers and isinstance(tool_headers, dict):
|
|
# Merge tool headers into mcp_server_auth_headers
|
|
# Extract server-specific headers from tool.headers
|
|
headers_obj_from_tool = Headers(tool_headers)
|
|
tool_mcp_server_auth_headers = (
|
|
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
|
|
headers_obj_from_tool
|
|
)
|
|
)
|
|
if tool_mcp_server_auth_headers:
|
|
if mcp_server_auth_headers is None:
|
|
mcp_server_auth_headers = {}
|
|
# Merge the headers from tool into existing headers
|
|
for (
|
|
server_alias,
|
|
headers_dict,
|
|
) in tool_mcp_server_auth_headers.items():
|
|
if server_alias not in mcp_server_auth_headers:
|
|
mcp_server_auth_headers[server_alias] = {}
|
|
mcp_server_auth_headers[server_alias].update(
|
|
headers_dict
|
|
)
|
|
# Also merge raw headers (non-prefixed headers from tool.headers)
|
|
if raw_headers_from_request is None:
|
|
raw_headers_from_request = {}
|
|
raw_headers_from_request.update(tool_headers)
|
|
|
|
return (
|
|
mcp_auth_header,
|
|
mcp_server_auth_headers,
|
|
oauth2_headers,
|
|
raw_headers_from_request,
|
|
)
|
|
|
|
|
|
class ResponseAPILoggingUtils:
|
|
@staticmethod
|
|
def _is_response_api_usage(usage: Union[dict, ResponseAPIUsage]) -> bool:
|
|
"""returns True if usage is from OpenAI Response API"""
|
|
if isinstance(usage, ResponseAPIUsage):
|
|
return True
|
|
if "input_tokens" in usage and "output_tokens" in usage:
|
|
return True
|
|
return False
|
|
|
|
@staticmethod
|
|
def _transform_response_api_usage_to_chat_usage(
|
|
usage_input: Optional[Union[dict, ResponseAPIUsage]],
|
|
) -> Usage:
|
|
"""
|
|
Transforms ResponseAPIUsage or ImageUsage to a Usage object.
|
|
|
|
Both have the same spec with input_tokens, output_tokens, and
|
|
input_tokens_details (text_tokens, image_tokens).
|
|
"""
|
|
if usage_input is None:
|
|
return Usage(
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
total_tokens=0,
|
|
)
|
|
response_api_usage: ResponseAPIUsage
|
|
if isinstance(usage_input, dict):
|
|
total_tokens = usage_input.get("total_tokens")
|
|
if total_tokens is None:
|
|
input_tokens = usage_input.get("input_tokens")
|
|
output_tokens = usage_input.get("output_tokens")
|
|
if input_tokens is not None and output_tokens is not None:
|
|
total_tokens = input_tokens + output_tokens
|
|
usage_input["total_tokens"] = total_tokens
|
|
response_api_usage = ResponseAPIUsage(**usage_input)
|
|
else:
|
|
response_api_usage = usage_input
|
|
prompt_tokens: int = response_api_usage.input_tokens or 0
|
|
completion_tokens: int = response_api_usage.output_tokens or 0
|
|
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
|
if response_api_usage.input_tokens_details:
|
|
if isinstance(response_api_usage.input_tokens_details, dict):
|
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
|
**response_api_usage.input_tokens_details
|
|
)
|
|
else:
|
|
prompt_tokens_details = PromptTokensDetailsWrapper(
|
|
cached_tokens=getattr(
|
|
response_api_usage.input_tokens_details, "cached_tokens", None
|
|
),
|
|
audio_tokens=getattr(
|
|
response_api_usage.input_tokens_details, "audio_tokens", None
|
|
),
|
|
text_tokens=getattr(
|
|
response_api_usage.input_tokens_details, "text_tokens", None
|
|
),
|
|
image_tokens=getattr(
|
|
response_api_usage.input_tokens_details, "image_tokens", None
|
|
),
|
|
)
|
|
completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
|
|
output_tokens_details = getattr(
|
|
response_api_usage, "output_tokens_details", None
|
|
)
|
|
if output_tokens_details:
|
|
completion_tokens_details = CompletionTokensDetailsWrapper(
|
|
reasoning_tokens=getattr(
|
|
output_tokens_details, "reasoning_tokens", None
|
|
),
|
|
image_tokens=getattr(output_tokens_details, "image_tokens", None),
|
|
text_tokens=getattr(output_tokens_details, "text_tokens", None),
|
|
)
|
|
|
|
chat_usage = Usage(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=prompt_tokens + completion_tokens,
|
|
prompt_tokens_details=prompt_tokens_details,
|
|
completion_tokens_details=completion_tokens_details,
|
|
)
|
|
|
|
# Preserve cost attribute if it exists on ResponseAPIUsage
|
|
if hasattr(response_api_usage, "cost") and response_api_usage.cost is not None:
|
|
setattr(chat_usage, "cost", response_api_usage.cost)
|
|
|
|
return chat_usage
|