chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,10 @@
"""
A2A Protocol Providers.
This module contains provider-specific implementations for the A2A protocol.
"""
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]

View File

@@ -0,0 +1,62 @@
"""
Base configuration for A2A protocol providers.
"""
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Dict
class BaseA2AProviderConfig(ABC):
"""
Base configuration class for A2A protocol providers.
Each provider should implement this interface to define how to handle
A2A requests for their specific agent type.
"""
@abstractmethod
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Returns:
A2A SendMessageResponse dict
"""
pass
@abstractmethod
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the agent
**kwargs: Additional provider-specific parameters
Yields:
A2A streaming response events
"""
# This is an abstract method - subclasses must implement
# The yield is here to make this a generator function
if False: # pragma: no cover
yield {}

View File

@@ -0,0 +1,47 @@
"""
A2A Provider Config Manager.
Manages provider-specific configurations for A2A protocol.
"""
from typing import Optional
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
class A2AProviderConfigManager:
"""
Manager for A2A provider configurations.
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
"""
@staticmethod
def get_provider_config(
custom_llm_provider: Optional[str],
) -> Optional[BaseA2AProviderConfig]:
"""
Get the provider configuration for a given custom_llm_provider.
Args:
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
Returns:
Provider configuration instance or None if not found
"""
if custom_llm_provider is None:
return None
if custom_llm_provider == "pydantic_ai_agents":
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
return PydanticAIProviderConfig()
# Add more providers here as needed
# elif custom_llm_provider == "another_provider":
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
# return AnotherProviderConfig()
return None

View File

@@ -0,0 +1,74 @@
# A2A to LiteLLM Completion Bridge
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
## Flow
```
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
```
## SDK Usage
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
```python
from litellm.a2a_protocol import asend_message, asend_message_streaming
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
from uuid import uuid4
# Non-streaming
request = SendMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
response = await asend_message(
request=request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
)
# Streaming
stream_request = SendStreamingMessageRequest(
id=str(uuid4()),
params=MessageSendParams(
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
)
)
async for chunk in asend_message_streaming(
request=stream_request,
api_base="http://localhost:2024",
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
):
print(chunk)
```
## Proxy Usage
Configure an agent with `custom_llm_provider` in `litellm_params`:
```yaml
agents:
- agent_name: my-langgraph-agent
agent_card_params:
name: "LangGraph Agent"
url: "http://localhost:2024" # Used as api_base
litellm_params:
custom_llm_provider: langgraph
model: agent
```
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
1. Detects `custom_llm_provider` in agent's `litellm_params`
2. Transforms A2A message → OpenAI messages
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
4. Transforms response → A2A format
## Classes
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)

View File

@@ -0,0 +1,5 @@
"""
LiteLLM Completion bridge provider for A2A protocol.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
"""

View File

