chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
LiteLLM A2A SDK functions.
|
||||
|
||||
Provides standalone functions with @client decorator for LiteLLM logging integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
# Runtime imports with availability check
|
||||
A2A_SDK_AVAILABLE = False
|
||||
A2ACardResolver: Any = None
|
||||
_A2AClient: Any = None
|
||||
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import our custom card resolver that supports multiple well-known paths
|
||||
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
|
||||
from litellm.a2a_protocol.exception_mapping_utils import (
|
||||
handle_a2a_localhost_retry,
|
||||
map_a2a_exception,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
|
||||
|
||||
# Use our custom resolver instead of the default A2A SDK resolver
|
||||
A2ACardResolver = LiteLLMA2ACardResolver
|
||||
|
||||
|
||||
def _set_usage_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Set usage on litellm_logging_obj for standard logging payload.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
"""
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
litellm_logging_obj.model_call_details["usage"] = usage
|
||||
|
||||
|
||||
def _set_agent_id_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
agent_id: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Set agent_id on litellm_logging_obj for SpendLogs tracking.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
agent_id: The A2A agent ID
|
||||
"""
|
||||
if agent_id is None:
|
||||
return
|
||||
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
|
||||
litellm_logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
|
||||
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract agent info and set model/custom_llm_provider for cost tracking.
|
||||
|
||||
Sets model info on the litellm_logging_obj if available.
|
||||
Returns the agent name for logging.
|
||||
"""
|
||||
agent_name = "unknown"
|
||||
|
||||
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
|
||||
if agent_card is None:
|
||||
agent_card = getattr(a2a_client, "agent_card", None)
|
||||
|
||||
if agent_card is not None:
|
||||
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
|
||||
|
||||
# Build model string
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
custom_llm_provider = "a2a_agent"
|
||||
|
||||
# Set on litellm_logging_obj if available (for standard logging payload)
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
async def _send_message_via_completion_bridge(
|
||||
request: "SendMessageRequest",
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
|
||||
|
||||
Requires request; api_base is optional for providers that derive endpoint from model.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
|
||||
async def _execute_a2a_send_with_retry(
|
||||
a2a_client: Any,
|
||||
request: Any,
|
||||
agent_card: Any,
|
||||
card_url: Optional[str],
|
||||
api_base: Optional[str],
|
||||
agent_name: Optional[str],
|
||||
) -> Any:
|
||||
"""Send an A2A message with retry logic for localhost URL errors."""
|
||||
a2a_response = None
|
||||
for _ in range(2): # max 2 attempts: original + 1 retry
|
||||
try:
|
||||
a2a_response = await a2a_client.send_message(request)
|
||||
break # success, exit retry loop
|
||||
except A2ALocalhostURLError as e:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
except Exception as e:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
if a2a_response is None:
|
||||
raise RuntimeError(
|
||||
"A2A send_message failed: no response received after retry attempts."
|
||||
)
|
||||
return a2a_response
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Async: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
|
||||
api_base: API base URL (required for completion bridge, optional for standard A2A)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
|
||||
Example (standard A2A):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, create_a2a_client
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(a2a_client=a2a_client, request=request)
|
||||
```
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
return await _send_message_via_completion_bridge(
|
||||
request=request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
trace_id = trace_id or str(uuid.uuid4())
|
||||
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
|
||||
if agent_id:
|
||||
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
|
||||
if agent_extra_headers:
|
||||
extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
agent_name = _get_a2a_model_info(a2a_client, kwargs)
|
||||
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
|
||||
|
||||
# Get agent card URL for localhost retry logic
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
|
||||
context_id = trace_id or str(uuid.uuid4())
|
||||
message = request.params.message
|
||||
if isinstance(message, dict):
|
||||
if message.get("context_id") is None:
|
||||
message["context_id"] = context_id
|
||||
else:
|
||||
if getattr(message, "context_id", None) is None:
|
||||
message.context_id = context_id
|
||||
|
||||
a2a_response = await _execute_a2a_send_with_retry(
|
||||
a2a_client=a2a_client,
|
||||
request=request,
|
||||
agent_card=agent_card,
|
||||
card_url=card_url,
|
||||
api_base=api_base,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
|
||||
|
||||
# Wrap in LiteLLM response type for _hidden_params support
|
||||
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
|
||||
|
||||
# Calculate token usage from request and response
|
||||
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
|
||||
(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
_,
|
||||
) = A2ARequestUtils.calculate_usage_from_request_response(
|
||||
request=request,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
# Set usage on logging obj for standard logging payload
|
||||
_set_usage_on_logging_obj(
|
||||
kwargs=kwargs,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Set agent_id on logging obj for SpendLogs tracking
|
||||
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@client
|
||||
def send_message(
|
||||
a2a_client: "A2AClientType",
|
||||
request: "SendMessageRequest",
|
||||
**kwargs: Any,
|
||||
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
|
||||
"""
|
||||
Sync: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance
|
||||
request: SendMessageRequest from a2a.types
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None:
|
||||
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
else:
|
||||
return asyncio.run(
|
||||
asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
def _build_streaming_logging_obj(
|
||||
request: "SendStreamingMessageRequest",
|
||||
agent_name: str,
|
||||
agent_id: Optional[str],
|
||||
litellm_params: Optional[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
proxy_server_request: Optional[Dict[str, Any]],
|
||||
) -> Logging:
|
||||
"""Build logging object for streaming A2A requests."""
|
||||
start_time = datetime.datetime.now()
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "streaming-request"}],
|
||||
stream=False,
|
||||
call_type="asend_message_streaming",
|
||||
start_time=start_time,
|
||||
litellm_call_id=str(request.id),
|
||||
function_id=str(request.id),
|
||||
)
|
||||
logging_obj.model = model
|
||||
logging_obj.custom_llm_provider = "a2a_agent"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
|
||||
if agent_id:
|
||||
logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
_litellm_params = litellm_params.copy() if litellm_params else {}
|
||||
if metadata:
|
||||
_litellm_params["metadata"] = metadata
|
||||
if proxy_server_request:
|
||||
_litellm_params["proxy_server_request"] = proxy_server_request
|
||||
|
||||
logging_obj.litellm_params = _litellm_params
|
||||
logging_obj.optional_params = _litellm_params
|
||||
logging_obj.model_call_details["litellm_params"] = _litellm_params
|
||||
logging_obj.model_call_details["metadata"] = metadata or {}
|
||||
|
||||
return logging_obj
|
||||
|
||||
|
||||
async def asend_message_streaming( # noqa: PLR0915
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendStreamingMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
proxy_server_request: Optional[Dict[str, Any]] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncIterator[Any]:
|
||||
"""
|
||||
Async: Send a streaming message to an A2A agent.
|
||||
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendStreamingMessageRequest from a2a.types
|
||||
api_base: API base URL (required for completion bridge)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
|
||||
proxy_server_request: Optional proxy server request data
|
||||
|
||||
Yields:
|
||||
SendStreamingMessageResponse chunks from the agent
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from a2a.types import SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
# Extract params from request
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
# Mirror the non-streaming path: always include trace and agent-id headers
|
||||
streaming_extra_headers: Dict[str, str] = {
|
||||
"X-LiteLLM-Trace-Id": str(request.id),
|
||||
}
|
||||
if agent_id:
|
||||
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
if agent_extra_headers:
|
||||
streaming_extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=streaming_extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
|
||||
|
||||
# Build logging object for streaming completion callbacks
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
|
||||
|
||||
logging_obj = _build_streaming_logging_obj(
|
||||
request=request,
|
||||
agent_name=agent_name,
|
||||
agent_id=agent_id,
|
||||
litellm_params=litellm_params,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
|
||||
# Connection errors in streaming typically occur on first chunk iteration
|
||||
first_chunk = True
|
||||
for attempt in range(2): # max 2 attempts: original + 1 retry
|
||||
stream = a2a_client.send_message_streaming(request)
|
||||
iterator = A2AStreamingIterator(
|
||||
stream=stream,
|
||||
request=request,
|
||||
logging_obj=logging_obj,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
try:
|
||||
first_chunk = True
|
||||
async for chunk in iterator:
|
||||
if first_chunk:
|
||||
first_chunk = False # connection succeeded
|
||||
yield chunk
|
||||
return # stream completed successfully
|
||||
except A2ALocalhostURLError as e:
|
||||
# Only retry on first chunk, not mid-stream
|
||||
if first_chunk and attempt == 0:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Only map exception on first chunk
|
||||
if first_chunk and attempt == 0:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
# Localhost URL error - fix and retry
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
# Re-raise the mapped exception
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
async def create_a2a_client(
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Create an A2A client for the given agent URL.
|
||||
|
||||
This resolves the agent card and returns a ready-to-use A2A client.
|
||||
The client can be reused for multiple requests.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
An initialized a2a.client.A2AClient instance
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import create_a2a_client, asend_message
|
||||
|
||||
# Create client once
|
||||
client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
|
||||
# Reuse for multiple requests
|
||||
response1 = await asend_message(a2a_client=client, request=request1)
|
||||
response2 = await asend_message(a2a_client=client, request=request2)
|
||||
```
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Creating A2A client for {base_url}")
|
||||
|
||||
# Use get_async_httpx_client with per-agent params so that different agents
|
||||
# (with different extra_headers) get separate cached clients. The params
|
||||
# dict is hashed into the cache key, keeping agent auth isolated while
|
||||
# still reusing connections within the same agent.
|
||||
#
|
||||
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
|
||||
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
|
||||
# filtered out before reaching the constructor).
|
||||
_client_params: dict = {"timeout": timeout}
|
||||
if extra_headers:
|
||||
# Encode headers into a cache-key-only param so each unique header
|
||||
# set produces a distinct cache key.
|
||||
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
|
||||
_async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2AProvider,
|
||||
params=_client_params,
|
||||
)
|
||||
httpx_client = _async_handler.client
|
||||
if extra_headers:
|
||||
httpx_client.headers.update(extra_headers)
|
||||
verbose_proxy_logger.debug(
|
||||
f"A2A client created with extra_headers={list(extra_headers.keys())}"
|
||||
)
|
||||
|
||||
# Resolve agent card
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
|
||||
# Create A2A client
|
||||
a2a_client = _A2AClient(
|
||||
httpx_client=httpx_client,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
# Store agent_card on client for later retrieval (SDK doesn't expose it)
|
||||
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
|
||||
|
||||
verbose_logger.info(f"A2A client created for {base_url}")
|
||||
|
||||
return a2a_client
|
||||
|
||||
|
||||
async def aget_agent_card(
|
||||
base_url: str,
|
||||
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card from an A2A agent.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Fetching agent card from {base_url}")
|
||||
|
||||
# Use LiteLLM's cached httpx client
|
||||
http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2A,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
httpx_client = http_handler.client
|
||||
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
return agent_card
|
||||
Reference in New Issue
Block a user