chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,140 @@
"""
Handler for transforming responses api requests to litellm.completion requests
"""
from typing import Any, Coroutine, Dict, Optional, Union
import litellm
from litellm.responses.litellm_completion_transformation.streaming_iterator import (
LiteLLMCompletionStreamingIterator,
)
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
from litellm.types.llms.openai import (
ResponseInputParam,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
)
from litellm.types.utils import ModelResponse
class LiteLLMCompletionTransformationHandler:
def response_api_handler(
self,
model: str,
input: Union[str, ResponseInputParam],
responses_api_request: ResponsesAPIOptionalRequestParams,
custom_llm_provider: Optional[str] = None,
_is_async: bool = False,
stream: Optional[bool] = None,
extra_headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[
ResponsesAPIResponse,
BaseResponsesAPIStreamingIterator,
Coroutine[
Any, Any, Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]
],
]:
litellm_completion_request: dict = LiteLLMCompletionResponsesConfig.transform_responses_api_request_to_chat_completion_request(
model=model,
input=input,
responses_api_request=responses_api_request,
custom_llm_provider=custom_llm_provider,
stream=stream,
extra_headers=extra_headers,
**kwargs,
)
if _is_async:
return self.async_response_api_handler(
litellm_completion_request=litellm_completion_request,
request_input=input,
responses_api_request=responses_api_request,
**kwargs,
)
completion_args = {}
completion_args.update(kwargs)
completion_args.update(litellm_completion_request)
litellm_completion_response: Union[
ModelResponse, litellm.CustomStreamWrapper
] = litellm.completion(
**litellm_completion_request,
**kwargs,
)
if isinstance(litellm_completion_response, ModelResponse):
responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
chat_completion_response=litellm_completion_response,
request_input=input,
responses_api_request=responses_api_request,
)
return responses_api_response
elif isinstance(litellm_completion_response, litellm.CustomStreamWrapper):
return LiteLLMCompletionStreamingIterator(
model=model,
litellm_custom_stream_wrapper=litellm_completion_response,
request_input=input,
responses_api_request=responses_api_request,
custom_llm_provider=custom_llm_provider,
litellm_metadata=kwargs.get("litellm_metadata", {}),
)
raise ValueError(
f"Unexpected response type: {type(litellm_completion_response)}"
)
async def async_response_api_handler(
self,
litellm_completion_request: dict,
request_input: Union[str, ResponseInputParam],
responses_api_request: ResponsesAPIOptionalRequestParams,
**kwargs,
) -> Union[ResponsesAPIResponse, BaseResponsesAPIStreamingIterator]:
previous_response_id: Optional[str] = responses_api_request.get(
"previous_response_id"
)
if previous_response_id:
litellm_completion_request = await LiteLLMCompletionResponsesConfig.async_responses_api_session_handler(
previous_response_id=previous_response_id,
litellm_completion_request=litellm_completion_request,
)
acompletion_args = {}
acompletion_args.update(kwargs)
acompletion_args.update(litellm_completion_request)
litellm_completion_response: Union[
ModelResponse, litellm.CustomStreamWrapper
] = await litellm.acompletion(
**acompletion_args,
)
if isinstance(litellm_completion_response, ModelResponse):
responses_api_response: ResponsesAPIResponse = LiteLLMCompletionResponsesConfig.transform_chat_completion_response_to_responses_api_response(
chat_completion_response=litellm_completion_response,
request_input=request_input,
responses_api_request=responses_api_request,
)
return responses_api_response
elif isinstance(litellm_completion_response, litellm.CustomStreamWrapper):
return LiteLLMCompletionStreamingIterator(
model=litellm_completion_request.get("model") or "",
litellm_custom_stream_wrapper=litellm_completion_response,
request_input=request_input,
responses_api_request=responses_api_request,
custom_llm_provider=litellm_completion_request.get(
"custom_llm_provider"
),
litellm_metadata=kwargs.get("litellm_metadata", {}),
)
raise ValueError(
f"Unexpected response type: {type(litellm_completion_response)}"
)

View File