@@ -0,0 +1,301 @@
"""
Handler for A2A to LiteLLM completion bridge.
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
A2A Streaming Events (in order):
1. Task event (kind: "task") - Initial task creation with status "submitted"
2. Status update (kind: "status-update") - Status change to "working"
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
4. Status update (kind: "status-update") - Final status "completed" with final=true
"""
from typing import Any, AsyncIterator, Dict, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
PydanticAITransformation,
)
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
A2ACompletionBridgeTransformation,
A2AStreamingContext,
)
class A2ACompletionBridgeHandler:
"""
Static methods for handling A2A requests via LiteLLM completion.
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""
Handle non-streaming A2A request via litellm.acompletion.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Returns:
A2A SendMessageResponse dict
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
)
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
return response_data
# Extract message from params
message = params.get("message", {})
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": False,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# Call litellm.acompletion
response = await litellm.acompletion(**completion_params)
# Transform response to A2A format
a2a_response = (
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
response=response,
request_id=request_id,
)
)
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
return a2a_response
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming A2A request via litellm.acompletion with stream=True.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update (kind: "artifact-update") - Content delivery
4. Status update (kind: "status-update") - Final "completed" status
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
api_base: API base URL from agent_card_params
Yields:
A2A streaming response events
"""
# Check if this is a Pydantic AI agent request
custom_llm_provider = litellm_params.get("custom_llm_provider")
if custom_llm_provider == "pydantic_ai_agents":
if api_base is None:
raise ValueError("api_base is required for Pydantic AI agents")
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get non-streaming response first
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
)
# Convert to fake streaming
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=response_data,
request_id=request_id,
):
yield chunk
return
# Extract message from params
message = params.get("message", {})
# Create streaming context
ctx = A2AStreamingContext(
request_id=request_id,
input_message=message,
)
# Transform A2A message to OpenAI format
openai_messages = (
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
)
# Get completion params
custom_llm_provider = litellm_params.get("custom_llm_provider")
model = litellm_params.get("model", "agent")
# Build full model string if provider specified
# Skip prepending if model already starts with the provider prefix
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
full_model = f"{custom_llm_provider}/{model}"
else:
full_model = model
verbose_logger.info(
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
)
# Build completion params dict
completion_params = {
"model": full_model,
"messages": openai_messages,
"api_base": api_base,
"stream": True,
}
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
litellm_params_to_add = {
k: v
for k, v in litellm_params.items()
if k not in ("model", "custom_llm_provider")
}
completion_params.update(litellm_params_to_add)
# 1. Emit initial task event (kind: "task", status: "submitted")
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="working",
final=False,
message_text="Processing request...",
)
yield working_event
# Call litellm.acompletion with streaming
response = await litellm.acompletion(**completion_params)
# 3. Accumulate content and emit artifact update
accumulated_text = ""
chunk_count = 0
async for chunk in response: # type: ignore[union-attr]
chunk_count += 1
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if content:
accumulated_text += content
# Emit artifact update with accumulated content
if accumulated_text:
artifact_event = (
A2ACompletionBridgeTransformation.create_artifact_update_event(
ctx=ctx,
text=accumulated_text,
)
)
yield artifact_event
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
ctx=ctx,
state="completed",
final=True,
)
yield completed_event
verbose_logger.info(
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
)
# Convenience functions that delegate to the class methods
async def handle_a2a_completion(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> Dict[str, Any]:
"""Convenience function for non-streaming A2A completion."""
return await A2ACompletionBridgeHandler.handle_non_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
)
async def handle_a2a_completion_streaming(
request_id: str,
params: Dict[str, Any],
litellm_params: Dict[str, Any],
api_base: Optional[str] = None,
) -> AsyncIterator[Dict[str, Any]]:
"""Convenience function for streaming A2A completion."""
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
request_id=request_id,
params=params,
litellm_params=litellm_params,
api_base=api_base,
):
yield chunk

View File

