chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,670 @@
|
||||
"""Helpers for handling MCP-aware `/chat/completions` requests."""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
List,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
|
||||
def _add_mcp_metadata_to_response(
|
||||
response: Union[ModelResponse, CustomStreamWrapper],
|
||||
openai_tools: Optional[List],
|
||||
tool_calls: Optional[List] = None,
|
||||
tool_results: Optional[List] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add MCP metadata to response's provider_specific_fields.
|
||||
|
||||
This function adds MCP-related information to the response so that
|
||||
clients can access which tools were available, which were called, and
|
||||
what results were returned.
|
||||
|
||||
For ModelResponse: adds to choices[].message.provider_specific_fields
|
||||
For CustomStreamWrapper: stores in _hidden_params and automatically adds to
|
||||
final chunk's delta.provider_specific_fields via CustomStreamWrapper._add_mcp_metadata_to_final_chunk()
|
||||
"""
|
||||
if isinstance(response, CustomStreamWrapper):
|
||||
# For streaming, store MCP metadata in _hidden_params
|
||||
# CustomStreamWrapper._add_mcp_metadata_to_final_chunk() will automatically
|
||||
# add it to the final chunk's delta.provider_specific_fields
|
||||
if not hasattr(response, "_hidden_params"):
|
||||
response._hidden_params = {}
|
||||
|
||||
mcp_metadata = {}
|
||||
if openai_tools:
|
||||
mcp_metadata["mcp_list_tools"] = openai_tools
|
||||
if tool_calls:
|
||||
mcp_metadata["mcp_tool_calls"] = tool_calls
|
||||
if tool_results:
|
||||
mcp_metadata["mcp_call_results"] = tool_results
|
||||
|
||||
if mcp_metadata:
|
||||
response._hidden_params["mcp_metadata"] = mcp_metadata
|
||||
return
|
||||
|
||||
if not isinstance(response, ModelResponse):
|
||||
return
|
||||
|
||||
if not hasattr(response, "choices") or not response.choices:
|
||||
return
|
||||
|
||||
# Add MCP metadata to all choices' messages
|
||||
for choice in response.choices:
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None:
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
provider_fields = getattr(message, "provider_specific_fields", None) or {}
|
||||
|
||||
# Add MCP metadata
|
||||
if openai_tools:
|
||||
provider_fields["mcp_list_tools"] = openai_tools
|
||||
if tool_calls:
|
||||
provider_fields["mcp_tool_calls"] = tool_calls
|
||||
if tool_results:
|
||||
provider_fields["mcp_call_results"] = tool_results
|
||||
|
||||
# Set the provider_specific_fields
|
||||
setattr(message, "provider_specific_fields", provider_fields)
|
||||
|
||||
|
||||
async def acompletion_with_mcp( # noqa: PLR0915
|
||||
model: str,
|
||||
messages: List,
|
||||
tools: Optional[List] = None,
|
||||
**kwargs: Any,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Async completion with MCP integration.
|
||||
|
||||
This function handles MCP tool integration following the same pattern as aresponses_api_with_mcp.
|
||||
It's designed to be called from the synchronous completion() function and return a coroutine.
|
||||
|
||||
When MCP tools with server_url="litellm_proxy" are provided, this function will:
|
||||
1. Get available tools from the MCP server manager
|
||||
2. Transform them to OpenAI format
|
||||
3. Call acompletion with the transformed tools
|
||||
4. If require_approval="never" and tool calls are returned, automatically execute them
|
||||
5. Make a follow-up call with the tool results
|
||||
"""
|
||||
from litellm import acompletion as litellm_acompletion
|
||||
|
||||
# Parse MCP tools and separate from other tools
|
||||
(
|
||||
mcp_tools_with_litellm_proxy,
|
||||
other_tools,
|
||||
) = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
|
||||
|
||||
if not mcp_tools_with_litellm_proxy:
|
||||
# No MCP tools, proceed with regular completion
|
||||
return await litellm_acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract user_api_key_auth from metadata or kwargs
|
||||
user_api_key_auth = kwargs.get("user_api_key_auth") or (
|
||||
(kwargs.get("metadata", {}) or {}).get("user_api_key_auth")
|
||||
)
|
||||
|
||||
# Extract MCP auth headers before fetching tools (needed for dynamic auth)
|
||||
(
|
||||
mcp_auth_header,
|
||||
mcp_server_auth_headers,
|
||||
oauth2_headers,
|
||||
raw_headers,
|
||||
) = ResponsesAPIRequestUtils.extract_mcp_headers_from_request(
|
||||
secret_fields=kwargs.get("secret_fields"),
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Process MCP tools (pass auth headers for dynamic auth)
|
||||
(
|
||||
deduplicated_mcp_tools,
|
||||
tool_server_map,
|
||||
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_without_openai_transform(
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy,
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
)
|
||||
|
||||
openai_tools = LiteLLM_Proxy_MCP_Handler._transform_mcp_tools_to_openai(
|
||||
deduplicated_mcp_tools,
|
||||
target_format="chat",
|
||||
)
|
||||
|
||||
# Combine with other tools
|
||||
all_tools = openai_tools + other_tools if (openai_tools or other_tools) else None
|
||||
|
||||
# Determine if we should auto-execute tools
|
||||
should_auto_execute = LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
|
||||
mcp_tools_with_litellm_proxy=mcp_tools_with_litellm_proxy
|
||||
)
|
||||
|
||||
# Prepare call parameters
|
||||
# Remove keys that shouldn't be passed to acompletion
|
||||
clean_kwargs = {k: v for k, v in kwargs.items() if k not in ["acompletion"]}
|
||||
|
||||
base_call_args = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": all_tools,
|
||||
"_skip_mcp_handler": True, # Prevent recursion
|
||||
**clean_kwargs,
|
||||
}
|
||||
|
||||
# If not auto-executing, just make the call with transformed tools
|
||||
if not should_auto_execute:
|
||||
response = await litellm_acompletion(**base_call_args)
|
||||
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=response,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return response
|
||||
|
||||
# For auto-execute: handle streaming vs non-streaming differently
|
||||
stream = kwargs.get("stream", False)
|
||||
mock_tool_calls = base_call_args.pop("mock_tool_calls", None)
|
||||
|
||||
if stream:
|
||||
# Streaming mode: make initial call with streaming, collect chunks, detect tool calls
|
||||
initial_call_args = dict(base_call_args)
|
||||
initial_call_args["stream"] = True
|
||||
if mock_tool_calls is not None:
|
||||
initial_call_args["mock_tool_calls"] = mock_tool_calls
|
||||
|
||||
# Make initial streaming call
|
||||
initial_stream = await litellm_acompletion(**initial_call_args)
|
||||
|
||||
if not isinstance(initial_stream, CustomStreamWrapper):
|
||||
# Not a stream, return as-is
|
||||
if isinstance(initial_stream, ModelResponse):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_stream,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return initial_stream
|
||||
|
||||
# Create a custom async generator that collects chunks and handles tool execution
|
||||
from litellm.main import stream_chunk_builder
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
class MCPStreamingIterator:
|
||||
"""Custom iterator that collects chunks, detects tool calls, and adds MCP metadata to final chunk."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream_wrapper,
|
||||
messages,
|
||||
tool_server_map,
|
||||
user_api_key_auth,
|
||||
mcp_auth_header,
|
||||
mcp_server_auth_headers,
|
||||
oauth2_headers,
|
||||
raw_headers,
|
||||
litellm_call_id,
|
||||
litellm_trace_id,
|
||||
openai_tools,
|
||||
base_call_args,
|
||||
):
|
||||
self.stream_wrapper = stream_wrapper
|
||||
self.messages = messages
|
||||
self.tool_server_map = tool_server_map
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.mcp_auth_header = mcp_auth_header
|
||||
self.mcp_server_auth_headers = mcp_server_auth_headers
|
||||
self.oauth2_headers = oauth2_headers
|
||||
self.raw_headers = raw_headers
|
||||
self.litellm_call_id = litellm_call_id
|
||||
self.litellm_trace_id = litellm_trace_id
|
||||
self.openai_tools = openai_tools
|
||||
self.base_call_args = base_call_args
|
||||
self.collected_chunks: List[ModelResponseStream] = []
|
||||
self.tool_calls: Optional[List] = None
|
||||
self.tool_results: Optional[List] = None
|
||||
self.complete_response: Optional[ModelResponse] = None
|
||||
self.stream_exhausted = False
|
||||
self.tool_execution_done = False
|
||||
self.follow_up_stream = None
|
||||
self.follow_up_iterator = None
|
||||
self.follow_up_exhausted = False
|
||||
|
||||
async def __aiter__(self):
|
||||
return self
|
||||
|
||||
def _add_mcp_list_tools_to_chunk(
|
||||
self, chunk: ModelResponseStream
|
||||
) -> ModelResponseStream:
|
||||
"""Add mcp_list_tools to the first chunk."""
|
||||
from litellm.types.utils import (
|
||||
StreamingChoices,
|
||||
add_provider_specific_fields,
|
||||
)
|
||||
|
||||
if not self.openai_tools:
|
||||
return chunk
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
for choice in chunk.choices:
|
||||
if (
|
||||
isinstance(choice, StreamingChoices)
|
||||
and hasattr(choice, "delta")
|
||||
and choice.delta
|
||||
):
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
existing_fields = (
|
||||
getattr(choice.delta, "provider_specific_fields", None)
|
||||
or {}
|
||||
)
|
||||
provider_fields = dict(
|
||||
existing_fields
|
||||
) # Create a copy to avoid mutating the original
|
||||
|
||||
# Add only mcp_list_tools to first chunk
|
||||
provider_fields["mcp_list_tools"] = self.openai_tools
|
||||
|
||||
# Use add_provider_specific_fields to ensure proper setting
|
||||
# This function handles Pydantic model attribute setting correctly
|
||||
add_provider_specific_fields(choice.delta, provider_fields)
|
||||
|
||||
return chunk
|
||||
|
||||
def _add_mcp_tool_metadata_to_final_chunk(
|
||||
self, chunk: ModelResponseStream
|
||||
) -> ModelResponseStream:
|
||||
"""Add mcp_tool_calls and mcp_call_results to the final chunk."""
|
||||
from litellm.types.utils import (
|
||||
StreamingChoices,
|
||||
add_provider_specific_fields,
|
||||
)
|
||||
|
||||
if hasattr(chunk, "choices") and chunk.choices:
|
||||
for choice in chunk.choices:
|
||||
if (
|
||||
isinstance(choice, StreamingChoices)
|
||||
and hasattr(choice, "delta")
|
||||
and choice.delta
|
||||
):
|
||||
# Get existing provider_specific_fields or create new dict
|
||||
# Access the attribute directly to handle Pydantic model attributes correctly
|
||||
existing_fields = {}
|
||||
if hasattr(choice.delta, "provider_specific_fields"):
|
||||
attr_value = getattr(
|
||||
choice.delta, "provider_specific_fields", None
|
||||
)
|
||||
if attr_value is not None:
|
||||
# Create a copy to avoid mutating the original
|
||||
existing_fields = (
|
||||
dict(attr_value)
|
||||
if isinstance(attr_value, dict)
|
||||
else {}
|
||||
)
|
||||
|
||||
provider_fields = existing_fields
|
||||
|
||||
# Add tool_calls and tool_results if available
|
||||
if self.tool_calls:
|
||||
provider_fields["mcp_tool_calls"] = self.tool_calls
|
||||
if self.tool_results:
|
||||
provider_fields["mcp_call_results"] = self.tool_results
|
||||
|
||||
# Use add_provider_specific_fields to ensure proper setting
|
||||
# This function handles Pydantic model attribute setting correctly
|
||||
add_provider_specific_fields(choice.delta, provider_fields)
|
||||
|
||||
return chunk
|
||||
|
||||
async def __anext__(self):
|
||||
# Phase 1: Collect and yield initial stream chunks
|
||||
if not self.stream_exhausted:
|
||||
# Get the iterator from the stream wrapper
|
||||
if not hasattr(self, "_stream_iterator"):
|
||||
self._stream_iterator = self.stream_wrapper.__aiter__()
|
||||
# Add mcp_list_tools to the first chunk (available from the start)
|
||||
_add_mcp_metadata_to_response(
|
||||
response=self.stream_wrapper,
|
||||
openai_tools=self.openai_tools,
|
||||
)
|
||||
|
||||
try:
|
||||
chunk = await self._stream_iterator.__anext__()
|
||||
self.collected_chunks.append(chunk)
|
||||
|
||||
# Add mcp_list_tools to the first chunk
|
||||
if len(self.collected_chunks) == 1:
|
||||
chunk = self._add_mcp_list_tools_to_chunk(chunk)
|
||||
|
||||
# Check if this is the final chunk (has finish_reason)
|
||||
is_final = (
|
||||
hasattr(chunk, "choices")
|
||||
and chunk.choices
|
||||
and hasattr(chunk.choices[0], "finish_reason")
|
||||
and chunk.choices[0].finish_reason is not None
|
||||
)
|
||||
|
||||
if is_final:
|
||||
# This is the final chunk, mark stream as exhausted
|
||||
self.stream_exhausted = True
|
||||
# Process tool calls after we've collected all chunks
|
||||
await self._process_tool_calls()
|
||||
# Apply MCP metadata (tool_calls and tool_results) to final chunk
|
||||
chunk = self._add_mcp_tool_metadata_to_final_chunk(chunk)
|
||||
# If we have tool results, prepare follow-up call immediately
|
||||
if self.tool_results and self.complete_response:
|
||||
await self._prepare_follow_up_call()
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
self.stream_exhausted = True
|
||||
# Process tool calls after stream is exhausted
|
||||
await self._process_tool_calls()
|
||||
# If we have chunks, yield the final one with metadata
|
||||
if self.collected_chunks:
|
||||
final_chunk = self.collected_chunks[-1]
|
||||
final_chunk = self._add_mcp_tool_metadata_to_final_chunk(
|
||||
final_chunk
|
||||
)
|
||||
# If we have tool results, prepare follow-up call
|
||||
if self.tool_results and self.complete_response:
|
||||
await self._prepare_follow_up_call()
|
||||
return final_chunk
|
||||
|
||||
# Phase 2: Yield follow-up stream chunks if available
|
||||
if self.follow_up_stream and not self.follow_up_exhausted:
|
||||
if not self.follow_up_iterator:
|
||||
self.follow_up_iterator = self.follow_up_stream.__aiter__()
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream iterator created")
|
||||
|
||||
try:
|
||||
chunk = await self.follow_up_iterator.__anext__()
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug(f"Follow-up chunk yielded: {chunk}")
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
self.follow_up_exhausted = True
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream exhausted")
|
||||
# After follow-up stream is exhausted, check if we need to raise StopAsyncIteration
|
||||
raise StopAsyncIteration
|
||||
|
||||
# If we're here and follow_up_stream is None but we expected it, log a warning
|
||||
if (
|
||||
self.stream_exhausted
|
||||
and self.tool_results
|
||||
and self.complete_response
|
||||
and self.follow_up_stream is None
|
||||
):
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.warning(
|
||||
"Follow-up stream was not created despite having tool results"
|
||||
)
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _process_tool_calls(self):
|
||||
"""Process tool calls after streaming completes."""
|
||||
if self.tool_execution_done:
|
||||
return
|
||||
|
||||
self.tool_execution_done = True
|
||||
|
||||
if not self.collected_chunks:
|
||||
return
|
||||
|
||||
# Build complete response from chunks
|
||||
complete_response = stream_chunk_builder(
|
||||
chunks=self.collected_chunks,
|
||||
messages=self.messages,
|
||||
)
|
||||
|
||||
if isinstance(complete_response, ModelResponse):
|
||||
self.complete_response = complete_response
|
||||
# Extract tool calls from complete response
|
||||
self.tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
|
||||
response=complete_response
|
||||
)
|
||||
|
||||
if self.tool_calls:
|
||||
# Execute tool calls
|
||||
self.tool_results = (
|
||||
await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=self.tool_server_map,
|
||||
tool_calls=self.tool_calls,
|
||||
user_api_key_auth=self.user_api_key_auth,
|
||||
mcp_auth_header=self.mcp_auth_header,
|
||||
mcp_server_auth_headers=self.mcp_server_auth_headers,
|
||||
oauth2_headers=self.oauth2_headers,
|
||||
raw_headers=self.raw_headers,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
litellm_trace_id=self.litellm_trace_id,
|
||||
)
|
||||
)
|
||||
|
||||
async def _prepare_follow_up_call(self):
|
||||
"""Prepare and initiate follow-up call with tool results."""
|
||||
if self.follow_up_stream is not None:
|
||||
return # Already prepared
|
||||
|
||||
if not self.tool_results or not self.complete_response:
|
||||
return
|
||||
|
||||
# Create follow-up messages with tool results
|
||||
follow_up_messages = (
|
||||
LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
|
||||
original_messages=self.messages,
|
||||
response=self.complete_response,
|
||||
tool_results=self.tool_results,
|
||||
)
|
||||
)
|
||||
|
||||
# Make follow-up call with streaming
|
||||
follow_up_call_args = dict(self.base_call_args)
|
||||
follow_up_call_args["messages"] = follow_up_messages
|
||||
follow_up_call_args["stream"] = True
|
||||
# Ensure follow-up call doesn't trigger MCP handler again
|
||||
follow_up_call_args["_skip_mcp_handler"] = True
|
||||
|
||||
# Import litellm here to ensure we get the patched version
|
||||
# This ensures the patch works correctly in tests
|
||||
import litellm
|
||||
|
||||
follow_up_response = await litellm.acompletion(**follow_up_call_args)
|
||||
|
||||
# Ensure follow-up response is a CustomStreamWrapper
|
||||
if isinstance(follow_up_response, CustomStreamWrapper):
|
||||
self.follow_up_stream = follow_up_response
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.debug("Follow-up stream created successfully")
|
||||
else:
|
||||
# Unexpected response type - log and set to None
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.warning(
|
||||
f"Follow-up response is not a CustomStreamWrapper: {type(follow_up_response)}"
|
||||
)
|
||||
self.follow_up_stream = None
|
||||
|
||||
# Create the custom iterator
|
||||
iterator = MCPStreamingIterator(
|
||||
stream_wrapper=initial_stream,
|
||||
messages=messages,
|
||||
tool_server_map=tool_server_map,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
raw_headers=raw_headers,
|
||||
litellm_call_id=kwargs.get("litellm_call_id"),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
openai_tools=openai_tools,
|
||||
base_call_args=base_call_args,
|
||||
)
|
||||
|
||||
# Create a wrapper class that delegates to our custom iterator
|
||||
# We'll use a simple approach: just replace the __aiter__ method
|
||||
class MCPStreamWrapper(CustomStreamWrapper):
|
||||
def __init__(self, original_wrapper, custom_iterator):
|
||||
# Initialize with the same parameters as original wrapper
|
||||
super().__init__(
|
||||
completion_stream=None,
|
||||
model=getattr(original_wrapper, "model", "unknown"),
|
||||
logging_obj=getattr(original_wrapper, "logging_obj", None),
|
||||
custom_llm_provider=getattr(
|
||||
original_wrapper, "custom_llm_provider", None
|
||||
),
|
||||
stream_options=getattr(original_wrapper, "stream_options", None),
|
||||
make_call=getattr(original_wrapper, "make_call", None),
|
||||
_response_headers=getattr(
|
||||
original_wrapper, "_response_headers", None
|
||||
),
|
||||
)
|
||||
self._original_wrapper = original_wrapper
|
||||
self._custom_iterator = custom_iterator
|
||||
# Copy important attributes from original wrapper
|
||||
if hasattr(original_wrapper, "_hidden_params"):
|
||||
self._hidden_params = original_wrapper._hidden_params
|
||||
# For synchronous iteration, we need to run the async iterator
|
||||
self._sync_iterator = None
|
||||
self._sync_loop = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self._custom_iterator
|
||||
|
||||
def __iter__(self):
|
||||
# For synchronous iteration, create a sync wrapper
|
||||
if self._sync_iterator is None:
|
||||
import asyncio
|
||||
|
||||
try:
|
||||
self._sync_loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._sync_loop)
|
||||
self._sync_iterator = _SyncIteratorWrapper(
|
||||
self._custom_iterator, self._sync_loop
|
||||
)
|
||||
return self._sync_iterator
|
||||
|
||||
def __next__(self):
|
||||
# Delegate to sync iterator
|
||||
if self._sync_iterator is None:
|
||||
self.__iter__()
|
||||
return next(self._sync_iterator)
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Delegate all other attributes to original wrapper
|
||||
return getattr(self._original_wrapper, name)
|
||||
|
||||
# Helper class to wrap async iterator for sync iteration
|
||||
class _SyncIteratorWrapper:
|
||||
def __init__(self, async_iterator, loop):
|
||||
self._async_iterator = async_iterator
|
||||
self._loop = loop
|
||||
self._iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._iterator is None:
|
||||
# __aiter__ might be async, so we need to await it
|
||||
aiter_result = self._async_iterator.__aiter__()
|
||||
if hasattr(aiter_result, "__await__"):
|
||||
# It's a coroutine, await it
|
||||
self._iterator = self._loop.run_until_complete(aiter_result)
|
||||
else:
|
||||
# It's already an iterator
|
||||
self._iterator = aiter_result
|
||||
try:
|
||||
return self._loop.run_until_complete(self._iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
raise StopIteration
|
||||
|
||||
return cast(CustomStreamWrapper, MCPStreamWrapper(initial_stream, iterator))
|
||||
|
||||
# Non-streaming mode: use existing logic
|
||||
initial_call_args = dict(base_call_args)
|
||||
initial_call_args["stream"] = False
|
||||
if mock_tool_calls is not None:
|
||||
initial_call_args["mock_tool_calls"] = mock_tool_calls
|
||||
|
||||
# Make initial call
|
||||
initial_response = await litellm_acompletion(**initial_call_args)
|
||||
|
||||
if not isinstance(initial_response, ModelResponse):
|
||||
return initial_response
|
||||
|
||||
# Extract tool calls from response
|
||||
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_chat_response(
|
||||
response=initial_response
|
||||
)
|
||||
|
||||
if not tool_calls:
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_response,
|
||||
openai_tools=openai_tools,
|
||||
)
|
||||
return initial_response
|
||||
|
||||
# Execute tool calls
|
||||
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=tool_server_map,
|
||||
tool_calls=tool_calls,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
mcp_auth_header=mcp_auth_header,
|
||||
mcp_server_auth_headers=mcp_server_auth_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
raw_headers=raw_headers,
|
||||
litellm_call_id=kwargs.get("litellm_call_id"),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
)
|
||||
|
||||
if not tool_results:
|
||||
_add_mcp_metadata_to_response(
|
||||
response=initial_response,
|
||||
openai_tools=openai_tools,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
return initial_response
|
||||
|
||||
# Create follow-up messages with tool results
|
||||
follow_up_messages = LiteLLM_Proxy_MCP_Handler._create_follow_up_messages_for_chat(
|
||||
original_messages=messages,
|
||||
response=initial_response,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
|
||||
# Make follow-up call with original stream setting
|
||||
follow_up_call_args = dict(base_call_args)
|
||||
follow_up_call_args["messages"] = follow_up_messages
|
||||
follow_up_call_args["stream"] = stream
|
||||
|
||||
response = await litellm_acompletion(**follow_up_call_args)
|
||||
if isinstance(response, (ModelResponse, CustomStreamWrapper)):
|
||||
_add_mcp_metadata_to_response(
|
||||
response=response,
|
||||
openai_tools=openai_tools,
|
||||
tool_calls=tool_calls,
|
||||
tool_results=tool_results,
|
||||
)
|
||||
return response
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,798 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.responses.streaming_iterator import BaseResponsesAPIStreamingIterator
|
||||
from litellm.types.llms.openai import (
|
||||
BaseLiteLLMOpenAIResponseObject,
|
||||
MCPCallArgumentsDeltaEvent,
|
||||
MCPCallArgumentsDoneEvent,
|
||||
MCPCallCompletedEvent,
|
||||
MCPCallFailedEvent,
|
||||
MCPCallInProgressEvent,
|
||||
MCPListToolsCompletedEvent,
|
||||
MCPListToolsFailedEvent,
|
||||
MCPListToolsInProgressEvent,
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvents,
|
||||
ResponsesAPIStreamingResponse,
|
||||
ToolParam,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import Tool as MCPTool
|
||||
else:
|
||||
MCPTool = Any
|
||||
|
||||
|
||||
async def create_mcp_list_tools_events(
|
||||
mcp_tools_with_litellm_proxy: List[ToolParam],
|
||||
user_api_key_auth: Any,
|
||||
base_item_id: str,
|
||||
pre_processed_mcp_tools: List[Any],
|
||||
) -> List[ResponsesAPIStreamingResponse]:
|
||||
"""Create MCP discovery events using pre-processed tools from the parent"""
|
||||
|
||||
events: List[ResponsesAPIStreamingResponse] = []
|
||||
|
||||
try:
|
||||
# Extract MCP server names
|
||||
mcp_servers = []
|
||||
for tool in mcp_tools_with_litellm_proxy:
|
||||
if isinstance(tool, dict) and "server_url" in tool:
|
||||
server_url = tool.get("server_url")
|
||||
if isinstance(server_url, str) and server_url.startswith(
|
||||
"litellm_proxy/mcp/"
|
||||
):
|
||||
server_name = server_url.split("/")[-1]
|
||||
mcp_servers.append(server_name)
|
||||
|
||||
# Emit list tools in progress event
|
||||
in_progress_event = MCPListToolsInProgressEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS,
|
||||
sequence_number=1,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(in_progress_event)
|
||||
|
||||
# Use the pre-processed MCP tools that were already fetched, filtered, and deduplicated by the parent
|
||||
filtered_mcp_tools = pre_processed_mcp_tools
|
||||
|
||||
# Convert tools to dict format for the event
|
||||
mcp_tools_dict = []
|
||||
for tool in filtered_mcp_tools:
|
||||
if hasattr(tool, "model_dump") and callable(getattr(tool, "model_dump")):
|
||||
# Type cast to help mypy understand this is safe after hasattr check
|
||||
mcp_tools_dict.append(cast(Any, tool).model_dump())
|
||||
elif hasattr(tool, "__dict__"):
|
||||
mcp_tools_dict.append(tool.__dict__)
|
||||
else:
|
||||
mcp_tools_dict.append({"name": getattr(tool, "name", str(tool))})
|
||||
|
||||
# Emit list tools completed event
|
||||
completed_event = MCPListToolsCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED,
|
||||
sequence_number=2,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(completed_event)
|
||||
|
||||
# Add output_item.done event with the actual tools list (matching OpenAI format)
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
# Extract server label from the first MCP tool config
|
||||
server_label = ""
|
||||
if mcp_tools_with_litellm_proxy:
|
||||
first_tool = mcp_tools_with_litellm_proxy[0]
|
||||
if isinstance(first_tool, dict):
|
||||
server_label_value = first_tool.get("server_label", "")
|
||||
server_label = (
|
||||
str(server_label_value) if server_label_value is not None else ""
|
||||
)
|
||||
|
||||
# Format tools for OpenAI output_item.done format
|
||||
formatted_tools = []
|
||||
for tool in filtered_mcp_tools:
|
||||
tool_dict = {
|
||||
"name": getattr(tool, "name", "unknown"),
|
||||
"description": getattr(tool, "description", ""),
|
||||
"annotations": {"read_only": False},
|
||||
}
|
||||
|
||||
# Add input_schema if available
|
||||
if hasattr(tool, "inputSchema"):
|
||||
tool_dict["input_schema"] = getattr(tool, "inputSchema")
|
||||
elif hasattr(tool, "input_schema"):
|
||||
tool_dict["input_schema"] = getattr(tool, "input_schema")
|
||||
|
||||
formatted_tools.append(tool_dict)
|
||||
|
||||
# Create the output_item.done event with MCP tools list
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": base_item_id,
|
||||
"type": "mcp_list_tools",
|
||||
"server_label": server_label,
|
||||
"tools": formatted_tools,
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
|
||||
verbose_logger.debug(f"Created {len(events)} MCP discovery events")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating MCP list tools events: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
# Emit failed event on error
|
||||
failed_event = MCPListToolsFailedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED,
|
||||
sequence_number=2,
|
||||
output_index=0,
|
||||
item_id=base_item_id,
|
||||
)
|
||||
events.append(failed_event)
|
||||
|
||||
# Still emit output_item.done event even on failure (with empty tools list)
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": base_item_id,
|
||||
"type": "mcp_list_tools",
|
||||
"server_label": "",
|
||||
"tools": [],
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
def create_mcp_call_events(
|
||||
tool_name: str,
|
||||
tool_call_id: str,
|
||||
arguments: str,
|
||||
result: Optional[str] = None,
|
||||
base_item_id: Optional[str] = None,
|
||||
sequence_start: int = 1,
|
||||
) -> List[ResponsesAPIStreamingResponse]:
|
||||
"""Create MCP call events following OpenAI's specification"""
|
||||
events: List[ResponsesAPIStreamingResponse] = []
|
||||
item_id = base_item_id or f"mcp_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# MCP call in progress event
|
||||
in_progress_event = MCPCallInProgressEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS,
|
||||
sequence_number=sequence_start,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
)
|
||||
events.append(in_progress_event)
|
||||
|
||||
# MCP call arguments delta event (streaming the arguments)
|
||||
arguments_delta_event = MCPCallArgumentsDeltaEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
delta=arguments, # JSON string with arguments
|
||||
sequence_number=sequence_start + 1,
|
||||
)
|
||||
events.append(arguments_delta_event)
|
||||
|
||||
# MCP call arguments done event
|
||||
arguments_done_event = MCPCallArgumentsDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE,
|
||||
output_index=0,
|
||||
item_id=item_id,
|
||||
arguments=arguments, # Complete JSON string with finalized arguments
|
||||
sequence_number=sequence_start + 2,
|
||||
)
|
||||
events.append(arguments_done_event)
|
||||
|
||||
# MCP call completed event (or failed if result indicates failure)
|
||||
if result is not None:
|
||||
completed_event = MCPCallCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
|
||||
sequence_number=sequence_start + 3,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
events.append(completed_event)
|
||||
|
||||
# Add output_item.done event with the tool call result
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": item_id,
|
||||
"type": "mcp_call",
|
||||
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
|
||||
"arguments": arguments,
|
||||
"error": None,
|
||||
"name": tool_name,
|
||||
"output": result,
|
||||
"server_label": "litellm",
|
||||
}
|
||||
),
|
||||
)
|
||||
events.append(output_item_done_event)
|
||||
else:
|
||||
failed_event = MCPCallFailedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_FAILED,
|
||||
sequence_number=sequence_start + 3,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
events.append(failed_event)
|
||||
|
||||
return events
|
||||
|
||||
|
||||
class MCPEnhancedStreamingIterator(BaseResponsesAPIStreamingIterator):
|
||||
"""
|
||||
A complete MCP streaming iterator that handles the entire flow:
|
||||
1. Immediately emits MCP discovery events
|
||||
2. Makes the first LLM call and streams its response
|
||||
3. Handles tool execution and follow-up calls for auto-execute tools
|
||||
4. Emits tool execution events in the stream
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_iterator: Any, # Can be None - will be created internally
|
||||
mcp_events: List[ResponsesAPIStreamingResponse],
|
||||
tool_server_map: dict[str, str],
|
||||
mcp_tools_with_litellm_proxy: Optional[List[Any]] = None,
|
||||
user_api_key_auth: Any = None,
|
||||
original_request_params: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
# MCP setup
|
||||
self.mcp_tools_with_litellm_proxy = mcp_tools_with_litellm_proxy or []
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.original_request_params = original_request_params or {}
|
||||
self.should_auto_execute = self._should_auto_execute_tools()
|
||||
|
||||
# Streaming state management
|
||||
self.phase = "initial_response" # initial_response -> mcp_discovery -> tool_execution -> follow_up_response -> finished
|
||||
self.finished = False
|
||||
|
||||
# Event queues and generation flags
|
||||
self.mcp_discovery_events: List[
|
||||
ResponsesAPIStreamingResponse
|
||||
] = mcp_events # Pre-generated MCP discovery events
|
||||
self.tool_execution_events: List[ResponsesAPIStreamingResponse] = []
|
||||
self.mcp_discovery_generated = True # Events are already generated
|
||||
self.mcp_events = (
|
||||
mcp_events # Store the initial MCP events for backward compatibility
|
||||
)
|
||||
self.tool_server_map = tool_server_map
|
||||
|
||||
# Iterator references
|
||||
self.base_iterator: Optional[
|
||||
Union[Any, ResponsesAPIResponse]
|
||||
] = base_iterator # Will be created when needed
|
||||
self.follow_up_iterator: Optional[Any] = None
|
||||
|
||||
# Response collection for tool execution
|
||||
self.collected_response: Optional[ResponsesAPIResponse] = None
|
||||
|
||||
# Set up model metadata (will be updated when we get the real iterator)
|
||||
self.model = self.original_request_params.get("model", "unknown")
|
||||
self.litellm_metadata = {}
|
||||
self.custom_llm_provider = self.original_request_params.get(
|
||||
"custom_llm_provider", None
|
||||
)
|
||||
self.litellm_call_id = self.original_request_params.get("litellm_call_id")
|
||||
self.litellm_trace_id = self.original_request_params.get("litellm_trace_id")
|
||||
|
||||
self._extract_mcp_headers_from_params()
|
||||
|
||||
# Mark as async iterator
|
||||
self.is_async = True
|
||||
|
||||
# Track if we've emitted initial OpenAI lifecycle events
|
||||
self.initial_events_emitted = False
|
||||
|
||||
# Cache the response ID to ensure consistency across all events
|
||||
self._cached_response_id: Optional[str] = None
|
||||
|
||||
def _extract_mcp_headers_from_params(self) -> None:
|
||||
"""Extract MCP headers from original request params to pass to tool calls"""
|
||||
from typing import Dict, Optional
|
||||
from starlette.datastructures import Headers
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
# Extract headers from secret_fields in original_request_params
|
||||
raw_headers_from_request: Optional[Dict[str, str]] = None
|
||||
secret_fields = self.original_request_params.get("secret_fields")
|
||||
if secret_fields and isinstance(secret_fields, dict):
|
||||
raw_headers_from_request = secret_fields.get("raw_headers")
|
||||
|
||||
# Extract MCP-specific headers
|
||||
self.mcp_auth_header: Optional[str] = None
|
||||
self.mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None
|
||||
self.oauth2_headers: Optional[Dict[str, str]] = None
|
||||
self.raw_headers: Optional[Dict[str, str]] = raw_headers_from_request
|
||||
|
||||
if raw_headers_from_request:
|
||||
headers_obj = Headers(raw_headers_from_request)
|
||||
self.mcp_auth_header = MCPRequestHandler._get_mcp_auth_header_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
self.mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(headers_obj)
|
||||
)
|
||||
self.oauth2_headers = MCPRequestHandler._get_oauth2_headers_from_headers(
|
||||
headers_obj
|
||||
)
|
||||
|
||||
# Also check if headers are provided in tools array (from request body)
|
||||
tools = self.original_request_params.get("tools")
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict) and tool.get("type") == "mcp":
|
||||
tool_headers = tool.get("headers", {})
|
||||
if tool_headers and isinstance(tool_headers, dict):
|
||||
# Merge tool headers into mcp_server_auth_headers
|
||||
headers_obj_from_tool = Headers(tool_headers)
|
||||
tool_mcp_server_auth_headers = (
|
||||
MCPRequestHandler._get_mcp_server_auth_headers_from_headers(
|
||||
headers_obj_from_tool
|
||||
)
|
||||
)
|
||||
|
||||
if tool_mcp_server_auth_headers:
|
||||
if self.mcp_server_auth_headers is None:
|
||||
self.mcp_server_auth_headers = {}
|
||||
# Merge the headers from tool into existing headers
|
||||
for (
|
||||
server_alias,
|
||||
headers_dict,
|
||||
) in tool_mcp_server_auth_headers.items():
|
||||
if server_alias not in self.mcp_server_auth_headers:
|
||||
self.mcp_server_auth_headers[server_alias] = {}
|
||||
self.mcp_server_auth_headers[server_alias].update(
|
||||
headers_dict
|
||||
)
|
||||
|
||||
# Also merge raw headers
|
||||
if self.raw_headers is None:
|
||||
self.raw_headers = {}
|
||||
self.raw_headers.update(tool_headers)
|
||||
|
||||
def _should_auto_execute_tools(self) -> bool:
|
||||
"""Check if tools should be auto-executed"""
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
return LiteLLM_Proxy_MCP_Handler._should_auto_execute_tools(
|
||||
self.mcp_tools_with_litellm_proxy
|
||||
)
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Phase-based streaming:
|
||||
1. initial_response - Stream the first LLM response (includes response.created, response.in_progress, response.output_item.added)
|
||||
2. mcp_discovery - Emit MCP discovery events (after response.output_item.added)
|
||||
3. continue_initial_response - Continue streaming the initial response content
|
||||
4. tool_execution - Emit tool execution events
|
||||
5. follow_up_response - Stream the follow-up response
|
||||
6. finished - End iteration
|
||||
"""
|
||||
|
||||
# Phase 1: Initial Response Stream (emit standard OpenAI events first)
|
||||
if self.phase == "initial_response":
|
||||
result = await self._handle_initial_response_phase()
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Phase 2: MCP Discovery Events (after response.output_item.added)
|
||||
if self.phase == "mcp_discovery":
|
||||
# Emit MCP discovery events
|
||||
if self.mcp_discovery_events:
|
||||
return self.mcp_discovery_events.pop(0)
|
||||
self.phase = "continue_initial_response"
|
||||
# Fall through to continue processing the initial response
|
||||
|
||||
# Phase 3: Continue Initial Response (after MCP discovery events)
|
||||
if self.phase == "continue_initial_response":
|
||||
try:
|
||||
return await self._process_base_iterator_chunk()
|
||||
except StopAsyncIteration:
|
||||
# Initial response ended, move to next phase
|
||||
if self.should_auto_execute and self.collected_response:
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
|
||||
# Phase 4: Tool Execution Events
|
||||
if self.phase == "tool_execution":
|
||||
# Emit any queued tool execution events
|
||||
if self.tool_execution_events:
|
||||
return self.tool_execution_events.pop(0)
|
||||
|
||||
# Move to follow-up response phase
|
||||
self.phase = "follow_up_response"
|
||||
await self._create_follow_up_iterator()
|
||||
|
||||
# Phase 5: Follow-up Response Stream
|
||||
if self.phase == "follow_up_response":
|
||||
if self.follow_up_iterator:
|
||||
try:
|
||||
return await cast(Any, self.follow_up_iterator).__anext__() # type: ignore[attr-defined]
|
||||
except StopAsyncIteration:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Phase 6: Finished
|
||||
if self.phase == "finished":
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Should not reach here
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _handle_initial_response_phase(
|
||||
self,
|
||||
) -> Optional[ResponsesAPIStreamingResponse]:
|
||||
"""
|
||||
Handle Phase 1: Initial Response Stream.
|
||||
|
||||
Returns a chunk to emit, or None to fall through to the next phase.
|
||||
Raises StopAsyncIteration when the stream is exhausted with no auto-execution.
|
||||
"""
|
||||
if self.base_iterator is None:
|
||||
await self._create_initial_response_iterator()
|
||||
|
||||
if self.base_iterator is None:
|
||||
# LLM call failed — still emit MCP discovery events before finishing
|
||||
if self.mcp_discovery_events:
|
||||
self.phase = "mcp_discovery"
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
return None
|
||||
|
||||
if self.base_iterator:
|
||||
if hasattr(self.base_iterator, "__anext__"):
|
||||
try:
|
||||
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
|
||||
|
||||
# Capture the response ID from the first event to ensure consistency
|
||||
if self._cached_response_id is None and hasattr(chunk, "response"):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if response_obj and hasattr(response_obj, "id"):
|
||||
self._cached_response_id = response_obj.id
|
||||
verbose_logger.debug(
|
||||
f"Cached response ID: {self._cached_response_id}"
|
||||
)
|
||||
|
||||
# After emitting response.output_item.added, transition to MCP discovery
|
||||
if not self.initial_events_emitted and hasattr(chunk, "type"):
|
||||
chunk_type = getattr(chunk, "type", None)
|
||||
if chunk_type == ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED:
|
||||
self.initial_events_emitted = True
|
||||
self.phase = "mcp_discovery"
|
||||
return chunk
|
||||
|
||||
# If auto-execution is enabled, check for completed responses
|
||||
if self.should_auto_execute and self._is_response_completed(chunk):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
self.collected_response = response_obj
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
|
||||
return chunk
|
||||
except StopAsyncIteration:
|
||||
if self.should_auto_execute and self.collected_response:
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise
|
||||
else:
|
||||
# base_iterator is not async iterable (likely a ResponsesAPIResponse)
|
||||
if self.should_auto_execute and isinstance(
|
||||
self.base_iterator, ResponsesAPIResponse
|
||||
):
|
||||
self.collected_response = self.base_iterator
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
else:
|
||||
self.phase = "finished"
|
||||
raise StopAsyncIteration
|
||||
return None
|
||||
|
||||
def _is_response_completed(self, chunk: ResponsesAPIStreamingResponse) -> bool:
|
||||
"""Check if this chunk indicates the response is completed"""
|
||||
from litellm.types.llms.openai import ResponsesAPIStreamEvents
|
||||
|
||||
return (
|
||||
getattr(chunk, "type", None) == ResponsesAPIStreamEvents.RESPONSE_COMPLETED
|
||||
)
|
||||
|
||||
async def _process_base_iterator_chunk(self) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Process a chunk from the base iterator with response ID consistency enforcement.
|
||||
"""
|
||||
if not self.base_iterator or not hasattr(self.base_iterator, "__anext__"):
|
||||
raise StopAsyncIteration
|
||||
|
||||
chunk = await cast(Any, self.base_iterator).__anext__() # type: ignore[attr-defined]
|
||||
|
||||
# Ensure response ID consistency - update chunk if needed
|
||||
if self._cached_response_id and hasattr(chunk, "response"):
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if response_obj and hasattr(response_obj, "id"):
|
||||
if response_obj.id != self._cached_response_id:
|
||||
verbose_logger.debug(
|
||||
f"Updating response ID from {response_obj.id} to {self._cached_response_id}"
|
||||
)
|
||||
response_obj.id = self._cached_response_id
|
||||
|
||||
# If auto-execution is enabled, check for completed responses
|
||||
if self.should_auto_execute and self._is_response_completed(chunk):
|
||||
# Collect the response for tool execution
|
||||
response_obj = getattr(chunk, "response", None)
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
self.collected_response = response_obj
|
||||
# Move to tool execution phase after emitting this chunk
|
||||
self.phase = "tool_execution"
|
||||
await self._generate_tool_execution_events()
|
||||
|
||||
return chunk
|
||||
|
||||
async def _create_initial_response_iterator(self) -> None:
|
||||
"""Create the initial response iterator by making the first LLM call"""
|
||||
try:
|
||||
# Import the core aresponses function that doesn't have MCP logic
|
||||
from litellm.responses.main import aresponses
|
||||
|
||||
# Make the initial response API call - but avoid the MCP wrapper
|
||||
params = self.original_request_params.copy()
|
||||
params["stream"] = True # Ensure streaming
|
||||
|
||||
# Use the pre-fetched all_tools from original_request_params (no re-processing needed)
|
||||
params_for_llm = {}
|
||||
for key, value in params.items():
|
||||
params_for_llm[
|
||||
key
|
||||
] = value # Copy all params as-is since tools are already processed
|
||||
|
||||
tools_count = (
|
||||
len(params_for_llm.get("tools", []))
|
||||
if params_for_llm.get("tools")
|
||||
else 0
|
||||
)
|
||||
verbose_logger.debug(f"Making LLM call with {tools_count} tools")
|
||||
response = await aresponses(**params_for_llm)
|
||||
|
||||
# Set the base iterator
|
||||
if hasattr(response, "__aiter__") or hasattr(response, "__iter__"):
|
||||
self.base_iterator = response
|
||||
# Copy metadata from the real iterator
|
||||
self.model = getattr(response, "model", self.model)
|
||||
self.litellm_metadata = getattr(response, "litellm_metadata", {})
|
||||
self.custom_llm_provider = getattr(
|
||||
response, "custom_llm_provider", self.custom_llm_provider
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Created base iterator: {type(self.base_iterator)}"
|
||||
)
|
||||
else:
|
||||
# Non-streaming response - this shouldn't happen but handle it
|
||||
verbose_logger.warning(f"Got non-streaming response: {type(response)}")
|
||||
self.base_iterator = None
|
||||
self.phase = "finished"
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating initial response iterator: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.base_iterator = None
|
||||
# Don't set phase to "finished" here — let __anext__ emit any
|
||||
# pre-generated MCP discovery events before ending the iteration.
|
||||
|
||||
async def _generate_tool_execution_events(self) -> None:
|
||||
"""Generate tool execution events and execute tools"""
|
||||
if not self.collected_response:
|
||||
return
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool calls from the response
|
||||
if self.collected_response is not None:
|
||||
tool_calls = LiteLLM_Proxy_MCP_Handler._extract_tool_calls_from_response(self.collected_response) # type: ignore[arg-type]
|
||||
else:
|
||||
tool_calls = []
|
||||
if not tool_calls:
|
||||
return
|
||||
|
||||
for tool_call in tool_calls:
|
||||
(
|
||||
tool_name,
|
||||
tool_arguments,
|
||||
tool_call_id,
|
||||
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
|
||||
if tool_name and tool_call_id:
|
||||
# Create MCP call events for this tool execution
|
||||
call_events = create_mcp_call_events(
|
||||
tool_name=tool_name,
|
||||
tool_call_id=tool_call_id,
|
||||
arguments=tool_arguments or "{}", # JSON string with arguments
|
||||
result=None, # Will be set after execution
|
||||
base_item_id=f"mcp_{uuid.uuid4().hex[:8]}",
|
||||
sequence_start=len(self.tool_execution_events) + 1,
|
||||
)
|
||||
# Add the in_progress and arguments events (not the completed event yet)
|
||||
self.tool_execution_events.extend(call_events[:-1])
|
||||
|
||||
# Execute the tools
|
||||
tool_results = await LiteLLM_Proxy_MCP_Handler._execute_tool_calls(
|
||||
tool_server_map=self.tool_server_map,
|
||||
tool_calls=tool_calls,
|
||||
user_api_key_auth=self.user_api_key_auth,
|
||||
mcp_auth_header=self.mcp_auth_header,
|
||||
mcp_server_auth_headers=self.mcp_server_auth_headers,
|
||||
oauth2_headers=self.oauth2_headers,
|
||||
raw_headers=self.raw_headers,
|
||||
litellm_call_id=self.litellm_call_id,
|
||||
litellm_trace_id=self.litellm_trace_id,
|
||||
)
|
||||
|
||||
# Create completion events and output_item.done events for tool execution
|
||||
for tool_result in tool_results:
|
||||
tool_call_id = tool_result.get("tool_call_id", "unknown")
|
||||
result_text = tool_result.get("result", "")
|
||||
|
||||
# Find matching tool name and arguments
|
||||
tool_name = "unknown"
|
||||
tool_arguments = "{}"
|
||||
for tool_call in tool_calls:
|
||||
(
|
||||
name,
|
||||
args,
|
||||
call_id,
|
||||
) = LiteLLM_Proxy_MCP_Handler._extract_tool_call_details(tool_call)
|
||||
if call_id == tool_call_id:
|
||||
tool_name = name or "unknown"
|
||||
tool_arguments = args or "{}"
|
||||
break
|
||||
|
||||
item_id = f"mcp_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create the completion event
|
||||
completed_event = MCPCallCompletedEvent(
|
||||
type=ResponsesAPIStreamEvents.MCP_CALL_COMPLETED,
|
||||
sequence_number=len(self.tool_execution_events) + 1,
|
||||
item_id=item_id,
|
||||
output_index=0,
|
||||
)
|
||||
self.tool_execution_events.append(completed_event)
|
||||
|
||||
# Create output_item.done event with the tool call result
|
||||
from litellm.types.llms.openai import OutputItemDoneEvent
|
||||
|
||||
output_item_done_event = OutputItemDoneEvent(
|
||||
type=ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE,
|
||||
output_index=0,
|
||||
item=BaseLiteLLMOpenAIResponseObject(
|
||||
**{
|
||||
"id": item_id,
|
||||
"type": "mcp_call",
|
||||
"approval_request_id": f"mcpr_{uuid.uuid4().hex[:8]}",
|
||||
"arguments": tool_arguments,
|
||||
"error": None,
|
||||
"name": tool_name,
|
||||
"output": result_text,
|
||||
"server_label": "litellm", # or extract from tool config
|
||||
}
|
||||
),
|
||||
)
|
||||
self.tool_execution_events.append(output_item_done_event)
|
||||
|
||||
# Store tool results for follow-up call
|
||||
self.tool_results = tool_results
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in tool execution: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.tool_results = []
|
||||
|
||||
async def _create_follow_up_iterator(self) -> None:
|
||||
"""Create the follow-up response iterator with tool results"""
|
||||
if not self.collected_response or not hasattr(self, "tool_results"):
|
||||
return
|
||||
|
||||
from litellm.responses.main import aresponses
|
||||
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
|
||||
LiteLLM_Proxy_MCP_Handler,
|
||||
)
|
||||
|
||||
try:
|
||||
# Create follow-up input
|
||||
if self.collected_response is not None:
|
||||
follow_up_input = LiteLLM_Proxy_MCP_Handler._create_follow_up_input(
|
||||
response=self.collected_response, # type: ignore[arg-type]
|
||||
tool_results=self.tool_results,
|
||||
original_input=self.original_request_params.get("input"),
|
||||
)
|
||||
|
||||
# Make follow-up call with streaming
|
||||
follow_up_params = self.original_request_params.copy()
|
||||
follow_up_params.update(
|
||||
{
|
||||
"input": follow_up_input,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return
|
||||
# Remove tool_choice to avoid forcing more tool calls
|
||||
follow_up_params.pop("tool_choice", None)
|
||||
|
||||
follow_up_response = await aresponses(**follow_up_params)
|
||||
|
||||
# Set up the follow-up iterator
|
||||
if hasattr(follow_up_response, "__aiter__"):
|
||||
self.follow_up_iterator = follow_up_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error creating follow-up iterator: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
self.follow_up_iterator = None
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self) -> ResponsesAPIStreamingResponse:
|
||||
# First, emit any queued MCP events
|
||||
if self.mcp_events: # type: ignore[attr-defined]
|
||||
return self.mcp_events.pop(0) # type: ignore[attr-defined]
|
||||
|
||||
# Then delegate to the base iterator
|
||||
if not self.is_async:
|
||||
try:
|
||||
if self.base_iterator and hasattr(self.base_iterator, "__next__"):
|
||||
return next(cast(Any, self.base_iterator)) # type: ignore[arg-type]
|
||||
else:
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
self.finished = True
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError("Cannot use sync iteration on async iterator")
|
||||
Reference in New Issue
Block a user