@@ -0,0 +1,315 @@
import json
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpendLogsPayload
from litellm.proxy.spend_tracking.cold_storage_handler import ColdStorageHandler
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionResponseMessage,
GenericChatCompletionMessage,
ResponseInputParam,
)
from litellm.types.utils import ChatCompletionMessageToolCall, Message, ModelResponse
if TYPE_CHECKING:
from litellm.responses.litellm_completion_transformation.transformation import (
ChatCompletionSession,
)
else:
ChatCompletionSession = Any
########################################################
# Cold Storage Handler
########################################################
COLD_STORAGE_HANDLER = ColdStorageHandler()
########################################################
class ResponsesSessionHandler:
@staticmethod
async def get_chat_completion_message_history_for_previous_response_id(
previous_response_id: str,
) -> ChatCompletionSession:
"""
Return the chat completion message history for a previous response id
"""
from litellm.responses.litellm_completion_transformation.transformation import (
ChatCompletionSession,
)
verbose_proxy_logger.debug(
"inside get_chat_completion_message_history_for_previous_response_id"
)
all_spend_logs: List[
SpendLogsPayload
] = await ResponsesSessionHandler.get_all_spend_logs_for_previous_response_id(
previous_response_id
)
verbose_proxy_logger.debug(
"found %s spend logs for this response id", len(all_spend_logs)
)
litellm_session_id: Optional[str] = None
if len(all_spend_logs) > 0:
litellm_session_id = all_spend_logs[0].get("session_id")
chat_completion_message_history: List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionResponseMessage,
Message,
]
] = []
for spend_log in all_spend_logs:
chat_completion_message_history = await ResponsesSessionHandler.extend_chat_completion_message_with_spend_log_payload(
spend_log=spend_log,
chat_completion_message_history=chat_completion_message_history,
)
verbose_proxy_logger.debug(
"chat_completion_message_history %s",
json.dumps(chat_completion_message_history, indent=4, default=str),
)
return ChatCompletionSession(
messages=chat_completion_message_history,
litellm_session_id=litellm_session_id,
)
@staticmethod
async def extend_chat_completion_message_with_spend_log_payload(
spend_log: SpendLogsPayload,
chat_completion_message_history: List[
Union[
AllMessageValues,
GenericChatCompletionMessage,
ChatCompletionMessageToolCall,
ChatCompletionResponseMessage,
Message,
]
],
):
"""
Extend the chat completion message history with the spend log payload
"""
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
proxy_server_request_dict = (
await ResponsesSessionHandler.get_proxy_server_request_from_spend_log(
spend_log=spend_log,
)
)
response_input_param: Optional[Union[str, ResponseInputParam]] = None
_messages: Optional[Union[str, ResponseInputParam]] = None
############################################################
# Add Input messages for this Spend Log
############################################################
if proxy_server_request_dict:
_response_input_param = proxy_server_request_dict.get("input", None)
_messages = proxy_server_request_dict.get("messages", None)
if isinstance(_response_input_param, str):
response_input_param = _response_input_param
elif isinstance(_response_input_param, dict):
response_input_param = cast(ResponseInputParam, _response_input_param)
if response_input_param:
chat_completion_messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=response_input_param,
responses_api_request=proxy_server_request_dict or {},
)
chat_completion_message_history.extend(chat_completion_messages)
############################################################
# Check if `messages` field is present in the proxy server request dict
############################################################
elif _messages:
# ensure all messages are /chat/completions/messages
# certain requests can be stored as Responses API format - this ensures they are transformed to /chat/completions/messages
chat_completion_messages = LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
input=_messages,
responses_api_request=proxy_server_request_dict or {},
)
chat_completion_message_history.extend(chat_completion_messages)
############################################################
# Add Output messages for this Spend Log
############################################################
_response_output = spend_log.get("response", "{}")
if (
isinstance(_response_output, dict)
and _response_output
and _response_output != {}
):
# transform `ChatCompletion Response` to `ResponsesAPIResponse`
model_response = ModelResponse(**_response_output)
for choice in model_response.choices:
if hasattr(choice, "message"):
chat_completion_message_history.append(getattr(choice, "message"))
return chat_completion_message_history
@staticmethod
async def get_proxy_server_request_from_spend_log(
spend_log: SpendLogsPayload,
) -> Optional[dict]:
"""
Get the parsed proxy server request from the spend log
"""
proxy_server_request: Union[str, dict] = (
spend_log.get("proxy_server_request") or "{}"
)
proxy_server_request_dict: Optional[dict] = None
if isinstance(proxy_server_request, dict):
proxy_server_request_dict = proxy_server_request
else:
proxy_server_request_dict = json.loads(proxy_server_request)
############################################################
# Check if user has setup cold storage for session handling
############################################################
if ResponsesSessionHandler._should_check_cold_storage_for_full_payload(
proxy_server_request_dict
):
# Try to get cold storage object key from spend log metadata
_proxy_server_request_dict: Optional[dict] = None
cold_storage_object_key = (
ResponsesSessionHandler._get_cold_storage_object_key_from_spend_log(
spend_log
)
)
if cold_storage_object_key:
# Use the object key directly from metadata
_proxy_server_request_dict = await ResponsesSessionHandler.get_proxy_server_request_from_cold_storage_with_object_key(
object_key=cold_storage_object_key,
)
if _proxy_server_request_dict:
proxy_server_request_dict = _proxy_server_request_dict
return proxy_server_request_dict
@staticmethod
def _get_cold_storage_object_key_from_spend_log(
spend_log: SpendLogsPayload,
) -> Optional[str]:
"""
Extract the cold storage object key from spend log metadata.
Args:
spend_log: The spend log payload containing metadata
Returns:
Optional[str]: The cold storage object key if found, None otherwise
"""
try:
metadata_str = spend_log.get("metadata", "{}")
if isinstance(metadata_str, str):
metadata_dict = json.loads(metadata_str)
return metadata_dict.get("cold_storage_object_key")
elif isinstance(metadata_str, dict):
return metadata_str.get("cold_storage_object_key")
return None
except (json.JSONDecodeError, TypeError, AttributeError):
verbose_proxy_logger.debug(
"Failed to parse metadata from spend log to extract cold storage object key"
)
return None
@staticmethod
async def get_proxy_server_request_from_cold_storage_with_object_key(
object_key: str,
) -> Optional[dict]:
"""
Get the proxy server request from cold storage using the object key directly.
Args:
object_key: The S3/GCS object key to retrieve
Returns:
Optional[dict]: The proxy server request dict or None if not found
"""
verbose_proxy_logger.debug(
"inside get_proxy_server_request_from_cold_storage_with_object_key..."
)
proxy_server_request_dict = await COLD_STORAGE_HANDLER.get_proxy_server_request_from_cold_storage_with_object_key(
object_key=object_key,
)
return proxy_server_request_dict
@staticmethod
def _should_check_cold_storage_for_full_payload(
proxy_server_request_dict: Optional[dict],
) -> bool:
"""
Only check cold storage when both are true
1. `LITELLM_TRUNCATED_PAYLOAD_FIELD` is in the proxy server request dict
2. `litellm.cold_storage_custom_logger` is not None
"""
from litellm.constants import LITELLM_TRUNCATED_PAYLOAD_FIELD
configured_cold_storage_custom_logger = litellm.cold_storage_custom_logger
if configured_cold_storage_custom_logger is None:
return False
if proxy_server_request_dict is None:
return True
if len(proxy_server_request_dict) == 0:
return True
if LITELLM_TRUNCATED_PAYLOAD_FIELD in str(proxy_server_request_dict):
return True
return False
@staticmethod
async def get_all_spend_logs_for_previous_response_id(
previous_response_id: str,
) -> List[SpendLogsPayload]:
"""
Get all spend logs for a previous response id
SQL query
SELECT session_id FROM spend_logs WHERE response_id = previous_response_id, SELECT * FROM spend_logs WHERE session_id = session_id
"""
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug("decoding response id=%s", previous_response_id)
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(
previous_response_id
)
)
previous_response_id = decoded_response_id.get(
"response_id", previous_response_id
)
if prisma_client is None:
return []
query = """
WITH matching_session AS (
SELECT session_id
FROM "LiteLLM_SpendLogs"
WHERE request_id = $1
)
SELECT *
FROM "LiteLLM_SpendLogs"
WHERE session_id IN (SELECT session_id FROM matching_session)
ORDER BY "endTime" ASC;
"""
spend_logs = await prisma_client.db.query_raw(query, previous_response_id)
verbose_proxy_logger.debug(
"Found the following spend logs for previous response id %s: %s",
previous_response_id,
json.dumps(spend_logs, indent=4, default=str),
)
return spend_logs

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,670 @@
"""Helpers for handling MCP-aware `/chat/completions` requests."""
from typing import (
Any,
List,
Optional,
Union,
cast,
)
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
from litellm.responses.utils import ResponsesAPIRequestUtils
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
def _add_mcp_metadata_to_response(
response: Union[ModelResponse, CustomStreamWrapper],
openai_tools: Optional[List],
tool_calls: Optional[List] = None,
tool_results: Optional[List] = None,
) -> None:
"""
Add MCP metadata to response's provider_specific_fields.
This function adds MCP-related information to the response so that
clients can access which tools were available, which were called, and
what results were returned.
For ModelResponse: adds to choices[].message.provider_specific_fields
For CustomStreamWrapper: stores in _hidden_params and automatically adds to
final chunk's delta.provider_specific_fields via CustomStreamWrapper._add_mcp_metadata_to_final_chunk()
"""
if isinstance(response, CustomStreamWrapper):
# For streaming, store MCP metadata in _hidden_params
# CustomStreamWrapper._add_mcp_metadata_to_final_chunk() will automatically
# add it to the final chunk's delta.provider_specific_fields
if not hasattr(response, "_hidden_params"):
response._hidden_params = {}
mcp_metadata = {}
if openai_tools:
mcp_metadata["mcp_list_tools"] = openai_tools
if tool_calls:
mcp_metadata["mcp_tool_calls"] = tool_calls
if tool_results:
mcp_metadata["mcp_call_results"] = tool_results
if mcp_metadata:
response._hidden_params["mcp_metadata"] = mcp_metadata
return
if not isinstance(response, ModelResponse):
return
if not hasattr(response, "choices") or not response.choices:
return
# Add MCP metadata to all choices' messages
for choice in response.choices:
message = getattr(choice, "message", None)
if message is not None:
# Get existing provider_specific_fields or create new dict
provider_fields = getattr(message, "provider_specific_fields", None) or {}
# Add MCP metadata
if openai_tools:
provider_fields["mcp_list_tools"] = openai_tools
if tool_calls:
provider_fields["mcp_tool_calls"] = tool_calls
if tool_results:
provider_fields["mcp_call_results"] = tool_results
# Set the provider_specific_fields
setattr(message, "provider_specific_fields", provider_fields)
async def acompletion_with_mcp( # noqa: PLR0915
model: str,
messages: List,
tools: Optional[List] = None,
**kwargs: Any,
) -> Union[ModelResponse, CustomStreamWrapper]:
"""
Async completion with MCP integration.
This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp.
It's designed to be called from the synchronous completion() function and return a coroutine.
When MCP tools with server_url="litellm_proxy" are provided, this function will:
1. Get available tools from the MCP server manager
2. Transform them to OpenAI format
3. Call acompletion with the transformed tools
4. If require_approval="never" and tool calls are returned, automatically execute them
5. Make a follow-up call with the tool results
"""
from litellm import acompletion as litellm_acompletion
# Parse MCP tools and separate from other tools
(
mcp_tools_with_litellm_proxy,
other_tools,
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
if not mcp_tools_with_litellm_proxy:
# No MCP tools, proceed with regular completion
return await litellm_acompletion(
model=model,
messages=messages,
tools=tools,
**kwargs,
)
# Extract user_api_key_auth from metadata or kwargs
user_api_key_auth = kwargs.get("user_api_key_auth") or (
(kwargs.get("metadata", {}) or {}).get("user_api_key_auth")
)
# Extract MCP auth headers before fetching tools (needed for dynamic auth)
(
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
secret_fields=kwargs.get("secret_fields"),
tools=tools,
)
# Process MCP tools (pass auth headers for dynamic auth)
(
deduplicated_mcp_tools,
tool_server_map,
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
user_api_key_auth=user_api_key_auth,
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
litellm_trace_id=kwargs.get("litellm_trace_id"),
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
)
openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
deduplicated_mcp_tools,
target_format="chat",
)
# Combine with other tools
all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None
# Determine if we should auto-execute tools
should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy
)
# Prepare call parameters
# Remove keys that shouldn't be passed to acompletion
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]}
base_call_args = {
"model": model,
"messages": messages,
"tools": all_tools,
"_skip_mcp_handler": True, # Prevent recursion
**clean_kwargs,
}
# If not auto-executing, just make the call with transformed tools
if not should_auto_execute:
response = await litellm_acompletion(**base_call_args)
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
_add_mcp_metadata_to_response(
response=response,
openai_tools=openai_tools,
)
return response
# For auto-execute: handle streaming vs non-streaming differently
stream = kwargs.get("stream", False)
mock_tool_calls = base_call_args.pop("mock_tool_calls", None)
if stream:
# Streaming mode: make initial call with streaming, collect chunks, detect tool calls
initial_call_args = dict(base_call_args)
initial_call_args["stream"] = True
if mock_tool_calls is not None:
initial_call_args["mock_tool_calls"] = mock_tool_calls
# Make initial streaming call
initial_stream = await litellm_acompletion(**initial_call_args)
if not isinstance(initial_stream, CustomStreamWrapper):
# Not a stream, return as-is
if isinstance(initial_stream, ModelResponse):
_add_mcp_metadata_to_response(
response=initial_stream,
openai_tools=openai_tools,
)
return initial_stream
# Create a custom async generator that collects chunks and handles tool execution
from litellm.main import stream_chunk_builder
from litellm.types.utils import ModelResponseStream
class MCPStreamingIterator:
"""Custom iterator that collects chunks, detects tool calls, and adds MCP metadata to final chunk."""
def __init__(
self,
stream_wrapper,
messages,
tool_server_map,
user_api_key_auth,
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers,
litellm_call_id,
litellm_trace_id,
openai_tools,
base_call_args,
):
self.stream_wrapper = stream_wrapper
self.messages = messages
self.tool_server_map = tool_server_map
self.user_api_key_auth = user_api_key_auth
self.mcp_auth_header = mcp_auth_header
self.mcp_server_auth_headers = mcp_server_auth_headers
self.oauth2_headers = oauth2_headers
self.raw_headers = raw_headers
self.litellm_call_id = litellm_call_id
self.litellm_trace_id = litellm_trace_id
self.openai_tools = openai_tools
self.base_call_args = base_call_args
self.collected_chunks: List[ModelResponseStream] = []
self.tool_calls: Optional[List] = None
self.tool_results: Optional[List] = None
self.complete_response: Optional[ModelResponse] = None
self.stream_exhausted = False
self.tool_execution_done = False
self.follow_up_stream = None
self.follow_up_iterator = None
self.follow_up_exhausted = False
async def __aiter__(self):
return self
def _add_mcp_list_tools_to_chunk(
self, chunk: ModelResponseStream
) -> ModelResponseStream:
"""Add mcp_list_tools to the first chunk."""
from litellm.types.utils import (
StreamingChoices,
add_provider_specific_fields,
)
if not self.openai_tools:
return chunk
if hasattr(chunk, "choices") and chunk.choices:
for choice in chunk.choices:
if (
isinstance(choice, StreamingChoices)
and hasattr(choice, "delta")
and choice.delta
):
# Get existing provider_specific_fields or create new dict
existing_fields = (
getattr(choice.delta, "provider_specific_fields", None)
or {}
)
provider_fields = dict(
existing_fields
) # Create a copy to avoid mutating the original
# Add only mcp_list_tools to first chunk
provider_fields["mcp_list_tools"] = self.openai_tools
# Use add_provider_specific_fields to ensure proper setting
# This function handles Pydantic model attribute setting correctly
add_provider_specific_fields(choice.delta, provider_fields)
return chunk
def _add_mcp_tool_metadata_to_final_chunk(
self, chunk: ModelResponseStream
) -> ModelResponseStream:
"""Add mcp_tool_calls and mcp_call_results to the final chunk."""
from litellm.types.utils import (
StreamingChoices,
add_provider_specific_fields,
)
if hasattr(chunk, "choices") and chunk.choices:
for choice in chunk.choices:
if (
isinstance(choice, StreamingChoices)
and hasattr(choice, "delta")
and choice.delta
):
# Get existing provider_specific_fields or create new dict
# Access the attribute directly to handle Pydantic model attributes correctly
existing_fields = {}
if hasattr(choice.delta, "provider_specific_fields"):
attr_value = getattr(
choice.delta, "provider_specific_fields", None
)
if attr_value is not None:
# Create a copy to avoid mutating the original
existing_fields = (
dict(attr_value)
if isinstance(attr_value, dict)
else {}
)
provider_fields = existing_fields
# Add tool_calls and tool_results if available
if self.tool_calls:
provider_fields["mcp_tool_calls"] = self.tool_calls
if self.tool_results:
provider_fields["mcp_call_results"] = self.tool_results
# Use add_provider_specific_fields to ensure proper setting
# This function handles Pydantic model attribute setting correctly
add_provider_specific_fields(choice.delta, provider_fields)
return chunk
async def __anext__(self):
# Phase 1: Collect and yield initial stream chunks
if not self.stream_exhausted:
# Get the iterator from the stream wrapper
if not hasattr(self, "_stream_iterator"):
self._stream_iterator = self.stream_wrapper.__aiter__()
# Add mcp_list_tools to the first chunk (available from the start)
_add_mcp_metadata_to_response(
response=self.stream_wrapper,
openai_tools=self.openai_tools,
)
try:
chunk = await self._stream_iterator.__anext__()
self.collected_chunks.append(chunk)
# Add mcp_list_tools to the first chunk
if len(self.collected_chunks) == 1:
chunk = self._add_mcp_list_tools_to_chunk(chunk)
# Check if this is the final chunk (has finish_reason)
is_final = (
hasattr(chunk, "choices")
and chunk.choices
and hasattr(chunk.choices[0], "finish_reason")
and chunk.choices[0].finish_reason is not None
)
if is_final:
# This is the final chunk, mark stream as exhausted
self.stream_exhausted = True
# Process tool calls after we've collected all chunks
await self._process_tool_calls()
# Apply MCP metadata (tool_calls and tool_results) to final chunk
chunk = self._add_mcp_tool_metadata_to_final_chunk(chunk)
# If we have tool results, prepare follow-up call immediately
if self.tool_results and self.complete_response:
await self._prepare_follow_up_call()
return chunk
except StopAsyncIteration:
self.stream_exhausted = True
# Process tool calls after stream is exhausted
await self._process_tool_calls()
# If we have chunks, yield the final one with metadata
if self.collected_chunks:
final_chunk = self.collected_chunks[-1]
final_chunk = self._add_mcp_tool_metadata_to_final_chunk(
final_chunk
)
# If we have tool results, prepare follow-up call
if self.tool_results and self.complete_response:
await self._prepare_follow_up_call()
return final_chunk
# Phase 2: Yield follow-up stream chunks if available
if self.follow_up_stream and not self.follow_up_exhausted:
if not self.follow_up_iterator:
self.follow_up_iterator = self.follow_up_stream.__aiter__()
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream iterator created")
try:
chunk = await self.follow_up_iterator.__anext__()
from litellm._logging import verbose_logger
verbose_logger.debug(f"Follow-up chunk yielded: {chunk}")
return chunk
except StopAsyncIteration:
self.follow_up_exhausted = True
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream exhausted")
# After follow-up stream is exhausted, check if we need to raise StopAsyncIteration
raise StopAsyncIteration
# If we're here and follow_up_stream is None but we expected it, log a warning
if (
self.stream_exhausted
and self.tool_results
and self.complete_response
and self.follow_up_stream is None
):
from litellm._logging import verbose_logger
verbose_logger.warning(
"Follow-up stream was not created despite having tool results"
)
raise StopAsyncIteration
async def _process_tool_calls(self):
"""Process tool calls after streaming completes."""
if self.tool_execution_done:
return
self.tool_execution_done = True
if not self.collected_chunks:
return
# Build complete response from chunks
complete_response = stream_chunk_builder(
chunks=self.collected_chunks,
messages=self.messages,
)
if isinstance(complete_response, ModelResponse):
self.complete_response = complete_response
# Extract tool calls from complete response
self.tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
response=complete_response
)
if self.tool_calls:
# Execute tool calls
self.tool_results = (
await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=self.tool_server_map,
tool_calls=self.tool_calls,
user_api_key_auth=self.user_api_key_auth,
mcp_auth_header=self.mcp_auth_header,
mcp_server_auth_headers=self.mcp_server_auth_headers,
oauth2_headers=self.oauth2_headers,
raw_headers=self.raw_headers,
litellm_call_id=self.litellm_call_id,
litellm_trace_id=self.litellm_trace_id,
)
)
async def _prepare_follow_up_call(self):
"""Prepare and initiate follow-up call with tool results."""
if self.follow_up_stream is not None:
return # Already prepared
if not self.tool_results or not self.complete_response:
return
# Create follow-up messages with tool results
follow_up_messages = (
LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
original_messages=self.messages,
response=self.complete_response,
tool_results=self.tool_results,
)
)
# Make follow-up call with streaming
follow_up_call_args = dict(self.base_call_args)
follow_up_call_args["messages"] = follow_up_messages
follow_up_call_args["stream"] = True
# Ensure follow-up call doesn't trigger MCP handler again
follow_up_call_args["_skip_mcp_handler"] = True
# Import litellm here to ensure we get the patched version
# This ensures the patch works correctly in tests
import litellm
follow_up_response = await litellm.acompletion(**follow_up_call_args)
# Ensure follow-up response is a CustomStreamWrapper
if isinstance(follow_up_response, CustomStreamWrapper):
self.follow_up_stream = follow_up_response
from litellm._logging import verbose_logger
verbose_logger.debug("Follow-up stream created successfully")
else:
# Unexpected response type - log and set to None
from litellm._logging import verbose_logger
verbose_logger.warning(
f"Follow-up response is not a CustomStreamWrapper: {type(follow_up_response)}"
)
self.follow_up_stream = None
# Create the custom iterator
iterator = MCPStreamingIterator(
stream_wrapper=initial_stream,
messages=messages,
tool_server_map=tool_server_map,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
openai_tools=openai_tools,
base_call_args=base_call_args,
)
# Create a wrapper class that delegates to our custom iterator
# We'll use a simple approach: just replace the __aiter__ method
class MCPStreamWrapper(CustomStreamWrapper):
def __init__(self, original_wrapper, custom_iterator):
# Initialize with the same parameters as original wrapper
super().__init__(
completion_stream=None,
model=getattr(original_wrapper, "model", "unknown"),
logging_obj=getattr(original_wrapper, "logging_obj", None),
custom_llm_provider=getattr(
original_wrapper, "custom_llm_provider", None
),
stream_options=getattr(original_wrapper, "stream_options", None),
make_call=getattr(original_wrapper, "make_call", None),
_response_headers=getattr(
original_wrapper, "_response_headers", None
),
)
self._original_wrapper = original_wrapper
self._custom_iterator = custom_iterator
# Copy important attributes from original wrapper
if hasattr(original_wrapper, "_hidden_params"):
self._hidden_params = original_wrapper._hidden_params
# For synchronous iteration, we need to run the async iterator
self._sync_iterator = None
self._sync_loop = None
def __aiter__(self):
return self._custom_iterator
def __iter__(self):
# For synchronous iteration, create a sync wrapper
if self._sync_iterator is None:
import asyncio
try:
self._sync_loop = asyncio.get_event_loop()
except RuntimeError:
self._sync_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._sync_loop)
self._sync_iterator = _SyncIteratorWrapper(
self._custom_iterator, self._sync_loop
)
return self._sync_iterator
def __next__(self):
# Delegate to sync iterator
if self._sync_iterator is None:
self.__iter__()
return next(self._sync_iterator)
def __getattr__(self, name):
# Delegate all other attributes to original wrapper
return getattr(self._original_wrapper, name)
# Helper class to wrap async iterator for sync iteration
class _SyncIteratorWrapper:
def __init__(self, async_iterator, loop):
self._async_iterator = async_iterator
self._loop = loop
self._iterator = None
def __iter__(self):
return self
def __next__(self):
if self._iterator is None:
# __aiter__ might be async, so we need to await it
aiter_result = self._async_iterator.__aiter__()
if hasattr(aiter_result, "__await__"):
# It's a coroutine, await it
self._iterator = self._loop.run_until_complete(aiter_result)
else:
# It's already an iterator
self._iterator = aiter_result
try:
return self._loop.run_until_complete(self._iterator.__anext__())
except StopAsyncIteration:
raise StopIteration
return cast(CustomStreamWrapper, MCPStreamWrapper(initial_stream, iterator))
# Non-streaming mode: use existing logic
initial_call_args = dict(base_call_args)
initial_call_args["stream"] = False
if mock_tool_calls is not None:
initial_call_args["mock_tool_calls"] = mock_tool_calls
# Make initial call
initial_response = await litellm_acompletion(**initial_call_args)
if not isinstance(initial_response, ModelResponse):
return initial_response
# Extract tool calls from response
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
response=initial_response
)
if not tool_calls:
_add_mcp_metadata_to_response(
response=initial_response,
openai_tools=openai_tools,
)
return initial_response
# Execute tool calls
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=tool_server_map,
tool_calls=tool_calls,
user_api_key_auth=user_api_key_auth,
mcp_auth_header=mcp_auth_header,
mcp_server_auth_headers=mcp_server_auth_headers,
oauth2_headers=oauth2_headers,
raw_headers=raw_headers,
litellm_call_id=kwargs.get("litellm_call_id"),
litellm_trace_id=kwargs.get("litellm_trace_id"),
)
if not tool_results:
_add_mcp_metadata_to_response(
response=initial_response,
openai_tools=openai_tools,
tool_calls=tool_calls,
)
return initial_response
# Create follow-up messages with tool results
follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
original_messages=messages,
response=initial_response,
tool_results=tool_results,
)
# Make follow-up call with original stream setting
follow_up_call_args = dict(base_call_args)
follow_up_call_args["messages"] = follow_up_messages
follow_up_call_args["stream"] = stream
response = await litellm_acompletion(**follow_up_call_args)
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
_add_mcp_metadata_to_response(
response=response,
openai_tools=openai_tools,
tool_calls=tool_calls,
tool_results=tool_results,
)
return response