@@ -0,0 +1,284 @@
"""
Transformation utilities for A2A <-> OpenAI message format conversion.
A2A Message Format:
{
"role": "user",
"parts": [{"kind": "text", "text": "Hello!"}],
"messageId": "abc123"
}
OpenAI Message Format:
{"role": "user", "content": "Hello!"}
A2A Streaming Events:
- Task event (kind: "task") - Initial task creation with status "submitted"
- Status update (kind: "status-update") - Status changes (working, completed)
- Artifact update (kind: "artifact-update") - Content/artifact delivery
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
from uuid import uuid4
from litellm._logging import verbose_logger
class A2AStreamingContext:
"""
Context holder for A2A streaming state.
Tracks task_id, context_id, and message accumulation.
"""
def __init__(self, request_id: str, input_message: Dict[str, Any]):
self.request_id = request_id
self.task_id = str(uuid4())
self.context_id = str(uuid4())
self.input_message = input_message
self.accumulated_text = ""
self.has_emitted_task = False
self.has_emitted_working = False
class A2ACompletionBridgeTransformation:
"""
Static methods for transforming between A2A and OpenAI message formats.
"""
@staticmethod
def a2a_message_to_openai_messages(
a2a_message: Dict[str, Any],
) -> List[Dict[str, str]]:
"""
Transform an A2A message to OpenAI message format.
Args:
a2a_message: A2A message with role, parts, and messageId
Returns:
List of OpenAI-format messages
"""
role = a2a_message.get("role", "user")
parts = a2a_message.get("parts", [])
# Map A2A roles to OpenAI roles
openai_role = role
if role == "user":
openai_role = "user"
elif role == "assistant":
openai_role = "assistant"
elif role == "system":
openai_role = "system"
# Extract text content from parts
content_parts = []
for part in parts:
kind = part.get("kind", "")
if kind == "text":
text = part.get("text", "")
content_parts.append(text)
content = "\n".join(content_parts) if content_parts else ""
verbose_logger.debug(
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
)
return [{"role": openai_role, "content": content}]
@staticmethod
def openai_response_to_a2a_response(
response: Any,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
Args:
response: LiteLLM ModelResponse object
request_id: Original A2A request ID
Returns:
A2A SendMessageResponse dict
"""
# Extract content from response
content = ""
if hasattr(response, "choices") and response.choices:
choice = response.choices[0]
if hasattr(choice, "message") and choice.message:
content = choice.message.content or ""
# Build A2A message
a2a_message = {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
}
# Build A2A response
a2a_response = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
return a2a_response
@staticmethod
def _get_timestamp() -> str:
"""Get current timestamp in ISO format with timezone."""
return datetime.now(timezone.utc).isoformat()
@staticmethod
def create_task_event(
ctx: A2AStreamingContext,
) -> Dict[str, Any]:
"""
Create the initial task event with status 'submitted'.
This is the first event emitted in an A2A streaming response.
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"history": [
{
"contextId": ctx.context_id,
"kind": "message",
"messageId": ctx.input_message.get("messageId", uuid4().hex),
"parts": ctx.input_message.get("parts", []),
"role": ctx.input_message.get("role", "user"),
"taskId": ctx.task_id,
}
],
"id": ctx.task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
@staticmethod
def create_status_update_event(
ctx: A2AStreamingContext,
state: str,
final: bool = False,
message_text: Optional[str] = None,
) -> Dict[str, Any]:
"""
Create a status update event.
Args:
ctx: Streaming context
state: Status state ('working', 'completed')
final: Whether this is the final event
message_text: Optional message text for 'working' status
"""
status: Dict[str, Any] = {
"state": state,
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
}
# Add message for 'working' status
if state == "working" and message_text:
status["message"] = {
"contextId": ctx.context_id,
"kind": "message",
"messageId": str(uuid4()),
"parts": [{"kind": "text", "text": message_text}],
"role": "agent",
"taskId": ctx.task_id,
}
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"contextId": ctx.context_id,
"final": final,
"kind": "status-update",
"status": status,
"taskId": ctx.task_id,
},
}
@staticmethod
def create_artifact_update_event(
ctx: A2AStreamingContext,
text: str,
) -> Dict[str, Any]:
"""
Create an artifact update event with content.
Args:
ctx: Streaming context
text: The text content for the artifact
"""
return {
"id": ctx.request_id,
"jsonrpc": "2.0",
"result": {
"artifact": {
"artifactId": str(uuid4()),
"name": "response",
"parts": [{"kind": "text", "text": text}],
},
"contextId": ctx.context_id,
"kind": "artifact-update",
"taskId": ctx.task_id,
},
}
@staticmethod
def openai_chunk_to_a2a_chunk(
chunk: Any,
request_id: Optional[str] = None,
is_final: bool = False,
) -> Optional[Dict[str, Any]]:
"""
Transform a LiteLLM streaming chunk to A2A streaming format.
NOTE: This method is deprecated for streaming. Use the event-based
methods (create_task_event, create_status_update_event,
create_artifact_update_event) instead for proper A2A streaming.
Args:
chunk: LiteLLM ModelResponse chunk
request_id: Original A2A request ID
is_final: Whether this is the final chunk
Returns:
A2A streaming chunk dict or None if no content
"""
# Extract delta content
content = ""
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta") and choice.delta:
content = choice.delta.content or ""
if not content and not is_final:
return None
# Build A2A streaming chunk (legacy format)
a2a_chunk = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": content}],
"messageId": uuid4().hex,
},
"final": is_final,
},
}
return a2a_chunk

View File

@@ -0,0 +1,16 @@
"""
Pydantic AI agent provider for A2A protocol.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
"""
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
PydanticAIProviderConfig,
)
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]

View File

@@ -0,0 +1,50 @@
"""
Pydantic AI provider configuration.
"""
from typing import Any, AsyncIterator, Dict
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
class PydanticAIProviderConfig(BaseA2AProviderConfig):
"""
Provider configuration for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This config provides fake streaming by converting non-streaming responses into streaming chunks.
"""
async def handle_non_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> Dict[str, Any]:
"""Handle non-streaming request to Pydantic AI agent."""
return await PydanticAIHandler.handle_non_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
)
async def handle_streaming(
self,
request_id: str,
params: Dict[str, Any],
api_base: str,
**kwargs,
) -> AsyncIterator[Dict[str, Any]]:
"""Handle streaming request with fake streaming."""
async for chunk in PydanticAIHandler.handle_streaming(
request_id=request_id,
params=params,
api_base=api_base,
timeout=kwargs.get("timeout", 60.0),
chunk_size=kwargs.get("chunk_size", 50),
delay_ms=kwargs.get("delay_ms", 10),
):
yield chunk

View File

@@ -0,0 +1,102 @@
"""
Handler for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming natively.
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
"""
from typing import Any, AsyncIterator, Dict
from litellm._logging import verbose_logger
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
PydanticAITransformation,
)
class PydanticAIHandler:
"""
Handler for Pydantic AI agent requests.
Provides:
- Direct non-streaming requests to Pydantic AI agents
- Fake streaming by converting non-streaming responses into streaming chunks
"""
@staticmethod
async def handle_non_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Handle non-streaming request to Pydantic AI agent.
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
Returns:
A2A SendMessageResponse dict
"""
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
# Send request directly to Pydantic AI agent
response_data = await PydanticAITransformation.send_non_streaming_request(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
return response_data
@staticmethod
async def handle_streaming(
request_id: str,
params: Dict[str, Any],
api_base: str,
timeout: float = 60.0,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Handle streaming request to Pydantic AI agent with fake streaming.
Since Pydantic AI agents don't support streaming natively, this method:
1. Makes a non-streaming request
2. Converts the response into streaming chunks
Args:
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
api_base: Base URL of the Pydantic AI agent
timeout: Request timeout in seconds
chunk_size: Number of characters per chunk
delay_ms: Delay between chunks in milliseconds
Yields:
A2A streaming response events
"""
verbose_logger.info(
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
)
# Get raw task response first (not the transformed A2A format)
raw_response = await PydanticAITransformation.send_and_get_raw_response(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Convert raw task response to fake streaming chunks
async for chunk in PydanticAITransformation.fake_streaming_from_response(
response_data=raw_response,
request_id=request_id,
chunk_size=chunk_size,
delay_ms=delay_ms,
):
yield chunk

View File

@@ -0,0 +1,530 @@
"""
Transformation layer for Pydantic AI agents.
Pydantic AI agents follow A2A protocol but don't support streaming.
This module provides fake streaming by converting non-streaming responses into streaming chunks.
"""
import asyncio
from typing import Any, AsyncIterator, Dict, cast
from uuid import uuid4
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
)
class PydanticAITransformation:
"""
Transformation layer for Pydantic AI agents.
Handles:
- Direct A2A requests to Pydantic AI endpoints
- Polling for task completion (since Pydantic AI doesn't support streaming)
- Fake streaming by chunking non-streaming responses
"""
@staticmethod
def _remove_none_values(obj: Any) -> Any:
"""
Recursively remove None values from a dict/list structure.
FastA2A/Pydantic AI servers don't accept None values for optional fields -
they expect those fields to be omitted entirely.
Args:
obj: Dict, list, or other value to clean
Returns:
Cleaned object with None values removed
"""
if isinstance(obj, dict):
return {
k: PydanticAITransformation._remove_none_values(v)
for k, v in obj.items()
if v is not None
}
elif isinstance(obj, list):
return [
PydanticAITransformation._remove_none_values(item)
for item in obj
if item is not None
]
else:
return obj
@staticmethod
def _params_to_dict(params: Any) -> Dict[str, Any]:
"""
Convert params to a dict, handling Pydantic models.
Args:
params: Dict or Pydantic model
Returns:
Dict representation of params
"""
if hasattr(params, "model_dump"):
# Pydantic v2 model
return params.model_dump(mode="python", exclude_none=True)
elif hasattr(params, "dict"):
# Pydantic v1 model
return params.dict(exclude_none=True)
elif isinstance(params, dict):
return params
else:
# Try to convert to dict
return dict(params)
@staticmethod
async def _poll_for_completion(
client: AsyncHTTPHandler,
endpoint: str,
task_id: str,
request_id: str,
max_attempts: int = 30,
poll_interval: float = 0.5,
) -> Dict[str, Any]:
"""
Poll for task completion using tasks/get method.
Args:
client: HTTPX async client
endpoint: API endpoint URL
task_id: Task ID to poll for
request_id: JSON-RPC request ID
max_attempts: Maximum polling attempts
poll_interval: Seconds between poll attempts
Returns:
Completed task response
"""
for attempt in range(max_attempts):
poll_request = {
"jsonrpc": "2.0",
"id": f"{request_id}-poll-{attempt}",
"method": "tasks/get",
"params": {"id": task_id},
}
response = await client.post(
endpoint,
json=poll_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
poll_data = response.json()
result = poll_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
verbose_logger.debug(
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
)
if state == "completed":
return poll_data
elif state in ("failed", "canceled"):
raise Exception(f"Task {task_id} ended with state: {state}")
await asyncio.sleep(poll_interval)
raise TimeoutError(
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
)
@staticmethod
async def _send_and_poll_raw(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
This is an internal method used by both non-streaming and streaming handlers.
Returns the raw Pydantic AI task format with history/artifacts.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
# Convert params to dict if it's a Pydantic model
params_dict = PydanticAITransformation._params_to_dict(params)
# Remove None values - FastA2A doesn't accept null for optional fields
params_dict = PydanticAITransformation._remove_none_values(params_dict)
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
if "message" in params_dict:
params_dict["message"]["kind"] = "message"
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
a2a_request = {
"jsonrpc": "2.0",
"id": request_id,
"method": "message/send",
"params": params_dict,
}
# FastA2A uses root endpoint (/) not /messages
endpoint = api_base.rstrip("/")
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
# Send request to Pydantic AI agent using shared async HTTP client
client = get_async_httpx_client(
llm_provider=cast(Any, "pydantic_ai_agent"),
params={"timeout": timeout},
)
response = await client.post(
endpoint,
json=a2a_request,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
response_data = response.json()
# Check if task is already completed
result = response_data.get("result", {})
status = result.get("status", {})
state = status.get("state", "")
if state != "completed":
# Need to poll for completion
task_id = result.get("id")
if task_id:
verbose_logger.info(
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
)
response_data = await PydanticAITransformation._poll_for_completion(
client=client,
endpoint=endpoint,
task_id=task_id,
request_id=request_id,
)
verbose_logger.info(
f"Pydantic AI: Received completed response for request_id={request_id}"
)
return response_data
@staticmethod
async def send_non_streaming_request(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
Args:
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message (dict or Pydantic model)
timeout: Request timeout in seconds
Returns:
Standard A2A non-streaming response format with message
"""
# Get raw task response
raw_response = await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
# Transform to standard A2A non-streaming format
return PydanticAITransformation._transform_to_a2a_response(
response_data=raw_response,
request_id=request_id,
)
@staticmethod
async def send_and_get_raw_response(
api_base: str,
request_id: str,
params: Any,
timeout: float = 60.0,
) -> Dict[str, Any]:
"""
Send a request to Pydantic AI agent and return the raw task response.
Used by streaming handler to get raw response for fake streaming.
Args:
api_base: Base URL of the Pydantic AI agent
request_id: A2A JSON-RPC request ID
params: A2A MessageSendParams containing the message
timeout: Request timeout in seconds
Returns:
Raw Pydantic AI task response (with history/artifacts)
"""
return await PydanticAITransformation._send_and_poll_raw(
api_base=api_base,
request_id=request_id,
params=params,
timeout=timeout,
)
@staticmethod
def _transform_to_a2a_response(
response_data: Dict[str, Any],
request_id: str,
) -> Dict[str, Any]:
"""
Transform Pydantic AI task response to standard A2A non-streaming format.
Pydantic AI returns a task with history/artifacts, but the standard A2A
non-streaming format expects:
{
"jsonrpc": "2.0",
"id": "...",
"result": {
"message": {
"role": "agent",
"parts": [{"kind": "text", "text": "..."}],
"messageId": "..."
}
}
}
Args:
response_data: Pydantic AI task response
request_id: Original request ID
Returns:
Standard A2A non-streaming response format
"""
# Extract the agent response text
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Build standard A2A message
a2a_message = {
"role": "agent",
"parts": parts if parts else [{"kind": "text", "text": full_text}],
"messageId": message_id,
}
# Return standard A2A non-streaming format
return {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"message": a2a_message,
},
}
@staticmethod
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
"""
Extract response text from completed task response.
Pydantic AI returns completed tasks with:
- history: list of messages (user and agent)
- artifacts: list of result artifacts
Args:
response_data: Completed task response
Returns:
Tuple of (full_text, message_id, parts)
"""
result = response_data.get("result", {})
# Try to extract from artifacts first (preferred for results)
artifacts = result.get("artifacts", [])
if artifacts:
for artifact in artifacts:
parts = artifact.get("parts", [])
for part in parts:
if part.get("kind") == "text":
text = part.get("text", "")
if text:
return text, str(uuid4()), parts
# Fall back to history - get the last agent message
history = result.get("history", [])
for msg in reversed(history):
if msg.get("role") == "agent":
parts = msg.get("parts", [])
message_id = msg.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
if full_text:
return full_text, message_id, parts
# Fall back to message field (original format)
message = result.get("message", {})
if message:
parts = message.get("parts", [])
message_id = message.get("messageId", str(uuid4()))
full_text = ""
for part in parts:
if part.get("kind") == "text":
full_text += part.get("text", "")
return full_text, message_id, parts
return "", str(uuid4()), []
@staticmethod
async def fake_streaming_from_response(
response_data: Dict[str, Any],
request_id: str,
chunk_size: int = 50,
delay_ms: int = 10,
) -> AsyncIterator[Dict[str, Any]]:
"""
Convert a non-streaming A2A response into fake streaming chunks.
Emits proper A2A streaming events:
1. Task event (kind: "task") - Initial task with status "submitted"
2. Status update (kind: "status-update") - Status "working"
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
4. Status update (kind: "status-update") - Final "completed" status
Args:
response_data: Non-streaming A2A response dict (completed task)
request_id: A2A JSON-RPC request ID
chunk_size: Number of characters per chunk (default: 50)
delay_ms: Delay between chunks in milliseconds (default: 10)
Yields:
A2A streaming response events
"""
# Extract the response text from completed task
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
response_data
)
# Extract input message from raw response for history
result = response_data.get("result", {})
history = result.get("history", [])
input_message = {}
for msg in history:
if msg.get("role") == "user":
input_message = msg
break
# Generate IDs for streaming events
task_id = str(uuid4())
context_id = str(uuid4())
artifact_id = str(uuid4())
input_message_id = input_message.get("messageId", str(uuid4()))
# 1. Emit initial task event (kind: "task", status: "submitted")
# Format matches A2ACompletionBridgeTransformation.create_task_event
task_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"history": [
{
"contextId": context_id,
"kind": "message",
"messageId": input_message_id,
"parts": input_message.get(
"parts", [{"kind": "text", "text": ""}]
),
"role": "user",
"taskId": task_id,
}
],
"id": task_id,
"kind": "task",
"status": {
"state": "submitted",
},
},
}
yield task_event
# 2. Emit status update (kind: "status-update", status: "working")
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
working_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": False,
"kind": "status-update",
"status": {
"state": "working",
},
"taskId": task_id,
},
}
yield working_event
# Small delay to simulate processing
await asyncio.sleep(delay_ms / 1000.0)
# 3. Emit artifact update chunks (kind: "artifact-update")
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
if full_text:
# Split text into chunks
for i in range(0, len(full_text), chunk_size):
chunk_text = full_text[i : i + chunk_size]
is_last_chunk = (i + chunk_size) >= len(full_text)
artifact_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"kind": "artifact-update",
"taskId": task_id,
"artifact": {
"artifactId": artifact_id,
"parts": [
{
"kind": "text",
"text": chunk_text,
}
],
},
},
}
yield artifact_event
# Add delay between chunks (except for last chunk)
if not is_last_chunk:
await asyncio.sleep(delay_ms / 1000.0)
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
completed_event = {
"jsonrpc": "2.0",
"id": request_id,
"result": {
"contextId": context_id,
"final": True,
"kind": "status-update",
"status": {
"state": "completed",
},
"taskId": task_id,
},
}
yield completed_event
verbose_logger.info(
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
)