View File

@@ -0,0 +1,798 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
from litellm.types.llms.openai import (
BaseLiteLLMOpenAIResponseObject,
MCPCallArgumentsDeltaEvent,
MCPCallArgumentsDoneEvent,
MCPCallCompletedEvent,
MCPCallFailedEvent,
MCPCallInProgressEvent,
MCPListToolsCompletedEvent,
MCPListToolsFailedEvent,
MCPListToolsInProgressEvent,
ResponsesAPIResponse,
ResponsesAPIStreamEvents,
ResponsesAPIStreamingResponse,
ToolParam,
)
if TYPE_CHECKING:
from mcp.types import Tool as MCPTool
else:
MCPTool = Any
async def create_mcp_list_tools_events(
mcp_tools_with_litellm_proxy: List[ToolParam],
user_api_key_auth: Any,
base_item_id: str,
pre_processed_mcp_tools: List[Any],
) -> List[ResponsesAPIStreamingResponse]:
"""Create MCP discovery events using pre-processed tools from the parent"""
events: List[ResponsesAPIStreamingResponse] = []
try:
# Extract MCP server names
mcp_servers = []
for tool in mcp_tools_with_litellm_proxy:
if isinstance(tool, dict) and "server_url" in tool:
server_url = tool.get("server_url")
if isinstance(server_url, str) and server_url.startswith(
"litellm_proxy/mcp/"
):
server_name = server_url.split("/")[-1]
mcp_servers.append(server_name)
# Emit list tools in progress event
in_progress_event = MCPListToolsInProgressEvent(
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS,
sequence_number=1,
output_index=0,
item_id=base_item_id,
)
events.append(in_progress_event)
# Use the pre-processed MCP tools that were already fetched, filtered, and deduplicated by the parent
filtered_mcp_tools = pre_processed_mcp_tools
# Convert tools to dict format for the event
mcp_tools_dict = []
for tool in filtered_mcp_tools:
if hasattr(tool, "model_dump") and callable(getattr(tool, "model_dump")):
# Type cast to help mypy understand this is safe after hasattr check
mcp_tools_dict.append(cast(Any, tool).model_dump())
elif hasattr(tool, "__dict__"):
mcp_tools_dict.append(tool.__dict__)
else:
mcp_tools_dict.append({"name": getattr(tool, "name", str(tool))})
# Emit list tools completed event
completed_event = MCPListToolsCompletedEvent(
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED,
sequence_number=2,
output_index=0,
item_id=base_item_id,
)
events.append(completed_event)
# Add output_item.done event with the actual tools list (matching OpenAI format)
from litellm.types.llms.openai import OutputItemDoneEvent
# Extract server label from the first MCP tool config
server_label = ""
if mcp_tools_with_litellm_proxy:
first_tool = mcp_tools_with_litellm_proxy[0]
if isinstance(first_tool, dict):
server_label_value = first_tool.get("server_label", "")
server_label = (
str(server_label_value) if server_label_value is not None else ""
)
# Format tools for OpenAI output_item.done format
formatted_tools = []
for tool in filtered_mcp_tools:
tool_dict = {
"name": getattr(tool, "name", "unknown"),
"description": getattr(tool, "description", ""),
"annotations": {"read_only": False},
}
# Add input_schema if available
if hasattr(tool, "inputSchema"):
tool_dict["input_schema"] = getattr(tool, "inputSchema")
elif hasattr(tool, "input_schema"):
tool_dict["input_schema"] = getattr(tool, "input_schema")
formatted_tools.append(tool_dict)
# Create the output_item.done event with MCP tools list
output_item_done_event = OutputItemDoneEvent(
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
output_index=0,
item=BaseLiteLLMOpenAIResponseObject(
**{
"id": base_item_id,
"type": "mcp_list_tools",
"server_label": server_label,
"tools": formatted_tools,
}
),
)
events.append(output_item_done_event)
verbose_logger.debug(f"Created {len(events)} MCP discovery events")
except Exception as e:
verbose_logger.error(f"Error creating MCP list tools events: {e}")
import traceback
traceback.print_exc()
# Emit failed event on error
failed_event = MCPListToolsFailedEvent(
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED,
sequence_number=2,
output_index=0,
item_id=base_item_id,
)
events.append(failed_event)
# Still emit output_item.done event even on failure (with empty tools list)
from litellm.types.llms.openai import OutputItemDoneEvent
output_item_done_event = OutputItemDoneEvent(
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
output_index=0,
item=BaseLiteLLMOpenAIResponseObject(
**{
"id": base_item_id,
"type": "mcp_list_tools",
"server_label": "",
"tools": [],
}
),
)
events.append(output_item_done_event)
return events
def create_mcp_call_events(
tool_name: str,
tool_call_id: str,
arguments: str,
result: Optional[str] = None,
base_item_id: Optional[str] = None,
sequence_start: int = 1,
) -> List[ResponsesAPIStreamingResponse]:
"""Create MCP call events following OpenAI's specification"""
events: List[ResponsesAPIStreamingResponse] = []
item_id = base_item_id or f"mcp_{uuid.uuid4().hex[:8]}"
# MCP call in progress event
in_progress_event = MCPCallInProgressEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS,
sequence_number=sequence_start,
output_index=0,
item_id=item_id,
)
events.append(in_progress_event)
# MCP call arguments delta event (streaming the arguments)
arguments_delta_event = MCPCallArgumentsDeltaEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA,
output_index=0,
item_id=item_id,
delta=arguments, # JSON string with arguments
sequence_number=sequence_start + 1,
)
events.append(arguments_delta_event)
# MCP call arguments done event
arguments_done_event = MCPCallArgumentsDoneEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE,
output_index=0,
item_id=item_id,
arguments=arguments, # Complete JSON string with finalized arguments
sequence_number=sequence_start + 2,
)
events.append(arguments_done_event)
# MCP call completed event (or failed if result indicates failure)
if result is not None:
completed_event = MCPCallCompletedEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
sequence_number=sequence_start + 3,
item_id=item_id,
output_index=0,
)
events.append(completed_event)
# Add output_item.done event with the tool call result
from litellm.types.llms.openai import OutputItemDoneEvent
output_item_done_event = OutputItemDoneEvent(
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
output_index=0,
item=BaseLiteLLMOpenAIResponseObject(
**{
"id": item_id,
"type": "mcp_call",
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
"arguments": arguments,
"error": None,
"name": tool_name,
"output": result,
"server_label": "litellm",
}
),
)
events.append(output_item_done_event)
else:
failed_event = MCPCallFailedEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_FAILED,
sequence_number=sequence_start + 3,
item_id=item_id,
output_index=0,
)
events.append(failed_event)
return events
class MCPEnhancedStreamingIterator(BaseResponsesAPIStreamingIterator):
"""
A complete MCP streaming iterator that handles the entire flow:
1. Immediately emits MCP discovery events
2. Makes the first LLM call and streams its response
3. Handles tool execution and follow-up calls for auto-execute tools
4. Emits tool execution events in the stream
"""
def __init__(
self,
base_iterator: Any, # Can be None - will be created internally
mcp_events: List[ResponsesAPIStreamingResponse],
tool_server_map: dict[str, str],
mcp_tools_with_litellm_proxy: Optional[List[Any]] = None,
user_api_key_auth: Any = None,
original_request_params: Optional[Dict[str, Any]] = None,
):
# MCP setup
self.mcp_tools_with_litellm_proxy = mcp_tools_with_litellm_proxy or []
self.user_api_key_auth = user_api_key_auth
self.original_request_params = original_request_params or {}
self.should_auto_execute = self._should_auto_execute_tools()
# Streaming state management
self.phase = "initial_response" # initial_response -> mcp_discovery -> tool_execution -> follow_up_response -> finished
self.finished = False
# Event queues and generation flags
self.mcp_discovery_events: List[
ResponsesAPIStreamingResponse
] = mcp_events # Pre-generated MCP discovery events
self.tool_execution_events: List[ResponsesAPIStreamingResponse] = []
self.mcp_discovery_generated = True # Events are already generated
self.mcp_events = (
mcp_events # Store the initial MCP events for backward compatibility
)
self.tool_server_map = tool_server_map
# Iterator references
self.base_iterator: Optional[
Union[Any, ResponsesAPIResponse]
] = base_iterator # Will be created when needed
self.follow_up_iterator: Optional[Any] = None
# Response collection for tool execution
self.collected_response: Optional[ResponsesAPIResponse] = None
# Set up model metadata (will be updated when we get the real iterator)
self.model = self.original_request_params.get("model", "unknown")
self.litellm_metadata = {}
self.custom_llm_provider = self.original_request_params.get(
"custom_llm_provider", None
)
self.litellm_call_id = self.original_request_params.get("litellm_call_id")
self.litellm_trace_id = self.original_request_params.get("litellm_trace_id")
self._extract_mcp_headers_from_params()
# Mark as async iterator
self.is_async = True
# Track if we've emitted initial OpenAI lifecycle events
self.initial_events_emitted = False
# Cache the response ID to ensure consistency across all events
self._cached_response_id: Optional[str] = None
def _extract_mcp_headers_from_params(self) -> None:
"""Extract MCP headers from original request params to pass to tool calls"""
from typing import Dict, Optional
from starlette.datastructures import Headers
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
# Extract headers from secret_fields in original_request_params
raw_headers_from_request: Optional[Dict[str, str]] = None
secret_fields = self.original_request_params.get("secret_fields")
if secret_fields and isinstance(secret_fields, dict):
raw_headers_from_request = secret_fields.get("raw_headers")
# Extract MCP-specific headers
self.mcp_auth_header: Optional[str] = None
self.mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
self.oauth2_headers: Optional[Dict[str, str]] = None
self.raw_headers: Optional[Dict[str, str]] = raw_headers_from_request
if raw_headers_from_request:
headers_obj = Headers(raw_headers_from_request)
self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
headers_obj
)
self.mcp_server_auth_headers = (
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
)
self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
headers_obj
)
# Also check if headers are provided in tools array (from request body)
tools = self.original_request_params.get("tools")
if tools:
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "mcp":
tool_headers = tool.get("headers", {})
if tool_headers and isinstance(tool_headers, dict):
# Merge tool headers into mcp_server_auth_headers
headers_obj_from_tool = Headers(tool_headers)
tool_mcp_server_auth_headers = (
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
headers_obj_from_tool
)
)
if tool_mcp_server_auth_headers:
if self.mcp_server_auth_headers is None:
self.mcp_server_auth_headers = {}
# Merge the headers from tool into existing headers
for (
server_alias,
headers_dict,
) in tool_mcp_server_auth_headers.items():
if server_alias not in self.mcp_server_auth_headers:
self.mcp_server_auth_headers[server_alias] = {}
self.mcp_server_auth_headers[server_alias].update(
headers_dict
)
# Also merge raw headers
if self.raw_headers is None:
self.raw_headers = {}
self.raw_headers.update(tool_headers)
def _should_auto_execute_tools(self) -> bool:
"""Check if tools should be auto-executed"""
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
return LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
self.mcp_tools_with_litellm_proxy
)
def __aiter__(self):
return self
async def __anext__(self) -> ResponsesAPIStreamingResponse:
"""
Phase-based streaming:
1. initial_response - Stream the first LLM response (includes response.created, response.in_progress, response.output_item.added)
2. mcp_discovery - Emit MCP discovery events (after response.output_item.added)
3. continue_initial_response - Continue streaming the initial response content
4. tool_execution - Emit tool execution events
5. follow_up_response - Stream the follow-up response
6. finished - End iteration
"""
# Phase 1: Initial Response Stream (emit standard OpenAI events first)
if self.phase == "initial_response":
result = await self._handle_initial_response_phase()
if result is not None:
return result
# Phase 2: MCP Discovery Events (after response.output_item.added)
if self.phase == "mcp_discovery":
# Emit MCP discovery events
if self.mcp_discovery_events:
return self.mcp_discovery_events.pop(0)
self.phase = "continue_initial_response"
# Fall through to continue processing the initial response
# Phase 3: Continue Initial Response (after MCP discovery events)
if self.phase == "continue_initial_response":
try:
return await self._process_base_iterator_chunk()
except StopAsyncIteration:
# Initial response ended, move to next phase
if self.should_auto_execute and self.collected_response:
self.phase = "tool_execution"
await self._generate_tool_execution_events()
else:
self.phase = "finished"
raise
# Phase 4: Tool Execution Events
if self.phase == "tool_execution":
# Emit any queued tool execution events
if self.tool_execution_events:
return self.tool_execution_events.pop(0)
# Move to follow-up response phase
self.phase = "follow_up_response"
await self._create_follow_up_iterator()
# Phase 5: Follow-up Response Stream
if self.phase == "follow_up_response":
if self.follow_up_iterator:
try:
return await cast(Any, self.follow_up_iterator).__anext__() # type: ignore[attr-defined]
except StopAsyncIteration:
self.phase = "finished"
raise
else:
self.phase = "finished"
raise StopAsyncIteration
# Phase 6: Finished
if self.phase == "finished":
raise StopAsyncIteration
# Should not reach here
raise StopAsyncIteration
async def _handle_initial_response_phase(
self,
) -> Optional[ResponsesAPIStreamingResponse]:
"""
Handle Phase 1: Initial Response Stream.
Returns a chunk to emit, or None to fall through to the next phase.
Raises StopAsyncIteration when the stream is exhausted with no auto-execution.
"""
if self.base_iterator is None:
await self._create_initial_response_iterator()
if self.base_iterator is None:
# LLM call failed — still emit MCP discovery events before finishing
if self.mcp_discovery_events:
self.phase = "mcp_discovery"
else:
self.phase = "finished"
raise StopAsyncIteration
return None
if self.base_iterator:
if hasattr(self.base_iterator, "__anext__"):
try:
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
# Capture the response ID from the first event to ensure consistency
if self._cached_response_id is None and hasattr(chunk, "response"):
response_obj = getattr(chunk, "response", None)
if response_obj and hasattr(response_obj, "id"):
self._cached_response_id = response_obj.id
verbose_logger.debug(
f"Cached response ID: {self._cached_response_id}"
)
# After emitting response.output_item.added, transition to MCP discovery
if not self.initial_events_emitted and hasattr(chunk, "type"):
chunk_type = getattr(chunk, "type", None)
if chunk_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
self.initial_events_emitted = True
self.phase = "mcp_discovery"
return chunk
# If auto-execution is enabled, check for completed responses
if self.should_auto_execute and self._is_response_completed(chunk):
response_obj = getattr(chunk, "response", None)
if isinstance(response_obj, ResponsesAPIResponse):
self.collected_response = response_obj
self.phase = "tool_execution"
await self._generate_tool_execution_events()
return chunk
except StopAsyncIteration:
if self.should_auto_execute and self.collected_response:
self.phase = "tool_execution"
await self._generate_tool_execution_events()
else:
self.phase = "finished"
raise
else:
# base_iterator is not async iterable (likely a ResponsesAPIResponse)
if self.should_auto_execute and isinstance(
self.base_iterator, ResponsesAPIResponse
):
self.collected_response = self.base_iterator
self.phase = "tool_execution"
await self._generate_tool_execution_events()
else:
self.phase = "finished"
raise StopAsyncIteration
return None
def _is_response_completed(self, chunk: ResponsesAPIStreamingResponse) -> bool:
"""Check if this chunk indicates the response is completed"""
from litellm.types.llms.openai import ResponsesAPIStreamEvents
return (
getattr(chunk, "type", None) == ResponsesAPIStreamEvents.RESPONSE_COMPLETED
)
async def _process_base_iterator_chunk(self) -> ResponsesAPIStreamingResponse:
"""
Process a chunk from the base iterator with response ID consistency enforcement.
"""
if not self.base_iterator or not hasattr(self.base_iterator, "__anext__"):
raise StopAsyncIteration
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
# Ensure response ID consistency - update chunk if needed
if self._cached_response_id and hasattr(chunk, "response"):
response_obj = getattr(chunk, "response", None)
if response_obj and hasattr(response_obj, "id"):
if response_obj.id != self._cached_response_id:
verbose_logger.debug(
f"Updating response ID from {response_obj.id} to {self._cached_response_id}"
)
response_obj.id = self._cached_response_id
# If auto-execution is enabled, check for completed responses
if self.should_auto_execute and self._is_response_completed(chunk):
# Collect the response for tool execution
response_obj = getattr(chunk, "response", None)
if isinstance(response_obj, ResponsesAPIResponse):
self.collected_response = response_obj
# Move to tool execution phase after emitting this chunk
self.phase = "tool_execution"
await self._generate_tool_execution_events()
return chunk
async def _create_initial_response_iterator(self) -> None:
"""Create the initial response iterator by making the first LLM call"""
try:
# Import the core aresponses function that doesn't have MCP logic
from litellm.responses.main import aresponses
# Make the initial response API call - but avoid the MCP wrapper
params = self.original_request_params.copy()
params["stream"] = True # Ensure streaming
# Use the pre-fetched all_tools from original_request_params (no re-processing needed)
params_for_llm = {}
for key, value in params.items():
params_for_llm[
key
] = value # Copy all params as-is since tools are already processed
tools_count = (
len(params_for_llm.get("tools", []))
if params_for_llm.get("tools")
else 0
)
verbose_logger.debug(f"Making LLM call with {tools_count} tools")
response = await aresponses(**params_for_llm)
# Set the base iterator
if hasattr(response, "__aiter__") or hasattr(response, "__iter__"):
self.base_iterator = response
# Copy metadata from the real iterator
self.model = getattr(response, "model", self.model)
self.litellm_metadata = getattr(response, "litellm_metadata", {})
self.custom_llm_provider = getattr(
response, "custom_llm_provider", self.custom_llm_provider
)
verbose_logger.debug(
f"Created base iterator: {type(self.base_iterator)}"
)
else:
# Non-streaming response - this shouldn't happen but handle it
verbose_logger.warning(f"Got non-streaming response: {type(response)}")
self.base_iterator = None
self.phase = "finished"
except Exception as e:
verbose_logger.error(f"Error creating initial response iterator: {e}")
import traceback
traceback.print_exc()
self.base_iterator = None
# Don't set phase to "finished" here — let __anext__ emit any
# pre-generated MCP discovery events before ending the iteration.
async def _generate_tool_execution_events(self) -> None:
"""Generate tool execution events and execute tools"""
if not self.collected_response:
return
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
try:
# Extract tool calls from the response
if self.collected_response is not None:
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_response(self.collected_response) # type: ignore[arg-type]
else:
tool_calls = []
if not tool_calls:
return
for tool_call in tool_calls:
(
tool_name,
tool_arguments,
tool_call_id,
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
if tool_name and tool_call_id:
# Create MCP call events for this tool execution
call_events = create_mcp_call_events(
tool_name=tool_name,
tool_call_id=tool_call_id,
arguments=tool_arguments or "{}", # JSON string with arguments
result=None, # Will be set after execution
base_item_id=f"mcp_{uuid.uuid4().hex[:8]}",
sequence_start=len(self.tool_execution_events) + 1,
)
# Add the in_progress and arguments events (not the completed event yet)
self.tool_execution_events.extend(call_events[:-1])
# Execute the tools
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
tool_server_map=self.tool_server_map,
tool_calls=tool_calls,
user_api_key_auth=self.user_api_key_auth,
mcp_auth_header=self.mcp_auth_header,
mcp_server_auth_headers=self.mcp_server_auth_headers,
oauth2_headers=self.oauth2_headers,
raw_headers=self.raw_headers,
litellm_call_id=self.litellm_call_id,
litellm_trace_id=self.litellm_trace_id,
)
# Create completion events and output_item.done events for tool execution
for tool_result in tool_results:
tool_call_id = tool_result.get("tool_call_id", "unknown")
result_text = tool_result.get("result", "")
# Find matching tool name and arguments
tool_name = "unknown"
tool_arguments = "{}"
for tool_call in tool_calls:
(
name,
args,
call_id,
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
if call_id == tool_call_id:
tool_name = name or "unknown"
tool_arguments = args or "{}"
break
item_id = f"mcp_{uuid.uuid4().hex[:8]}"
# Create the completion event
completed_event = MCPCallCompletedEvent(
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
sequence_number=len(self.tool_execution_events) + 1,
item_id=item_id,
output_index=0,
)
self.tool_execution_events.append(completed_event)
# Create output_item.done event with the tool call result
from litellm.types.llms.openai import OutputItemDoneEvent
output_item_done_event = OutputItemDoneEvent(
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
output_index=0,
item=BaseLiteLLMOpenAIResponseObject(
**{
"id": item_id,
"type": "mcp_call",
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
"arguments": tool_arguments,
"error": None,
"name": tool_name,
"output": result_text,
"server_label": "litellm", # or extract from tool config
}
),
)
self.tool_execution_events.append(output_item_done_event)
# Store tool results for follow-up call
self.tool_results = tool_results
except Exception as e:
verbose_logger.error(f"Error in tool execution: {e}")
import traceback
traceback.print_exc()
self.tool_results = []
async def _create_follow_up_iterator(self) -> None:
"""Create the follow-up response iterator with tool results"""
if not self.collected_response or not hasattr(self, "tool_results"):
return
from litellm.responses.main import aresponses
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
try:
# Create follow-up input
if self.collected_response is not None:
follow_up_input = LiteLLM_Proxy_MCP_Handler._create_follow_up_input(
response=self.collected_response, # type: ignore[arg-type]
tool_results=self.tool_results,
original_input=self.original_request_params.get("input"),
)
# Make follow-up call with streaming
follow_up_params = self.original_request_params.copy()
follow_up_params.update(
{
"input": follow_up_input,
"stream": True,
}
)
else:
return
# Remove tool_choice to avoid forcing more tool calls
follow_up_params.pop("tool_choice", None)
follow_up_response = await aresponses(**follow_up_params)
# Set up the follow-up iterator
if hasattr(follow_up_response, "__aiter__"):
self.follow_up_iterator = follow_up_response
except Exception as e:
verbose_logger.error(f"Error creating follow-up iterator: {e}")
import traceback
traceback.print_exc()
self.follow_up_iterator = None
def __iter__(self):
return self
def __next__(self) -> ResponsesAPIStreamingResponse:
# First, emit any queued MCP events
if self.mcp_events: # type: ignore[attr-defined]
return self.mcp_events.pop(0) # type: ignore[attr-defined]
# Then delegate to the base iterator
if not self.is_async:
try:
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
return next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
else:
raise StopIteration
except StopIteration:
self.finished = True
raise
else:
raise RuntimeError("Cannot use sync iteration on async iterator")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,726 @@
import base64
from typing import (
Any,
Dict,
Iterable,
List,
Optional,
Type,
Union,
cast,
get_type_hints,
overload,
)
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
from litellm.types.llms.openai import (
ResponseAPIUsage,
ResponsesAPIOptionalRequestParams,
ResponsesAPIResponse,
ResponseText,
)
from litellm.types.responses.main import DecodedResponseId
from litellm.types.utils import (
CompletionTokensDetailsWrapper,
PromptTokensDetailsWrapper,
SpecialEnums,
Usage,
)
class ResponsesAPIRequestUtils:
"""Helper utils for constructing ResponseAPI requests"""
@staticmethod
def _check_valid_arg(
supported_params: Optional[List[str]],
non_default_params: Dict,
drop_params: Optional[bool],
custom_llm_provider: Optional[str],
model: str,
):
if supported_params is None:
return
unsupported_params = {}
for k in non_default_params.keys():
if k not in supported_params:
unsupported_params[k] = non_default_params[k]
if unsupported_params:
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
raise litellm.UnsupportedParamsError(
status_code=500,
message=f"{custom_llm_provider} does not support parameters: {unsupported_params}, for model={model}. To drop these, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\n",
)
@staticmethod
def get_optional_params_responses_api(
model: str,
responses_api_provider_config: BaseResponsesAPIConfig,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
allowed_openai_params: Optional[List[str]] = None,
) -> Dict:
"""
Get optional parameters for the responses API.
Args:
params: Dictionary of all parameters
model: The model name
responses_api_provider_config: The provider configuration for responses API
Returns:
A dictionary of supported parameters for the responses API
"""
from litellm.utils import _apply_openai_param_overrides
# Remove None values and internal parameters
# Get supported parameters for the model
supported_params = responses_api_provider_config.get_supported_openai_params(
model
)
non_default_params = cast(Dict, response_api_optional_params)
# Check for unsupported parameters
ResponsesAPIRequestUtils._check_valid_arg(
supported_params=supported_params + (allowed_openai_params or []),
non_default_params=non_default_params,
drop_params=litellm.drop_params,
custom_llm_provider=responses_api_provider_config.custom_llm_provider,
model=model,
)
# Map parameters to provider-specific format
mapped_params = responses_api_provider_config.map_openai_params(
response_api_optional_params=response_api_optional_params,
model=model,
drop_params=litellm.drop_params,
)
# add any allowed_openai_params to the mapped_params
mapped_params = _apply_openai_param_overrides(
optional_params=mapped_params,
non_default_params=non_default_params,
allowed_openai_params=allowed_openai_params or [],
)
return mapped_params
@staticmethod
def get_requested_response_api_optional_param(
params: Dict[str, Any],
) -> ResponsesAPIOptionalRequestParams:
"""
Filter parameters to only include those defined in ResponsesAPIOptionalRequestParams.
Args:
params: Dictionary of parameters to filter
Returns:
ResponsesAPIOptionalRequestParams instance with only the valid parameters
"""
from litellm.utils import PreProcessNonDefaultParams
valid_keys = get_type_hints(ResponsesAPIOptionalRequestParams).keys()
custom_llm_provider = params.pop("custom_llm_provider", None)
special_params = params.pop("kwargs", {})
additional_drop_params = params.pop("additional_drop_params", None)
non_default_params = (
PreProcessNonDefaultParams.base_pre_process_non_default_params(
passed_params=params,
special_params=special_params,
custom_llm_provider=custom_llm_provider,
additional_drop_params=additional_drop_params,
default_param_values={k: None for k in valid_keys},
additional_endpoint_specific_params=["input"],
)
)
# decode previous_response_id if it's a litellm encoded id
if "previous_response_id" in non_default_params:
decoded_previous_response_id = ResponsesAPIRequestUtils.decode_previous_response_id_to_original_previous_response_id(
non_default_params["previous_response_id"]
)
non_default_params["previous_response_id"] = decoded_previous_response_id
if "metadata" in non_default_params:
from litellm.utils import add_openai_metadata
converted_metadata = add_openai_metadata(non_default_params["metadata"])
if converted_metadata is not None:
non_default_params["metadata"] = converted_metadata
else:
non_default_params.pop("metadata", None)
return cast(ResponsesAPIOptionalRequestParams, non_default_params)
# fmt: off
@overload
@staticmethod
def _update_responses_api_response_id_with_model_id(
responses_api_response: ResponsesAPIResponse,
custom_llm_provider: Optional[str],
litellm_metadata: Optional[Dict[str, Any]] = None,
) -> ResponsesAPIResponse:
...
@overload
@staticmethod
def _update_responses_api_response_id_with_model_id(
responses_api_response: Dict[str, Any],
custom_llm_provider: Optional[str],
litellm_metadata: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
...
# fmt: on
@staticmethod
def _update_responses_api_response_id_with_model_id(
responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]],
custom_llm_provider: Optional[str],
litellm_metadata: Optional[Dict[str, Any]] = None,
) -> Union[ResponsesAPIResponse, Dict[str, Any]]:
"""Update the responses_api_response_id with model_id and custom_llm_provider.
Handles both ``ResponsesAPIResponse`` objects and plain dictionaries returned
by some streaming providers.
"""
litellm_metadata = litellm_metadata or {}
model_info: Dict[str, Any] = litellm_metadata.get("model_info", {}) or {}
model_id = model_info.get("id")
# access the response id based on the object type
if isinstance(responses_api_response, dict):
response_id = responses_api_response.get("id")
else:
response_id = getattr(responses_api_response, "id", None)
# If no response_id, return the response as-is (likely an error response)
if response_id is None:
return responses_api_response
updated_id = ResponsesAPIRequestUtils._build_responses_api_response_id(
model_id=model_id,
custom_llm_provider=custom_llm_provider,
response_id=response_id,
)
if isinstance(responses_api_response, dict):
responses_api_response["id"] = updated_id
else:
responses_api_response.id = updated_id
if litellm_metadata.get("encrypted_content_affinity_enabled"):
responses_api_response = (
ResponsesAPIRequestUtils._update_encrypted_content_item_ids_in_response(
response=responses_api_response,
model_id=model_id,
)
)
return responses_api_response
@staticmethod
def _build_encrypted_item_id(model_id: str, item_id: str) -> str:
"""Encode model_id into an output item ID for encrypted-content items.
Format: ``encitem_{base64("litellm:model_id:{model_id};item_id:{original_id}")}``
"""
assembled = f"litellm:model_id:{model_id};item_id:{item_id}"
encoded = base64.b64encode(assembled.encode("utf-8")).decode("utf-8")
return f"encitem_{encoded}"
@staticmethod
def _decode_encrypted_item_id(encoded_id: str) -> Optional[Dict[str, str]]:
"""Decode a litellm-encoded encrypted-content item ID.
Returns a dict with ``model_id`` and ``item_id`` keys, or ``None`` if
the string is not a litellm-encoded item ID.
"""
if not encoded_id.startswith("encitem_"):
return None
try:
cleaned = encoded_id[len("encitem_") :]
# Restore any padding that may have been stripped in transit
missing = len(cleaned) % 4
if missing:
cleaned += "=" * (4 - missing)
decoded = base64.b64decode(cleaned.encode("utf-8")).decode("utf-8")
# Split on first ";" only so that semicolons inside item_id are preserved
parts = decoded.split(";", 1)
if len(parts) < 2:
return None
model_id = parts[0].replace("litellm:model_id:", "")
item_id = parts[1].replace("item_id:", "")
return {"model_id": model_id, "item_id": item_id}
except Exception:
return None
@staticmethod
def _wrap_encrypted_content_with_model_id(
encrypted_content: str, model_id: str
) -> str:
"""Wrap encrypted_content with model_id metadata for affinity routing.
When Codex or other clients send items with encrypted_content but no ID,
we encode the model_id directly into the encrypted_content itself.
Format: ``litellm_enc:{base64("model_id:{model_id}")};{original_encrypted_content}``
"""
metadata = f"model_id:{model_id}"
encoded_metadata = base64.b64encode(metadata.encode("utf-8")).decode("utf-8")
return f"litellm_enc:{encoded_metadata};{encrypted_content}"
@staticmethod
def _unwrap_encrypted_content_with_model_id(
wrapped_content: str,
) -> tuple[Optional[str], str]:
"""Unwrap encrypted_content to extract model_id and original content.
Returns:
Tuple of (model_id, original_encrypted_content).
If not wrapped, returns (None, original_content).
"""
if not wrapped_content.startswith("litellm_enc:"):
return None, wrapped_content
try:
# Split on first ";" to separate metadata from content
parts = wrapped_content.split(";", 1)
if len(parts) < 2:
return None, wrapped_content
metadata_b64 = parts[0].replace("litellm_enc:", "")
original_content = parts[1]
# Restore padding if needed
missing = len(metadata_b64) % 4
if missing:
metadata_b64 += "=" * (4 - missing)
decoded_metadata = base64.b64decode(metadata_b64.encode("utf-8")).decode(
"utf-8"
)
model_id = decoded_metadata.replace("model_id:", "")
return model_id, original_content
except Exception:
return None, wrapped_content
@staticmethod
def _update_encrypted_content_item_ids_in_response(
response: Union["ResponsesAPIResponse", Dict[str, Any]],
model_id: Optional[str],
) -> Union["ResponsesAPIResponse", Dict[str, Any]]:
"""Rewrite item IDs for output items that contain ``encrypted_content``.
Encodes ``model_id`` into the item ID so that follow-up requests can be
routed back to the originating deployment without any cache lookup.
For items without an ID (e.g., from Codex), encodes model_id directly
into the encrypted_content itself.
"""
if not model_id:
return response
output: Optional[list] = None
if isinstance(response, dict):
output = response.get("output")
else:
output = getattr(response, "output", None)
if not isinstance(output, list):
return response
for item in output:
if isinstance(item, dict):
item_id = item.get("id")
encrypted_content = item.get("encrypted_content")
if encrypted_content and isinstance(encrypted_content, str):
# Always wrap encrypted_content with model_id for redundancy
item[
"encrypted_content"
] = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
encrypted_content, model_id
)
# Also encode the ID if present
if item_id and isinstance(item_id, str):
item["id"] = ResponsesAPIRequestUtils._build_encrypted_item_id(
model_id, item_id
)
else:
item_id = getattr(item, "id", None)
encrypted_content = getattr(item, "encrypted_content", None)
if encrypted_content and isinstance(encrypted_content, str):
# Always wrap encrypted_content with model_id for redundancy
try:
item.encrypted_content = ResponsesAPIRequestUtils._wrap_encrypted_content_with_model_id(
encrypted_content, model_id
)
except AttributeError:
pass
# Also encode the ID if present
if item_id and isinstance(item_id, str):
try:
item.id = ResponsesAPIRequestUtils._build_encrypted_item_id(
model_id, item_id
)
except AttributeError:
pass
return response
@staticmethod
def _restore_encrypted_content_item_ids_in_input(request_input: Any) -> Any:
"""Decode litellm-encoded item IDs in request input back to original IDs.
Called before forwarding the request to the upstream provider so the
provider receives the original item IDs and unwrapped encrypted_content.
Handles both:
1. Items with encoded IDs (encitem_...)
2. Items with wrapped encrypted_content (litellm_enc:...)
"""
if not isinstance(request_input, list):
return request_input
for item in request_input:
if isinstance(item, dict):
item_id = item.get("id")
if item_id and isinstance(item_id, str):
decoded = ResponsesAPIRequestUtils._decode_encrypted_item_id(
item_id
)
if decoded:
item["id"] = decoded["item_id"]
encrypted_content = item.get("encrypted_content")
if encrypted_content and isinstance(encrypted_content, str):
(
_,
unwrapped,
) = ResponsesAPIRequestUtils._unwrap_encrypted_content_with_model_id(
encrypted_content
)
if unwrapped != encrypted_content:
item["encrypted_content"] = unwrapped
return request_input
@staticmethod
def _build_responses_api_response_id(
custom_llm_provider: Optional[str],
model_id: Optional[str],
response_id: str,
) -> str:
"""Build the responses_api_response_id"""
assembled_id: str = str(
SpecialEnums.LITELLM_MANAGED_RESPONSE_COMPLETE_STR.value
).format(custom_llm_provider, model_id, response_id)
base64_encoded_id: str = base64.b64encode(assembled_id.encode("utf-8")).decode(
"utf-8"
)
return f"resp_{base64_encoded_id}"
@staticmethod
def _decode_responses_api_response_id(
response_id: str,
) -> DecodedResponseId:
"""
Decode the responses_api_response_id
Returns:
DecodedResponseId: Structured tuple with custom_llm_provider, model_id, and response_id
"""
try:
# Remove prefix and decode
cleaned_id = response_id.replace("resp_", "")
decoded_id = base64.b64decode(cleaned_id.encode("utf-8")).decode("utf-8")
# Parse components using known prefixes
if ";" not in decoded_id:
return DecodedResponseId(
custom_llm_provider=None,
model_id=None,
response_id=response_id,
)
parts = decoded_id.split(";")
# Format: litellm:custom_llm_provider:{};model_id:{};response_id:{}
custom_llm_provider = None
model_id = None
if (
len(parts) >= 3
): # Full format with custom_llm_provider, model_id, and response_id
custom_llm_provider_part = parts[0]
model_id_part = parts[1]
response_part = parts[2]
custom_llm_provider = custom_llm_provider_part.replace(
"litellm:custom_llm_provider:", ""
)
model_id = model_id_part.replace("model_id:", "")
decoded_response_id = response_part.replace("response_id:", "")
else:
decoded_response_id = response_id
return DecodedResponseId(
custom_llm_provider=custom_llm_provider,
model_id=model_id,
response_id=decoded_response_id,
)
except Exception as e:
verbose_logger.debug(f"Error decoding response_id '{response_id}': {e}")
return DecodedResponseId(
custom_llm_provider=None,
model_id=None,
response_id=response_id,
)
@staticmethod
def get_model_id_from_response_id(response_id: Optional[str]) -> Optional[str]:
"""Get the model_id from the response_id"""
if response_id is None:
return None
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(response_id)
)
return decoded_response_id.get("model_id") or None
@staticmethod
def decode_previous_response_id_to_original_previous_response_id(
previous_response_id: str,
) -> str:
"""
Decode the previous_response_id to the original previous_response_id
Why?
- LiteLLM encodes the `custom_llm_provider` and `model_id` into the `previous_response_id` this helps with maintaining session consistency when load balancing multiple deployments of the same model.
- We cannot send the litellm encoded b64 to the upstream llm api, hence we decode it to the original `previous_response_id`
Args:
previous_response_id: The previous_response_id to decode
Returns:
The original previous_response_id
"""
decoded_response_id = (
ResponsesAPIRequestUtils._decode_responses_api_response_id(
previous_response_id
)
)
return decoded_response_id.get("response_id", previous_response_id)
@staticmethod
def convert_text_format_to_text_param(
text_format: Optional[Union[Type["BaseModel"], dict]],
text: Optional["ResponseText"] = None,
) -> Optional["ResponseText"]:
"""
Convert text_format parameter to text parameter for the responses API.
Args:
text_format: Pydantic model class or dict to convert to response format
text: Existing text parameter (if provided, text_format is ignored)
Returns:
ResponseText object with the converted format, or None if conversion fails
"""
if text_format is not None and text is None:
from litellm.llms.base_llm.base_utils import type_to_response_format_param
# Convert Pydantic model to response format
response_format = type_to_response_format_param(text_format)
if response_format is not None:
# Create ResponseText object with the format
# The responses API expects the format to have name at the top level
text = {
"format": {
"type": response_format["type"],
"name": response_format["json_schema"]["name"],
"schema": response_format["json_schema"]["schema"],
"strict": response_format["json_schema"]["strict"],
}
}
return text
return text
@staticmethod
def extract_mcp_headers_from_request(
secret_fields: Optional[Dict[str, Any]],
tools: Optional[Iterable[Any]],
) -> tuple[
Optional[str],
Optional[Dict[str, Dict[str, str]]],
Optional[Dict[str, str]],
Optional[Dict[str, str]],
]:
"""
Extract MCP auth headers from the request to pass to MCP server.
Headers from tools.headers in request body should be passed to MCP server.
"""
from starlette.datastructures import Headers
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
# Extract headers from secret_fields which contains the original request headers
raw_headers_from_request: Optional[Dict[str, str]] = None
if secret_fields and isinstance(secret_fields, dict):
raw_headers_from_request = secret_fields.get("raw_headers")
# Extract MCP-specific headers using MCPRequestHandler methods
mcp_auth_header: Optional[str] = None
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
oauth2_headers: Optional[Dict[str, str]] = None
if raw_headers_from_request:
headers_obj = Headers(raw_headers_from_request)
mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
headers_obj
)
mcp_server_auth_headers = (
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
)
oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
headers_obj
)
if tools:
for tool in tools:
if isinstance(tool, dict) and tool.get("type") == "mcp":
tool_headers = tool.get("headers", {})
if tool_headers and isinstance(tool_headers, dict):
# Merge tool headers into mcp_server_auth_headers
# Extract server-specific headers from tool.headers
headers_obj_from_tool = Headers(tool_headers)
tool_mcp_server_auth_headers = (
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
headers_obj_from_tool
)
)
if tool_mcp_server_auth_headers:
if mcp_server_auth_headers is None:
mcp_server_auth_headers = {}
# Merge the headers from tool into existing headers
for (
server_alias,
headers_dict,
) in tool_mcp_server_auth_headers.items():
if server_alias not in mcp_server_auth_headers:
mcp_server_auth_headers[server_alias] = {}
mcp_server_auth_headers[server_alias].update(
headers_dict
)
# Also merge raw headers (non-prefixed headers from tool.headers)
if raw_headers_from_request is None:
raw_headers_from_request = {}
raw_headers_from_request.update(tool_headers)
return (
mcp_auth_header,
mcp_server_auth_headers,
oauth2_headers,
raw_headers_from_request,
)
class ResponseAPILoggingUtils:
@staticmethod
def _is_response_api_usage(usage: Union[dict, ResponseAPIUsage]) -> bool:
"""returns True if usage is from OpenAI Response API"""
if isinstance(usage, ResponseAPIUsage):
return True
if "input_tokens" in usage and "output_tokens" in usage:
return True
return False
@staticmethod
def _transform_response_api_usage_to_chat_usage(
usage_input: Optional[Union[dict, ResponseAPIUsage]],
) -> Usage:
"""
Transforms ResponseAPIUsage or ImageUsage to a Usage object.
Both have the same spec with input_tokens, output_tokens, and
input_tokens_details (text_tokens, image_tokens).
"""
if usage_input is None:
return Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
response_api_usage: ResponseAPIUsage
if isinstance(usage_input, dict):
total_tokens = usage_input.get("total_tokens")
if total_tokens is None:
input_tokens = usage_input.get("input_tokens")
output_tokens = usage_input.get("output_tokens")
if input_tokens is not None and output_tokens is not None:
total_tokens = input_tokens + output_tokens
usage_input["total_tokens"] = total_tokens
response_api_usage = ResponseAPIUsage(**usage_input)
else:
response_api_usage = usage_input
prompt_tokens: int = response_api_usage.input_tokens or 0
completion_tokens: int = response_api_usage.output_tokens or 0
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
if response_api_usage.input_tokens_details:
if isinstance(response_api_usage.input_tokens_details, dict):
prompt_tokens_details = PromptTokensDetailsWrapper(
**response_api_usage.input_tokens_details
)
else:
prompt_tokens_details = PromptTokensDetailsWrapper(
cached_tokens=getattr(
response_api_usage.input_tokens_details, "cached_tokens", None
),
audio_tokens=getattr(
response_api_usage.input_tokens_details, "audio_tokens", None
),
text_tokens=getattr(
response_api_usage.input_tokens_details, "text_tokens", None
),
image_tokens=getattr(
response_api_usage.input_tokens_details, "image_tokens", None
),
)
completion_tokens_details: Optional[CompletionTokensDetailsWrapper] = None
output_tokens_details = getattr(
response_api_usage, "output_tokens_details", None
)
if output_tokens_details:
completion_tokens_details = CompletionTokensDetailsWrapper(
reasoning_tokens=getattr(
output_tokens_details, "reasoning_tokens", None
),
image_tokens=getattr(output_tokens_details, "image_tokens", None),
text_tokens=getattr(output_tokens_details, "text_tokens", None),
)
chat_usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
prompt_tokens_details=prompt_tokens_details,
completion_tokens_details=completion_tokens_details,
)
# Preserve cost attribute if it exists on ResponseAPIUsage
if hasattr(response_api_usage, "cost") and response_api_usage.cost is not None:
setattr(chat_usage, "cost", response_api_usage.cost)
return chat_usage