chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,12 @@
## File Structure
### August 27th, 2024
To make it easy to see how calls are transformed for each model/provider:
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
E.g. `cohere/`, `bedrock/`.

View File

@@ -0,0 +1,183 @@
import importlib
import os
from typing import TYPE_CHECKING, Dict, Optional, Type
from litellm._logging import verbose_logger
from litellm.types.utils import CallTypes
from . import *
if TYPE_CHECKING:
from litellm.llms.base_llm.guardrail_translation.base_translation import (
BaseTranslation,
)
from litellm.types.utils import ModelInfo, Usage
def get_cost_for_web_search_request(
custom_llm_provider: str, usage: "Usage", model_info: "ModelInfo"
) -> Optional[float]:
"""
Get the cost for a web search request for a given model.
Args:
custom_llm_provider: The custom LLM provider.
usage: The usage object.
model_info: The model info.
"""
if custom_llm_provider == "gemini":
from .gemini.cost_calculator import cost_per_web_search_request
return cost_per_web_search_request(usage=usage, model_info=model_info)
elif custom_llm_provider == "anthropic":
from .anthropic.cost_calculation import get_cost_for_anthropic_web_search
return get_cost_for_anthropic_web_search(model_info=model_info, usage=usage)
elif custom_llm_provider.startswith("vertex_ai"):
from .vertex_ai.gemini.cost_calculator import (
cost_per_web_search_request as cost_per_web_search_request_vertex_ai,
)
return cost_per_web_search_request_vertex_ai(usage=usage, model_info=model_info)
elif custom_llm_provider == "perplexity":
# Perplexity handles search costs internally in its own cost calculator
# Return 0.0 to indicate costs are already accounted for
return 0.0
elif custom_llm_provider == "xai":
from .xai.cost_calculator import cost_per_web_search_request
return cost_per_web_search_request(usage=usage, model_info=model_info)
else:
return None
def discover_guardrail_translation_mappings() -> (
Dict[CallTypes, Type["BaseTranslation"]]
):
"""
Discover guardrail translation mappings by scanning the llms directory structure.
Scans for modules with guardrail_translation_mappings dictionaries and aggregates them.
Returns:
Dict[CallTypes, Type[BaseTranslation]]: A dictionary mapping call types to their translation handler classes
"""
discovered_mappings: Dict[CallTypes, Type["BaseTranslation"]] = {}
try:
# Get the path to the llms directory
current_dir = os.path.dirname(__file__)
llms_dir = current_dir
if not os.path.exists(llms_dir):
verbose_logger.debug("llms directory not found")
return discovered_mappings
# Recursively scan for guardrail_translation directories
for root, dirs, files in os.walk(llms_dir):
# Skip __pycache__ and base_llm directories
dirs[:] = [d for d in dirs if not d.startswith("__") and d != "base_llm"]
# Check if this is a guardrail_translation directory with __init__.py
if (
os.path.basename(root) == "guardrail_translation"
and "__init__.py" in files
):
# Build the module path relative to litellm
rel_path = os.path.relpath(root, os.path.dirname(llms_dir))
module_path = "litellm." + rel_path.replace(os.sep, ".")
try:
# Import the module
verbose_logger.debug(
f"Discovering guardrail translations in: {module_path}"
)
module = importlib.import_module(module_path)
# Check for guardrail_translation_mappings dictionary
if hasattr(module, "guardrail_translation_mappings"):
mappings = getattr(module, "guardrail_translation_mappings")
if isinstance(mappings, dict):
discovered_mappings.update(mappings)
verbose_logger.debug(
f"Found guardrail_translation_mappings in {module_path}: {list(mappings.keys())}"
)
except ImportError as e:
verbose_logger.error(f"Could not import {module_path}: {e}")
continue
except Exception as e:
verbose_logger.error(f"Error processing {module_path}: {e}")
continue
try:
from litellm.proxy._experimental.mcp_server.guardrail_translation import (
guardrail_translation_mappings as mcp_guardrail_translation_mappings,
)
discovered_mappings.update(mcp_guardrail_translation_mappings)
verbose_logger.debug(
"Loaded MCP guardrail translation mappings: %s",
list(mcp_guardrail_translation_mappings.keys()),
)
except ImportError:
verbose_logger.debug(
"MCP guardrail translation mappings not available; skipping"
)
verbose_logger.debug(
f"Discovered {len(discovered_mappings)} guardrail translation mappings: {list(discovered_mappings.keys())}"
)
except Exception as e:
verbose_logger.error(f"Error discovering guardrail translation mappings: {e}")
return discovered_mappings
# Cache the discovered mappings
endpoint_guardrail_translation_mappings: Optional[
Dict[CallTypes, Type["BaseTranslation"]]
] = None
def load_guardrail_translation_mappings():
global endpoint_guardrail_translation_mappings
if endpoint_guardrail_translation_mappings is None:
endpoint_guardrail_translation_mappings = (
discover_guardrail_translation_mappings()
)
return endpoint_guardrail_translation_mappings
def get_guardrail_translation_mapping(call_type: CallTypes) -> Type["BaseTranslation"]:
"""
Get the guardrail translation handler for a given call type.
Args:
call_type: The type of call (e.g., completion, acompletion, anthropic_messages)
Returns:
The translation handler class for the given call type
Raises:
ValueError: If no translation mapping exists for the given call type
"""
global endpoint_guardrail_translation_mappings
# Lazy load the mappings on first access
if endpoint_guardrail_translation_mappings is None:
endpoint_guardrail_translation_mappings = (
discover_guardrail_translation_mappings()
)
# Get the translation handler class for the call type
if call_type not in endpoint_guardrail_translation_mappings:
raise ValueError(
f"No guardrail translation mapping found for call_type: {call_type}. "
f"Available mappings: {list(endpoint_guardrail_translation_mappings.keys())}"
)
# Return the handler class directly
return endpoint_guardrail_translation_mappings[call_type]

View File

@@ -0,0 +1,6 @@
"""
A2A (Agent-to-Agent) Protocol Provider for LiteLLM
"""
from .chat.transformation import A2AConfig
__all__ = ["A2AConfig"]

View File

@@ -0,0 +1,6 @@
"""
A2A Chat Completion Implementation
"""
from .transformation import A2AConfig
__all__ = ["A2AConfig"]

View File

@@ -0,0 +1,155 @@
# A2A Protocol Guardrail Translation Handler
Handler for processing A2A (Agent-to-Agent) Protocol messages with guardrails.
## Overview
This handler processes A2A JSON-RPC 2.0 input/output by:
1. Extracting text from message parts (`kind: "text"`)
2. Applying guardrails to text content
3. Mapping guardrailed text back to original structure
## A2A Protocol Format
### Input Format (JSON-RPC 2.0)
```json
{
"jsonrpc": "2.0",
"id": "request-id",
"method": "message/send",
"params": {
"message": {
"kind": "message",
"messageId": "...",
"role": "user",
"parts": [
{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}
]
},
"metadata": {
"guardrails": ["block-ssn"]
}
}
}
```
### Output Formats
The handler supports multiple A2A response formats:
**Direct message:**
```json
{
"result": {
"kind": "message",
"parts": [{"kind": "text", "text": "Response text"}]
}
}
```
**Nested message:**
```json
{
"result": {
"message": {
"parts": [{"kind": "text", "text": "Response text"}]
}
}
}
```
**Task with artifacts:**
```json
{
"result": {
"kind": "task",
"artifacts": [
{"parts": [{"kind": "text", "text": "Artifact text"}]}
]
}
}
```
**Task with status message:**
```json
{
"result": {
"kind": "task",
"status": {
"message": {
"parts": [{"kind": "text", "text": "Status message"}]
}
}
}
}
```
**Streaming artifact-update:**
```json
{
"result": {
"kind": "artifact-update",
"artifact": {
"parts": [{"kind": "text", "text": "Streaming text"}]
}
}
}
```
## Usage
The handler is automatically discovered and applied when guardrails are used with A2A endpoints.
### Via LiteLLM Proxy
```bash
curl -X POST 'http://localhost:4000/a2a/my-agent' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer your-api-key' \
-d '{
"jsonrpc": "2.0",
"id": "1",
"method": "message/send",
"params": {
"message": {
"kind": "message",
"messageId": "msg-1",
"role": "user",
"parts": [{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}]
},
"metadata": {
"guardrails": ["block-ssn"]
}
}
}'
```
### Specifying Guardrails
Guardrails can be specified in the A2A request via the `metadata.guardrails` field:
```json
{
"params": {
"message": {...},
"metadata": {
"guardrails": ["block-ssn", "pii-filter"]
}
}
}
```
## Extension
Override these methods to customize behavior:
- `_extract_texts_from_result()`: Custom text extraction from A2A responses
- `_extract_texts_from_parts()`: Custom text extraction from message parts
- `_apply_text_to_path()`: Custom application of guardrailed text
## Call Types
This handler is registered for:
- `CallTypes.send_message`: Synchronous A2A message sending
- `CallTypes.asend_message`: Asynchronous A2A message sending

View File

@@ -0,0 +1,11 @@
"""A2A Protocol handler for Unified Guardrails."""
from litellm.llms.a2a.chat.guardrail_translation.handler import A2AGuardrailHandler
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.send_message: A2AGuardrailHandler,
CallTypes.asend_message: A2AGuardrailHandler,
}
__all__ = ["guardrail_translation_mappings"]

View File

@@ -0,0 +1,428 @@
"""
A2A Protocol Handler for Unified Guardrails
This module provides guardrail translation support for A2A (Agent-to-Agent) Protocol.
It handles both JSON-RPC 2.0 input requests and output responses, extracting text
from message parts and applying guardrails.
A2A Protocol Format:
- Input: JSON-RPC 2.0 with params.message.parts containing text parts
- Output: JSON-RPC 2.0 with result containing message/artifact parts
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import UserAPIKeyAuth
class A2AGuardrailHandler(BaseTranslation):
"""
Handler for processing A2A Protocol messages with guardrails.
This class provides methods to:
1. Process input messages (pre-call hook) - extracts text from A2A message parts
2. Process output responses (post-call hook) - extracts text from A2A response parts
A2A Message Format:
- Input: params.message.parts[].text (where kind == "text")
- Output: result.message.parts[].text or result.artifacts[].parts[].text
"""
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> Any:
"""
Process A2A input messages by applying guardrails to text content.
Extracts text from A2A message parts and applies guardrails.
Args:
data: The A2A JSON-RPC 2.0 request data
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
Returns:
Modified data with guardrails applied to text content
"""
# A2A request format: { "params": { "message": { "parts": [...] } } }
params = data.get("params", {})
message = params.get("message", {})
parts = message.get("parts", [])
if not parts:
verbose_proxy_logger.debug("A2A: No parts in message, skipping guardrail")
return data
texts_to_check: List[str] = []
text_part_indices: List[int] = [] # Track which parts contain text
# Step 1: Extract text from all text parts
for part_idx, part in enumerate(parts):
if part.get("kind") == "text":
text = part.get("text", "")
if text:
texts_to_check.append(text)
text_part_indices.append(part_idx)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
# Pass the structured A2A message to guardrails
inputs["structured_messages"] = [message]
# Include agent model info if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Apply guardrailed text back to original parts
if guardrailed_texts and len(guardrailed_texts) == len(text_part_indices):
for task_idx, part_idx in enumerate(text_part_indices):
parts[part_idx]["text"] = guardrailed_texts[task_idx]
verbose_proxy_logger.debug("A2A: Processed input message: %s", message)
return data
async def process_output_response(
self,
response: Any,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> Any:
"""
Process A2A output response by applying guardrails to text content.
Handles multiple A2A response formats:
- Direct message: {"result": {"kind": "message", "parts": [...]}}
- Nested message: {"result": {"message": {"parts": [...]}}}
- Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
- Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
Args:
response: A2A JSON-RPC 2.0 response dict or object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata
Returns:
Modified response with guardrails applied to text content
"""
# Handle both dict and Pydantic model responses
if hasattr(response, "model_dump"):
response_dict = response.model_dump()
is_pydantic = True
elif isinstance(response, dict):
response_dict = response
is_pydantic = False
else:
verbose_proxy_logger.warning(
"A2A: Unknown response type %s, skipping guardrail", type(response)
)
return response
result = response_dict.get("result", {})
if not result or not isinstance(result, dict):
verbose_proxy_logger.debug("A2A: No result in response, skipping guardrail")
return response
# Find all text-containing parts in the response
texts_to_check: List[str] = []
# Each mapping is (path_to_parts_list, part_index)
# path_to_parts_list is a tuple of keys to navigate to the parts list
task_mappings: List[Tuple[Tuple[str, ...], int]] = []
# Extract texts from all possible locations
self._extract_texts_from_result(
result=result,
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
if not texts_to_check:
verbose_proxy_logger.debug("A2A: No text content in response")
return response
# Step 2: Apply guardrail to all texts in batch
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response_dict}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Apply guardrailed text back to original response
if guardrailed_texts and len(guardrailed_texts) == len(task_mappings):
for task_idx, (path, part_idx) in enumerate(task_mappings):
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text=guardrailed_texts[task_idx],
)
verbose_proxy_logger.debug("A2A: Processed output response")
# Update the original response
if is_pydantic:
# For Pydantic models, we need to update the underlying dict
# and the model will reflect the changes
response_dict["result"] = result
return response
else:
response["result"] = result
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
) -> List[Any]:
"""
Process A2A streaming output by applying guardrails to accumulated text.
responses_so_far can be a list of JSON-RPC 2.0 objects (dict or NDJSON str), e.g.:
- task with history, status-update, artifact-update (with result.artifact.parts),
- then status-update (final). Text is extracted from result.artifact.parts,
result.message.parts, result.parts, etc., concatenated in order, guardrailed once,
then the combined guardrailed text is written into the first chunk that had text
and all other text parts in other chunks are cleared (in-place).
"""
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
# Parse each item; keep alignment with responses_so_far (None where unparseable)
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
for i, item in enumerate(responses_so_far):
if isinstance(item, dict):
obj = item
elif isinstance(item, str):
try:
obj = json.loads(item.strip())
except (json.JSONDecodeError, TypeError):
continue
else:
continue
if isinstance(obj.get("result"), dict):
parsed[i] = obj
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
if not valid_parsed:
return responses_so_far
# Collect text from each chunk in order (by original index in responses_so_far)
text_parts: List[str] = []
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
for idx, (orig_i, obj) in enumerate(valid_parsed):
t = extract_text_from_a2a_response(obj)
if t:
text_parts.append(t)
chunk_indices_with_text.append(orig_i)
combined_text = "".join(text_parts)
if not combined_text:
return responses_so_far
request_data: dict = {"responses_so_far": responses_so_far}
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=[combined_text])
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
if not guardrailed_texts:
return responses_so_far
guardrailed_text = guardrailed_texts[0]
# Find first chunk (by original index) that has text; put full guardrailed text there and clear rest
first_chunk_with_text: Optional[int] = (
chunk_indices_with_text[0] if chunk_indices_with_text else None
)
for orig_i, obj in valid_parsed:
result = obj.get("result", {})
if not isinstance(result, dict):
continue
texts_in_chunk: List[str] = []
mappings: List[Tuple[Tuple[str, ...], int]] = []
self._extract_texts_from_result(
result=result,
texts_to_check=texts_in_chunk,
task_mappings=mappings,
)
if not mappings:
continue
if orig_i == first_chunk_with_text:
# Put full guardrailed text in first text part; clear others
for task_idx, (path, part_idx) in enumerate(mappings):
text = guardrailed_text if task_idx == 0 else ""
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text=text,
)
else:
for path, part_idx in mappings:
self._apply_text_to_path(
result=result,
path=path,
part_idx=part_idx,
text="",
)
# Write back to responses_so_far where we had NDJSON strings
for i, item in enumerate(responses_so_far):
if isinstance(item, str) and parsed[i] is not None:
responses_so_far[i] = json.dumps(parsed[i]) + "\n"
return responses_so_far
def _extract_texts_from_result(
self,
result: Dict[str, Any],
texts_to_check: List[str],
task_mappings: List[Tuple[Tuple[str, ...], int]],
) -> None:
"""
Extract text from all possible locations in an A2A result.
Handles multiple response formats:
1. Direct message with parts: {"parts": [...]}
2. Nested message: {"message": {"parts": [...]}}
3. Task with artifacts: {"artifacts": [{"parts": [...]}]}
4. Task with status message: {"status": {"message": {"parts": [...]}}}
5. Streaming artifact-update: {"artifact": {"parts": [...]}}
"""
# Case 1: Direct parts in result (direct message)
if "parts" in result:
self._extract_texts_from_parts(
parts=result["parts"],
path=("parts",),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 2: Nested message
message = result.get("message")
if message and isinstance(message, dict) and "parts" in message:
self._extract_texts_from_parts(
parts=message["parts"],
path=("message", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 3: Streaming artifact-update (singular artifact)
artifact = result.get("artifact")
if artifact and isinstance(artifact, dict) and "parts" in artifact:
self._extract_texts_from_parts(
parts=artifact["parts"],
path=("artifact", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 4: Task with status message
status = result.get("status", {})
if isinstance(status, dict):
status_message = status.get("message")
if (
status_message
and isinstance(status_message, dict)
and "parts" in status_message
):
self._extract_texts_from_parts(
parts=status_message["parts"],
path=("status", "message", "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
# Case 5: Task with artifacts (plural, array)
artifacts = result.get("artifacts", [])
if artifacts and isinstance(artifacts, list):
for artifact_idx, art in enumerate(artifacts):
if isinstance(art, dict) and "parts" in art:
self._extract_texts_from_parts(
parts=art["parts"],
path=("artifacts", str(artifact_idx), "parts"),
texts_to_check=texts_to_check,
task_mappings=task_mappings,
)
def _extract_texts_from_parts(
self,
parts: List[Dict[str, Any]],
path: Tuple[str, ...],
texts_to_check: List[str],
task_mappings: List[Tuple[Tuple[str, ...], int]],
) -> None:
"""Extract text from message parts."""
for part_idx, part in enumerate(parts):
if part.get("kind") == "text":
text = part.get("text", "")
if text:
texts_to_check.append(text)
task_mappings.append((path, part_idx))
def _apply_text_to_path(
self,
result: Dict[Union[str, int], Any],
path: Tuple[str, ...],
part_idx: int,
text: str,
) -> None:
"""Apply guardrailed text back to the specified path in the result."""
# Navigate to the parts list
current = result
for key in path:
if key.isdigit():
# Array index
current = current[int(key)]
else:
current = current[key]
# Update the text in the part
current[part_idx]["text"] = text

View File

@@ -0,0 +1,105 @@
"""
A2A Streaming Response Iterator
"""
from typing import Optional, Union
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
from ..common_utils import extract_text_from_a2a_response
class A2AModelResponseIterator(BaseModelResponseIterator):
"""
Iterator for parsing A2A streaming responses.
Converts A2A JSON-RPC streaming chunks to OpenAI-compatible format.
"""
def __init__(
self,
streaming_response,
sync_stream: bool,
json_mode: Optional[bool] = False,
model: str = "a2a/agent",
):
super().__init__(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
self.model = model
def chunk_parser(
self, chunk: dict
) -> Union[GenericStreamingChunk, ModelResponseStream]:
"""
Parse A2A streaming chunk to OpenAI format.
A2A chunk format:
{
"jsonrpc": "2.0",
"id": "request-id",
"result": {
"message": {
"parts": [{"kind": "text", "text": "content"}]
}
}
}
Or for tasks:
{
"jsonrpc": "2.0",
"result": {
"kind": "task",
"status": {"state": "running"},
"artifacts": [{"parts": [{"kind": "text", "text": "content"}]}]
}
}
"""
try:
# Extract text from A2A response
text = extract_text_from_a2a_response(chunk)
# Determine finish reason
finish_reason = self._get_finish_reason(chunk)
# Return generic streaming chunk
return GenericStreamingChunk(
text=text,
is_finished=bool(finish_reason),
finish_reason=finish_reason or "",
usage=None,
index=0,
tool_use=None,
)
except Exception:
# Return empty chunk on parse error
return GenericStreamingChunk(
text="",
is_finished=False,
finish_reason="",
usage=None,
index=0,
tool_use=None,
)
def _get_finish_reason(self, chunk: dict) -> Optional[str]:
"""Extract finish reason from A2A chunk"""
result = chunk.get("result", {})
# Check for task completion
if isinstance(result, dict):
status = result.get("status", {})
if isinstance(status, dict):
state = status.get("state")
if state == "completed":
return "stop"
elif state == "failed":
return "stop" # Map failed state to 'stop' (valid finish_reason)
# Check for [DONE] marker
if chunk.get("done") is True:
return "stop"
return None

View File

@@ -0,0 +1,373 @@
"""
A2A Protocol Transformation for LiteLLM
"""
import uuid
from typing import Any, Dict, Iterator, List, Optional, Union
import httpx
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse
from ..common_utils import (
A2AError,
convert_messages_to_prompt,
extract_text_from_a2a_response,
)
from .streaming_iterator import A2AModelResponseIterator
class A2AConfig(BaseConfig):
"""
Configuration for A2A (Agent-to-Agent) Protocol.
Handles transformation between OpenAI and A2A JSON-RPC 2.0 formats.
"""
@staticmethod
def resolve_agent_config_from_registry(
model: str,
api_base: Optional[str],
api_key: Optional[str],
headers: Optional[Dict[str, Any]],
optional_params: Dict[str, Any],
) -> tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
"""
Resolve agent configuration from registry if model format is "a2a/<agent-name>".
Extracts agent name from model string and looks up configuration in the
agent registry (if available in proxy context).
Args:
model: Model string (e.g., "a2a/my-agent")
api_base: Explicit api_base (takes precedence over registry)
api_key: Explicit api_key (takes precedence over registry)
headers: Explicit headers (takes precedence over registry)
optional_params: Dict to merge additional litellm_params into
Returns:
Tuple of (api_base, api_key, headers) with registry values filled in
"""
# Extract agent name from model (e.g., "a2a/my-agent" -> "my-agent")
agent_name = model.split("/", 1)[1] if "/" in model else None
# Only lookup if agent name exists and some config is missing
if not agent_name or (
api_base is not None and api_key is not None and headers is not None
):
return api_base, api_key, headers
# Try registry lookup (only available in proxy context)
try:
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry,
)
agent = global_agent_registry.get_agent_by_name(agent_name)
if agent:
# Get api_base from agent card URL
if api_base is None and agent.agent_card_params:
api_base = agent.agent_card_params.get("url")
# Get api_key, headers, and other params from litellm_params
if agent.litellm_params:
if api_key is None:
api_key = agent.litellm_params.get("api_key")
if headers is None:
agent_headers = agent.litellm_params.get("headers")
if agent_headers:
headers = agent_headers
# Merge other litellm_params (timeout, max_retries, etc.)
for key, value in agent.litellm_params.items():
if (
key not in ["api_key", "api_base", "headers", "model"]
and key not in optional_params
):
optional_params[key] = value
except ImportError:
pass # Registry not available (not running in proxy context)
return api_base, api_key, headers
def get_supported_openai_params(self, model: str) -> List[str]:
"""Return list of supported OpenAI parameters"""
return [
"stream",
"temperature",
"max_tokens",
"top_p",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to A2A parameters.
For A2A protocol, we need to map the stream parameter so
transform_request can determine which JSON-RPC method to use.
"""
# Map stream parameter
for param, value in non_default_params.items():
if param == "stream" and value is True:
optional_params["stream"] = value
return optional_params
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set headers for A2A requests.
Args:
headers: Request headers dict
model: Model name
messages: Messages list
optional_params: Optional parameters
litellm_params: LiteLLM parameters
api_key: API key (optional for A2A)
api_base: API base URL
Returns:
Updated headers dict
"""
# Ensure Content-Type is set to application/json for JSON-RPC 2.0
if "content-type" not in headers and "Content-Type" not in headers:
headers["Content-Type"] = "application/json"
# Add Authorization header if API key is provided
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete A2A agent endpoint URL.
A2A agents use JSON-RPC 2.0 at the base URL, not specific paths.
The method (message/send or message/stream) is specified in the
JSON-RPC request body, not in the URL.
Args:
api_base: Base URL of the A2A agent (e.g., "http://0.0.0.0:9999")
api_key: API key (not used for URL construction)
model: Model name (not used for A2A, agent determined by api_base)
optional_params: Optional parameters
litellm_params: LiteLLM parameters
stream: Whether this is a streaming request (affects JSON-RPC method)
Returns:
Complete URL for the A2A endpoint (base URL)
"""
if api_base is None:
raise ValueError("api_base is required for A2A provider")
# A2A uses JSON-RPC 2.0 at the base URL
# Remove trailing slash for consistency
return api_base.rstrip("/")
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI request to A2A JSON-RPC 2.0 format.
Args:
model: Model name
messages: List of OpenAI messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters
headers: Request headers
Returns:
A2A JSON-RPC 2.0 request dict
"""
# Generate request ID
request_id = str(uuid.uuid4())
if not messages:
raise ValueError("At least one message is required for A2A completion")
# Convert all messages to maintain conversation history
# Use helper to format conversation with role prefixes
full_context = convert_messages_to_prompt(messages)
# Create single A2A message with full conversation context
a2a_message = {
"role": "user",
"parts": [{"kind": "text", "text": full_context}],
"messageId": str(uuid.uuid4()),
}
# Build JSON-RPC 2.0 request
# For A2A protocol, the method is "message/send" for non-streaming
# and "message/stream" for streaming
stream = optional_params.get("stream", False)
method = "message/stream" if stream else "message/send"
request_data = {
"jsonrpc": "2.0",
"id": request_id,
"method": method,
"params": {"message": a2a_message},
}
return request_data
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Any,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform A2A JSON-RPC 2.0 response to OpenAI format.
Args:
model: Model name
raw_response: HTTP response from A2A agent
model_response: Model response object to populate
logging_obj: Logging object
request_data: Original request data
messages: Original messages
optional_params: Optional parameters
litellm_params: LiteLLM parameters
encoding: Encoding object
api_key: API key
json_mode: JSON mode flag
Returns:
Populated ModelResponse object
"""
try:
response_json = raw_response.json()
except Exception as e:
raise A2AError(
status_code=raw_response.status_code,
message=f"Failed to parse A2A response: {str(e)}",
headers=dict(raw_response.headers),
)
# Check for JSON-RPC error
if "error" in response_json:
error = response_json["error"]
raise A2AError(
status_code=raw_response.status_code,
message=f"A2A error: {error.get('message', 'Unknown error')}",
headers=dict(raw_response.headers),
)
# Extract text from A2A response
text = extract_text_from_a2a_response(response_json)
# Populate model response
model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message=Message(
content=text,
role="assistant",
),
)
]
# Set model
model_response.model = model
# Set ID from response
model_response.id = response_json.get("id", str(uuid.uuid4()))
return model_response
def get_model_response_iterator(
self,
streaming_response: Union[Iterator, Any],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> BaseModelResponseIterator:
"""
Get streaming iterator for A2A responses.
Args:
streaming_response: Streaming response iterator
sync_stream: Whether this is a sync stream
json_mode: JSON mode flag
Returns:
A2A streaming iterator
"""
return A2AModelResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def _openai_message_to_a2a_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert OpenAI message to A2A message format.
Args:
message: OpenAI message dict
Returns:
A2A message dict
"""
content = message.get("content", "")
role = message.get("role", "user")
return {
"role": role,
"parts": [{"kind": "text", "text": str(content)}],
"messageId": str(uuid.uuid4()),
}
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
"""Return appropriate error class for A2A errors"""
# Convert headers to dict if needed
headers_dict = dict(headers) if isinstance(headers, httpx.Headers) else headers
return A2AError(
status_code=status_code,
message=error_message,
headers=headers_dict,
)

View File

@@ -0,0 +1,150 @@
"""
Common utilities for A2A (Agent-to-Agent) Protocol
"""
from typing import Any, Dict, List
from pydantic import BaseModel
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues
class A2AError(BaseLLMException):
"""Base exception for A2A protocol errors"""
def __init__(
self,
status_code: int,
message: str,
headers: Dict[str, Any] = {},
):
super().__init__(
status_code=status_code,
message=message,
headers=headers,
)
def convert_messages_to_prompt(messages: List[AllMessageValues]) -> str:
"""
Convert OpenAI messages to a single prompt string for A2A agent.
Formats each message as "{role}: {content}" and joins with newlines
to preserve conversation history. Handles both string and list content.
Args:
messages: List of OpenAI-format messages
Returns:
Formatted prompt string with full conversation context
"""
conversation_parts = []
for msg in messages:
# Use LiteLLM's helper to extract text from content (handles both str and list)
content_text = convert_content_list_to_str(message=msg)
# Get role
if isinstance(msg, BaseModel):
role = msg.model_dump().get("role", "user")
elif isinstance(msg, dict):
role = msg.get("role", "user")
else:
role = dict(msg).get("role", "user") # type: ignore
if content_text:
conversation_parts.append(f"{role}: {content_text}")
return "\n".join(conversation_parts)
def extract_text_from_a2a_message(
message: Dict[str, Any], depth: int = 0, max_depth: int = 10
) -> str:
"""
Extract text content from A2A message parts.
Args:
message: A2A message dict with 'parts' containing text parts
depth: Current recursion depth (internal use)
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
Concatenated text from all text parts
"""
if message is None or depth >= max_depth:
return ""
parts = message.get("parts", [])
text_parts: List[str] = []
for part in parts:
if part.get("kind") == "text":
text_parts.append(part.get("text", ""))
# Handle nested parts if they exist
elif "parts" in part:
nested_text = extract_text_from_a2a_message(part, depth + 1, max_depth)
if nested_text:
text_parts.append(nested_text)
return " ".join(text_parts)
def extract_text_from_a2a_response(
response_dict: Dict[str, Any], max_depth: int = 10
) -> str:
"""
Extract text content from A2A response result.
Args:
response_dict: A2A response dict with 'result' containing message
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
Text from response message parts
"""
result = response_dict.get("result", {})
if not isinstance(result, dict):
return ""
# A2A response can have different formats:
# 1. Direct message: {"result": {"kind": "message", "parts": [...]}}
# 2. Nested message: {"result": {"message": {"parts": [...]}}}
# 3. Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
# 4. Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
# 5. Streaming artifact-update: {"result": {"kind": "artifact-update", "artifact": {"parts": [...]}}}
# Check if result itself has parts (direct message)
if "parts" in result:
return extract_text_from_a2a_message(result, depth=0, max_depth=max_depth)
# Check for nested message
message = result.get("message")
if message:
return extract_text_from_a2a_message(message, depth=0, max_depth=max_depth)
# Check for streaming artifact-update (singular artifact)
artifact = result.get("artifact")
if artifact and isinstance(artifact, dict):
return extract_text_from_a2a_message(artifact, depth=0, max_depth=max_depth)
# Check for task status message (common in Gemini A2A agents)
status = result.get("status", {})
if isinstance(status, dict):
status_message = status.get("message")
if status_message:
return extract_text_from_a2a_message(
status_message, depth=0, max_depth=max_depth
)
# Handle task result with artifacts (plural, array)
artifacts = result.get("artifacts", [])
if artifacts and len(artifacts) > 0:
first_artifact = artifacts[0]
return extract_text_from_a2a_message(
first_artifact, depth=0, max_depth=max_depth
)
return ""

View File

@@ -0,0 +1,70 @@
"""
AI21 Chat Completions API
this is OpenAI compatible - no translation needed / occurs
"""
from typing import Optional, Union
from ...openai_like.chat.transformation import OpenAILikeChatConfig
class AI21ChatConfig(OpenAILikeChatConfig):
"""
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
Below are the parameters:
"""
tools: Optional[list] = None
response_format: Optional[dict] = None
documents: Optional[list] = None
max_tokens: Optional[int] = None
stop: Optional[Union[str, list]] = None
n: Optional[int] = None
stream: Optional[bool] = None
seed: Optional[int] = None
tool_choice: Optional[str] = None
user: Optional[str] = None
def __init__(
self,
tools: Optional[list] = None,
response_format: Optional[dict] = None,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
stop: Optional[Union[str, list]] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
seed: Optional[int] = None,
tool_choice: Optional[str] = None,
user: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the given model
"""
return [
"tools",
"response_format",
"max_tokens",
"max_completion_tokens",
"temperature",
"stop",
"n",
"stream",
"seed",
"tool_choice",
]

View File

@@ -0,0 +1,5 @@
from .image_generation import get_aiml_image_generation_config
__all__ = [
"get_aiml_image_generation_config",
]

View File

@@ -0,0 +1,24 @@
from typing import Optional, Tuple
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.secret_managers.main import get_secret_str
class AIMLChatConfig(OpenAIGPTConfig):
@property
def custom_llm_provider(self) -> Optional[str]:
return "aiml"
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
# AIML is openai compatible, we just need to set the api_base
api_base = (
api_base
or get_secret_str("AIML_API_BASE")
or "https://api.aimlapi.com/v1" # Default AIML API base URL
) # type: ignore
dynamic_api_key = api_key or get_secret_str("AIML_API_KEY")
return api_base, dynamic_api_key
pass

View File

@@ -0,0 +1,13 @@
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from .transformation import AimlImageGenerationConfig
__all__ = [
"AimlImageGenerationConfig",
]
def get_aiml_image_generation_config(model: str) -> BaseImageGenerationConfig:
return AimlImageGenerationConfig()

View File

@@ -0,0 +1,27 @@
from typing import Any
import litellm
from litellm.types.utils import ImageResponse
def cost_calculator(
model: str,
image_response: Any,
) -> float:
"""
AI/ML flux image generation cost calculator
"""
_model_info = litellm.get_model_info(
model=model,
custom_llm_provider=litellm.LlmProviders.AIML.value,
)
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
num_images: int = 0
if isinstance(image_response, ImageResponse):
if image_response.data:
num_images = len(image_response.data)
return output_cost_per_image * num_images
else:
raise ValueError(
f"image_response must be of type ImageResponse got type={type(image_response)}"
)

View File

@@ -0,0 +1,234 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.aiml import AimlImageGenerationRequestParams
from litellm.types.llms.openai import (
AllMessageValues,
OpenAIImageGenerationOptionalParams,
)
from litellm.types.utils import ImageObject, ImageResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AimlImageGenerationConfig(BaseImageGenerationConfig):
DEFAULT_BASE_URL: str = "https://api.aimlapi.com"
IMAGE_GENERATION_ENDPOINT: str = "v1/images/generations"
def get_supported_openai_params(
self, model: str
) -> List[OpenAIImageGenerationOptionalParams]:
"""
https://api.aimlapi.com/v1/images/generations
"""
return ["n", "response_format", "size"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
supported_params = self.get_supported_openai_params(model)
for k in non_default_params.keys():
if k not in optional_params.keys():
if k in supported_params:
# Map OpenAI params to AI/ML params
if k == "n":
optional_params["num_images"] = non_default_params[k]
elif k == "response_format":
optional_params["output_format"] = non_default_params[k]
elif k == "size":
# Map OpenAI size format to AI/ML image_size
size_value = non_default_params[k]
if isinstance(size_value, str):
# Handle standard OpenAI sizes like "1024x1024"
if "x" in size_value:
width, height = map(int, size_value.split("x"))
optional_params["image_size"] = {
"width": width,
"height": height,
}
else:
# Pass through predefined sizes
optional_params["image_size"] = size_value
else:
optional_params["image_size"] = size_value
else:
optional_params[k] = non_default_params[k]
elif drop_params:
pass
else:
raise ValueError(
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
)
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete url for the request
"""
complete_url: str = (
api_base or get_secret_str("AIML_API_BASE") or self.DEFAULT_BASE_URL
)
complete_url = complete_url.rstrip("/")
# Strip /v1 suffix if present since IMAGE_GENERATION_ENDPOINT already includes v1
if complete_url.endswith("/v1"):
complete_url = complete_url[:-3]
complete_url = f"{complete_url}/{self.IMAGE_GENERATION_ENDPOINT}"
return complete_url
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
final_api_key: Optional[str] = (
api_key
or get_secret_str("AIML_API_KEY")
or get_secret_str("AIMLAPI_KEY") # Alternative name
)
if not final_api_key:
raise ValueError("AIML_API_KEY or AIMLAPI_KEY is not set")
headers["Authorization"] = f"Bearer {final_api_key}"
headers["Content-Type"] = "application/json"
return headers
def transform_image_generation_request(
self,
model: str,
prompt: str,
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the image generation request to the AI/ML flux image generation request body
https://api.aimlapi.com/v1/images/generations
"""
aiml_image_generation_request_body: AimlImageGenerationRequestParams = (
AimlImageGenerationRequestParams(
prompt=prompt,
model=model,
**optional_params,
)
)
return dict(aiml_image_generation_request_body)
def transform_image_generation_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ImageResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ImageResponse:
"""
Transform the image generation response to the litellm image response
https://api.aimlapi.com/v1/images/generations
"""
try:
response_data = raw_response.json()
except Exception as e:
raise self.get_error_class(
error_message=f"Error transforming image generation response: {e}",
status_code=raw_response.status_code,
headers=raw_response.headers,
)
if not model_response.data:
model_response.data = []
# AI/ML API can return images in multiple formats:
# 1. Top-level data array with url (OpenAI-like format)
# 2. output.choices array with image_base64
# 3. images array with url (and optional width, height, content_type)
if "data" in response_data and isinstance(response_data["data"], list):
# Handle OpenAI-like format: {"data": [{"url": "...", "width": 1024, "height": 768, "content_type": "image/jpeg"}]}
for image in response_data["data"]:
if "url" in image:
model_response.data.append(
ImageObject(
b64_json=None,
url=image["url"],
revised_prompt=image.get("revised_prompt"),
)
)
elif "b64_json" in image or "image_base64" in image:
model_response.data.append(
ImageObject(
b64_json=image.get("b64_json") or image.get("image_base64"),
url=None,
revised_prompt=image.get("revised_prompt"),
)
)
elif "output" in response_data and "choices" in response_data["output"]:
for choice in response_data["output"]["choices"]:
if "image_base64" in choice:
model_response.data.append(
ImageObject(
b64_json=choice["image_base64"],
url=None,
)
)
elif "url" in choice:
model_response.data.append(
ImageObject(
b64_json=None,
url=choice["url"],
)
)
elif "images" in response_data:
# Handle alternative format: {"images": [{"url": "...", "width": 1024, "height": 768, "content_type": "image/jpeg"}]}
for image in response_data["images"]:
if "url" in image:
model_response.data.append(
ImageObject(
b64_json=None,
url=image["url"],
)
)
elif "image_base64" in image:
model_response.data.append(
ImageObject(
b64_json=image["image_base64"],
url=None,
)
)
return model_response

View File

@@ -0,0 +1,82 @@
"""
*New config* for using aiohttp to make the request to the custom OpenAI-like provider
This leads to 10x higher RPS than httpx
https://github.com/BerriAI/litellm/issues/6592
New config to ensure we introduce this without causing breaking changes for users
"""
from typing import TYPE_CHECKING, Any, List, Optional
from aiohttp import ClientResponse
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Ensure - /v1/chat/completions is at the end of the url
"""
if api_base is None:
api_base = "https://api.openai.com"
if not api_base.endswith("/chat/completions"):
api_base += "/chat/completions"
return api_base
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return {"Authorization": f"Bearer {api_key}"}
async def transform_response( # type: ignore
self,
model: str,
raw_response: ClientResponse,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
_json_response = await raw_response.json()
model_response.id = _json_response.get("id")
model_response.choices = [
Choices(**choice) for choice in _json_response.get("choices")
]
model_response.created = _json_response.get("created")
model_response.model = _json_response.get("model")
model_response.object = _json_response.get("object")
model_response.system_fingerprint = _json_response.get("system_fingerprint")
return model_response

View File

@@ -0,0 +1,115 @@
"""
Translate from OpenAI's `/v1/chat/completions` to Amazon Nova's `/v1/chat/completions`
"""
from typing import Any, List, Optional, Tuple
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import (
AllMessageValues,
)
from litellm.types.utils import ModelResponse
from ...openai_like.chat.transformation import OpenAILikeChatConfig
class AmazonNovaChatConfig(OpenAILikeChatConfig):
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
metadata: Optional[int] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
tools: Optional[list] = None
reasoning_effort: Optional[list] = None
def __init__(
self,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
tools: Optional[list] = None,
reasoning_effort: Optional[list] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@property
def custom_llm_provider(self) -> Optional[str]:
return "amazon_nova"
@classmethod
def get_config(cls):
return super().get_config()
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
# Amazon Nova is openai compatible, we just need to set this to custom_openai and have the api_base be Nova's endpoint
api_base = (
api_base
or get_secret_str("AMAZON_NOVA_API_BASE")
or "https://api.nova.amazon.com/v1"
) # type: ignore
# Get API key from multiple sources
key = (
api_key
or litellm.amazon_nova_api_key
or get_secret_str("AMAZON_NOVA_API_KEY")
or litellm.api_key
)
return api_base, key
def get_supported_openai_params(self, model: str) -> List:
return [
"top_p",
"temperature",
"max_tokens",
"max_completion_tokens",
"metadata",
"stop",
"stream",
"stream_options",
"tools",
"tool_choice",
"reasoning_effort",
]
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
model_response = super().transform_response(
model=model,
model_response=model_response,
raw_response=raw_response,
messages=messages,
logging_obj=logging_obj,
request_data=request_data,
encoding=encoding,
optional_params=optional_params,
json_mode=json_mode,
litellm_params=litellm_params,
api_key=api_key,
)
# Storing amazon_nova in the model response for easier cost calculation later
setattr(model_response, "model", "amazon-nova/" + model)
return model_response

View File

@@ -0,0 +1,21 @@
"""
Helper util for handling amazon nova cost calculation
- e.g.: prompt caching
"""
from typing import TYPE_CHECKING, Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
if TYPE_CHECKING:
from litellm.types.utils import Usage
def cost_per_token(model: str, usage: "Usage") -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Follows the same logic as Anthropic's cost per token calculation.
"""
return generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="amazon_nova"
)

View File

@@ -0,0 +1,15 @@
from typing import Type, Union
from .batches.transformation import AnthropicBatchesConfig
from .chat.transformation import AnthropicConfig
__all__ = ["AnthropicBatchesConfig", "AnthropicConfig"]
def get_anthropic_config(
url_route: str,
) -> Union[Type[AnthropicBatchesConfig], Type[AnthropicConfig]]:
if "messages/batches" in url_route and "results" in url_route:
return AnthropicBatchesConfig
else:
return AnthropicConfig

View File

@@ -0,0 +1,4 @@
from .handler import AnthropicBatchesHandler
from .transformation import AnthropicBatchesConfig
__all__ = ["AnthropicBatchesHandler", "AnthropicBatchesConfig"]

View File

@@ -0,0 +1,167 @@
"""
Anthropic Batches API Handler
"""
import asyncio
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
import httpx
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.types.utils import LiteLLMBatch, LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from ..common_utils import AnthropicModelInfo
from .transformation import AnthropicBatchesConfig
class AnthropicBatchesHandler:
"""
Handler for Anthropic Message Batches API.
Supports:
- retrieve_batch() - Retrieve batch status and information
"""
def __init__(self):
self.anthropic_model_info = AnthropicModelInfo()
self.provider_config = AnthropicBatchesConfig()
async def aretrieve_batch(
self,
batch_id: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
logging_obj: Optional[LiteLLMLoggingObj] = None,
) -> LiteLLMBatch:
"""
Async: Retrieve a batch from Anthropic.
Args:
batch_id: The batch ID to retrieve
api_base: Anthropic API base URL
api_key: Anthropic API key
timeout: Request timeout
max_retries: Max retry attempts (unused for now)
logging_obj: Optional logging object
Returns:
LiteLLMBatch: Batch information in OpenAI format
"""
# Resolve API credentials
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
api_key = api_key or self.anthropic_model_info.get_api_key()
if not api_key:
raise ValueError("Missing Anthropic API Key")
# Create a minimal logging object if not provided
if logging_obj is None:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObjClass,
)
logging_obj = LiteLLMLoggingObjClass(
model="anthropic/unknown",
messages=[],
stream=False,
call_type="batch_retrieve",
start_time=None,
litellm_call_id=f"batch_retrieve_{batch_id}",
function_id="batch_retrieve",
)
# Get the complete URL for batch retrieval
retrieve_url = self.provider_config.get_retrieve_batch_url(
api_base=api_base,
batch_id=batch_id,
optional_params={},
litellm_params={},
)
# Validate environment and get headers
headers = self.provider_config.validate_environment(
headers={},
model="",
messages=[],
optional_params={},
litellm_params={},
api_key=api_key,
api_base=api_base,
)
logging_obj.pre_call(
input=batch_id,
api_key=api_key,
additional_args={
"api_base": retrieve_url,
"headers": headers,
"complete_input_dict": {},
},
)
# Make the request
async_client = get_async_httpx_client(llm_provider=LlmProviders.ANTHROPIC)
response = await async_client.get(url=retrieve_url, headers=headers)
response.raise_for_status()
# Transform response to LiteLLM format
return self.provider_config.transform_retrieve_batch_response(
model=None,
raw_response=response,
logging_obj=logging_obj,
litellm_params={},
)
def retrieve_batch(
self,
_is_async: bool,
batch_id: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
logging_obj: Optional[LiteLLMLoggingObj] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Retrieve a batch from Anthropic.
Args:
_is_async: Whether to run asynchronously
batch_id: The batch ID to retrieve
api_base: Anthropic API base URL
api_key: Anthropic API key
timeout: Request timeout
max_retries: Max retry attempts (unused for now)
logging_obj: Optional logging object
Returns:
LiteLLMBatch or Coroutine: Batch information in OpenAI format
"""
if _is_async:
return self.aretrieve_batch(
batch_id=batch_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=max_retries,
logging_obj=logging_obj,
)
else:
return asyncio.run(
self.aretrieve_batch(
batch_id=batch_id,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=max_retries,
logging_obj=logging_obj,
)
)

View File

@@ -0,0 +1,312 @@
import json
import time
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
import httpx
from httpx import Headers, Response
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.openai import AllMessageValues, CreateBatchRequest
from litellm.types.utils import LiteLLMBatch, LlmProviders, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class AnthropicBatchesConfig(BaseBatchesConfig):
def __init__(self):
from ..chat.transformation import AnthropicConfig
from ..common_utils import AnthropicModelInfo
self.anthropic_chat_config = AnthropicConfig() # initialize once
self.anthropic_model_info = AnthropicModelInfo()
@property
def custom_llm_provider(self) -> LlmProviders:
"""Return the LLM provider type for this configuration."""
return LlmProviders.ANTHROPIC
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""Validate and prepare environment-specific headers and parameters."""
# Resolve api_key from environment if not provided
api_key = api_key or self.anthropic_model_info.get_api_key()
if api_key is None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
_headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": api_key,
}
# Add beta header for message batches
if "anthropic-beta" not in headers:
headers["anthropic-beta"] = "message-batches-2024-09-24"
headers.update(_headers)
return headers
def get_complete_batch_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: Dict,
litellm_params: Dict,
data: CreateBatchRequest,
) -> str:
"""Get the complete URL for batch creation request."""
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
if not api_base.endswith("/v1/messages/batches"):
api_base = f"{api_base.rstrip('/')}/v1/messages/batches"
return api_base
def transform_create_batch_request(
self,
model: str,
create_batch_data: CreateBatchRequest,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, Dict[str, Any]]:
"""
Transform the batch creation request to Anthropic format.
Not currently implemented - placeholder to satisfy abstract base class.
"""
raise NotImplementedError("Batch creation not yet implemented for Anthropic")
def transform_create_batch_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LoggingClass,
litellm_params: dict,
) -> LiteLLMBatch:
"""
Transform Anthropic MessageBatch creation response to LiteLLM format.
Not currently implemented - placeholder to satisfy abstract base class.
"""
raise NotImplementedError("Batch creation not yet implemented for Anthropic")
def get_retrieve_batch_url(
self,
api_base: Optional[str],
batch_id: str,
optional_params: Dict,
litellm_params: Dict,
) -> str:
"""
Get the complete URL for batch retrieval request.
Args:
api_base: Base API URL (optional, will use default if not provided)
batch_id: Batch ID to retrieve
optional_params: Optional parameters
litellm_params: LiteLLM parameters
Returns:
Complete URL for Anthropic batch retrieval: {api_base}/v1/messages/batches/{batch_id}
"""
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
return f"{api_base.rstrip('/')}/v1/messages/batches/{batch_id}"
def transform_retrieve_batch_request(
self,
batch_id: str,
optional_params: dict,
litellm_params: dict,
) -> Union[bytes, str, Dict[str, Any]]:
"""
Transform batch retrieval request for Anthropic.
For Anthropic, the URL is constructed by get_retrieve_batch_url(),
so this method returns an empty dict (no additional request params needed).
"""
# No additional request params needed - URL is handled by get_retrieve_batch_url
return {}
def transform_retrieve_batch_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LoggingClass,
litellm_params: dict,
) -> LiteLLMBatch:
"""Transform Anthropic MessageBatch retrieval response to LiteLLM format."""
try:
response_data = raw_response.json()
except Exception as e:
raise ValueError(f"Failed to parse Anthropic batch response: {e}")
# Map Anthropic MessageBatch to OpenAI Batch format
batch_id = response_data.get("id", "")
processing_status = response_data.get("processing_status", "in_progress")
# Map Anthropic processing_status to OpenAI status
status_mapping: Dict[
str,
Literal[
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
],
] = {
"in_progress": "in_progress",
"canceling": "cancelling",
"ended": "completed",
}
openai_status = status_mapping.get(processing_status, "in_progress")
# Parse timestamps
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
if not ts_str:
return None
try:
from datetime import datetime
dt = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
return int(dt.timestamp())
except Exception:
return None
created_at = parse_timestamp(response_data.get("created_at"))
ended_at = parse_timestamp(response_data.get("ended_at"))
expires_at = parse_timestamp(response_data.get("expires_at"))
cancel_initiated_at = parse_timestamp(response_data.get("cancel_initiated_at"))
archived_at = parse_timestamp(response_data.get("archived_at"))
# Extract request counts
request_counts_data = response_data.get("request_counts", {})
from openai.types.batch import BatchRequestCounts
request_counts = BatchRequestCounts(
total=sum(
[
request_counts_data.get("processing", 0),
request_counts_data.get("succeeded", 0),
request_counts_data.get("errored", 0),
request_counts_data.get("canceled", 0),
request_counts_data.get("expired", 0),
]
),
completed=request_counts_data.get("succeeded", 0),
failed=request_counts_data.get("errored", 0),
)
return LiteLLMBatch(
id=batch_id,
object="batch",
endpoint="/v1/messages",
errors=None,
input_file_id="None",
completion_window="24h",
status=openai_status,
output_file_id=batch_id,
error_file_id=None,
created_at=created_at or int(time.time()),
in_progress_at=created_at if processing_status == "in_progress" else None,
expires_at=expires_at,
finalizing_at=None,
completed_at=ended_at if processing_status == "ended" else None,
failed_at=None,
expired_at=archived_at if archived_at else None,
cancelling_at=cancel_initiated_at
if processing_status == "canceling"
else None,
cancelled_at=ended_at
if processing_status == "canceling" and ended_at
else None,
request_counts=request_counts,
metadata={},
)
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
) -> "BaseLLMException":
"""Get the appropriate error class for Anthropic."""
from ..common_utils import AnthropicError
# Convert Dict to Headers if needed
if isinstance(headers, dict):
headers_obj: Optional[Headers] = Headers(headers)
else:
headers_obj = headers if isinstance(headers, Headers) else None
return AnthropicError(
status_code=status_code, message=error_message, headers=headers_obj
)
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: Dict,
messages: List[AllMessageValues],
optional_params: Dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
from litellm.cost_calculator import BaseTokenUsageProcessor
from litellm.types.utils import Usage
response_text = raw_response.text.strip()
all_usage: List[Usage] = []
try:
# Split by newlines and try to parse each line as JSON
lines = response_text.split("\n")
for line in lines:
line = line.strip()
if not line:
continue
try:
response_json = json.loads(line)
# Update model_response with the parsed JSON
completion_response = response_json["result"]["message"]
transformed_response = (
self.anthropic_chat_config.transform_parsed_response(
completion_response=completion_response,
raw_response=raw_response,
model_response=model_response,
)
)
transformed_response_usage = getattr(
transformed_response, "usage", None
)
if transformed_response_usage:
all_usage.append(cast(Usage, transformed_response_usage))
except json.JSONDecodeError:
continue
## SUM ALL USAGE
combined_usage = BaseTokenUsageProcessor.combine_usage_objects(all_usage)
setattr(model_response, "usage", combined_usage)
return model_response
except Exception as e:
raise e

View File

@@ -0,0 +1 @@
from .handler import AnthropicChatCompletion, ModelResponseIterator

View File

@@ -0,0 +1,10 @@
from litellm.llms.anthropic.chat.guardrail_translation.handler import (
AnthropicMessagesHandler,
)
from litellm.types.utils import CallTypes
guardrail_translation_mappings = {
CallTypes.anthropic_messages: AnthropicMessagesHandler,
}
__all__ = ["guardrail_translation_mappings"]

View File

@@ -0,0 +1,688 @@
"""
Anthropic Message Handler for Unified Guardrails
This module provides a class-based handler for Anthropic-format messages.
The class methods can be overridden for custom behavior.
Pattern Overview:
-----------------
1. Extract text content from messages/responses (both string and list formats)
2. Create async tasks to apply guardrails to each text segment
3. Track mappings to know where each response belongs
4. Apply guardrail responses back to the original structure
"""
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
from litellm._logging import verbose_proxy_logger
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import (
LiteLLMAnthropicMessagesAdapter,
)
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthropicMessagesRequest,
)
from litellm.types.llms.openai import (
ChatCompletionToolCallChunk,
ChatCompletionToolParam,
)
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Choices,
GenericGuardrailAPIInputs,
ModelResponse,
)
if TYPE_CHECKING:
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
class AnthropicMessagesHandler(BaseTranslation):
"""
Handler for processing Anthropic messages with guardrails.
This class provides methods to:
1. Process input messages (pre-call hook)
2. Process output responses (post-call hook)
Methods can be overridden to customize behavior for different message formats.
"""
def __init__(self):
super().__init__()
self.adapter = LiteLLMAnthropicMessagesAdapter()
async def process_input_messages(
self,
data: dict,
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Any:
"""
Process input messages by applying guardrails to text content.
"""
messages = data.get("messages")
if messages is None:
return data
(
chat_completion_compatible_request,
_tool_name_mapping,
) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
# Use a shallow copy to avoid mutating request data (pop on litellm_metadata).
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
)
structured_messages = chat_completion_compatible_request.get("messages", [])
texts_to_check: List[str] = []
images_to_check: List[str] = []
tools_to_check: List[
ChatCompletionToolParam
] = chat_completion_compatible_request.get("tools", [])
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (message_index, content_index) for each text
# content_index is None for string content, int for list content
# Step 1: Extract all text content and images
for msg_idx, message in enumerate(messages):
self._extract_input_text_and_images(
message=message,
msg_idx=msg_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check:
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tools_to_check:
inputs["tools"] = tools_to_check
if structured_messages:
inputs["structured_messages"] = structured_messages
# Include model information if available
model = data.get("model")
if model:
inputs["model"] = model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
guardrailed_tools = guardrailed_inputs.get("tools")
if guardrailed_tools is not None:
data["tools"] = guardrailed_tools
# Step 3: Map guardrail responses back to original message structure
await self._apply_guardrail_responses_to_input(
messages=messages,
responses=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"Anthropic Messages: Processed input messages: %s", messages
)
return data
def extract_request_tool_names(self, data: dict) -> List[str]:
"""Extract tool names from Anthropic messages request (tools[].name)."""
names: List[str] = []
for tool in data.get("tools") or []:
if isinstance(tool, dict) and tool.get("name"):
names.append(str(tool["name"]))
return names
def _extract_input_text_and_images(
self,
message: Dict[str, Any],
msg_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Extract text content and images from a message.
Override this method to customize text/image extraction logic.
"""
content = message.get("content", None)
tools = message.get("tools", None)
if content is None and tools is None:
return
## CHECK FOR TEXT + IMAGES
if content is not None and isinstance(content, str):
# Simple string content
texts_to_check.append(content)
task_mappings.append((msg_idx, None))
elif content is not None and isinstance(content, list):
# List content (e.g., multimodal with text and images)
for content_idx, content_item in enumerate(content):
# Extract text
text_str = content_item.get("text", None)
if text_str is not None:
texts_to_check.append(text_str)
task_mappings.append((msg_idx, int(content_idx)))
# Extract images
if content_item.get("type") == "image":
source = content_item.get("source", {})
if isinstance(source, dict):
# Could be base64 or url
data = source.get("data")
if data:
images_to_check.append(data)
def _extract_input_tools(
self,
tools: List[Dict[str, Any]],
tools_to_check: List[ChatCompletionToolParam],
) -> None:
"""
Extract tools from a message.
"""
## CHECK FOR TOOLS
if tools is not None and isinstance(tools, list):
# TRANSFORM ANTHROPIC TOOLS TO OPENAI TOOLS
openai_tools = self.adapter.translate_anthropic_tools_to_openai(
tools=cast(List[AllAnthropicToolsValues], tools)
)
tools_to_check.extend(openai_tools) # type: ignore
async def _apply_guardrail_responses_to_input(
self,
messages: List[Dict[str, Any]],
responses: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to input messages.
Override this method to customize how responses are applied.
"""
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
msg_idx = cast(int, mapping[0])
content_idx_optional = cast(Optional[int], mapping[1])
content = messages[msg_idx].get("content", None)
if content is None:
continue
if isinstance(content, str) and content_idx_optional is None:
# Replace string content with guardrail response
messages[msg_idx]["content"] = guardrail_response
elif isinstance(content, list) and content_idx_optional is not None:
# Replace specific text item in list content
messages[msg_idx]["content"][content_idx_optional][
"text"
] = guardrail_response
async def process_output_response(
self,
response: "AnthropicMessagesResponse",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
"""
Process output response by applying guardrails to text content and tool calls.
Args:
response: Anthropic MessagesResponse object
guardrail_to_apply: The guardrail instance to apply
litellm_logging_obj: Optional logging object
user_api_key_dict: User API key metadata to pass to guardrails
Returns:
Modified response with guardrail applied to content
Response Format Support:
- List content: response.content = [
{"type": "text", "text": "text here"},
{"type": "tool_use", "id": "...", "name": "...", "input": {...}},
...
]
"""
texts_to_check: List[str] = []
images_to_check: List[str] = []
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
task_mappings: List[Tuple[int, Optional[int]]] = []
# Track (content_index, None) for each text
# Handle both dict and object responses
response_content: List[Any] = []
if isinstance(response, dict):
response_content = response.get("content", []) or []
elif hasattr(response, "content"):
content = getattr(response, "content", None)
response_content = content or []
else:
response_content = []
if not response_content:
return response
# Step 1: Extract all text content and tool calls from response
for content_idx, content_block in enumerate(response_content):
# Handle both dict and Pydantic object content blocks
block_dict: Dict[str, Any] = {}
if isinstance(content_block, dict):
block_type = content_block.get("type")
block_dict = cast(Dict[str, Any], content_block)
elif hasattr(content_block, "type"):
block_type = getattr(content_block, "type", None)
# Convert Pydantic object to dict for processing
if hasattr(content_block, "model_dump"):
block_dict = content_block.model_dump()
else:
block_dict = {
"type": block_type,
"text": getattr(content_block, "text", None),
}
else:
continue
if block_type in ["text", "tool_use"]:
self._extract_output_text_and_images(
content_block=block_dict,
content_idx=content_idx,
texts_to_check=texts_to_check,
images_to_check=images_to_check,
task_mappings=task_mappings,
tool_calls_to_check=tool_calls_to_check,
)
# Step 2: Apply guardrail to all texts in batch
if texts_to_check or tool_calls_to_check:
# Create a request_data dict with response info and user API key metadata
request_data: dict = {"response": response}
# Add user API key metadata with prefixed keys
user_metadata = self.transform_user_api_key_dict_to_metadata(
user_api_key_dict
)
if user_metadata:
request_data["litellm_metadata"] = user_metadata
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
if images_to_check:
inputs["images"] = images_to_check
if tool_calls_to_check:
inputs["tool_calls"] = tool_calls_to_check
# Include model information from the response if available
response_model = None
if isinstance(response, dict):
response_model = response.get("model")
elif hasattr(response, "model"):
response_model = getattr(response, "model", None)
if response_model:
inputs["model"] = response_model
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="response",
logging_obj=litellm_logging_obj,
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
# Step 3: Map guardrail responses back to original response structure
await self._apply_guardrail_responses_to_output(
response=response,
responses=guardrailed_texts,
task_mappings=task_mappings,
)
verbose_proxy_logger.debug(
"Anthropic Messages: Processed output response: %s", response
)
return response
async def process_output_streaming_response(
self,
responses_so_far: List[Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> List[Any]:
"""
Process output streaming response by applying guardrails to text content.
Get the string so far, check the apply guardrail to the string so far, and return the list of responses so far.
"""
has_ended = self._check_streaming_has_ended(responses_so_far)
if has_ended:
# build the model response from the responses_so_far
built_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=responses_so_far,
litellm_logging_obj=cast("LiteLLMLoggingObj", litellm_logging_obj),
model="",
)
)
# Check if model_response is valid and has choices before accessing
if (
built_response is not None
and hasattr(built_response, "choices")
and built_response.choices
):
model_response = cast(ModelResponse, built_response)
first_choice = cast(Choices, model_response.choices[0])
tool_calls_list = cast(
Optional[List[ChatCompletionMessageToolCall]],
first_choice.message.tool_calls,
)
string_so_far = first_choice.message.content
guardrail_inputs = GenericGuardrailAPIInputs()
if string_so_far:
guardrail_inputs["texts"] = [string_so_far]
if tool_calls_list:
guardrail_inputs["tool_calls"] = tool_calls_list
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
inputs=guardrail_inputs,
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
)
else:
verbose_proxy_logger.debug(
"Skipping output guardrail - model response has no choices"
)
return responses_so_far
string_so_far = self.get_streaming_string_so_far(responses_so_far)
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
inputs={"texts": [string_so_far]},
request_data={},
input_type="response",
logging_obj=litellm_logging_obj,
)
return responses_so_far
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
"""
Parse streaming responses and extract accumulated text content.
Handles two formats:
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
2. Parsed dict objects (for backwards compatibility)
SSE format example:
b'event: content_block_delta\\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" curious"}}\\n\\n'
Dict format example:
{
"type": "content_block_delta",
"index": 0,
"delta": {
"type": "text_delta",
"text": " curious"
}
}
"""
text_so_far = ""
for response in responses_so_far:
# Handle raw bytes in SSE format
if isinstance(response, bytes):
text_so_far += self._extract_text_from_sse(response)
# Handle already-parsed dict format
elif isinstance(response, dict):
delta = response.get("delta") if response.get("delta") else None
if delta and delta.get("type") == "text_delta":
text = delta.get("text", "")
if text:
text_so_far += text
return text_so_far
def _extract_text_from_sse(self, sse_bytes: bytes) -> str:
"""
Extract text content from Server-Sent Events (SSE) format.
Args:
sse_bytes: Raw bytes in SSE format
Returns:
Accumulated text from all content_block_delta events
"""
text = ""
try:
# Decode bytes to string
sse_string = sse_bytes.decode("utf-8")
# Split by double newline to get individual events
events = sse_string.split("\n\n")
for event in events:
if not event.strip():
continue
# Parse event lines
lines = event.strip().split("\n")
event_type = None
data_line = None
for line in lines:
if line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_line = line[5:].strip()
# Only process content_block_delta events
if event_type == "content_block_delta" and data_line:
try:
data = json.loads(data_line)
delta = data.get("delta", {})
if delta.get("type") == "text_delta":
text += delta.get("text", "")
except json.JSONDecodeError:
verbose_proxy_logger.warning(
f"Failed to parse JSON from SSE data: {data_line}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error extracting text from SSE: {e}")
return text
def _check_streaming_has_ended(self, responses_so_far: List[Any]) -> bool:
"""
Check if streaming response has ended by looking for non-null stop_reason.
Handles two formats:
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
2. Parsed dict objects (for backwards compatibility)
SSE format example:
b'event: message_delta\\ndata: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},...}\\n\\n'
Dict format example:
{
"type": "message_delta",
"delta": {
"stop_reason": "tool_use",
"stop_sequence": null
}
}
Returns:
True if stop_reason is set to a non-null value, indicating stream has ended
"""
for response in responses_so_far:
# Handle raw bytes in SSE format
if isinstance(response, bytes):
try:
# Decode bytes to string
sse_string = response.decode("utf-8")
# Split by double newline to get individual events
events = sse_string.split("\n\n")
for event in events:
if not event.strip():
continue
# Parse event lines
lines = event.strip().split("\n")
event_type = None
data_line = None
for line in lines:
if line.startswith("event:"):
event_type = line[6:].strip()
elif line.startswith("data:"):
data_line = line[5:].strip()
# Check for message_delta event with stop_reason
if event_type == "message_delta" and data_line:
try:
data = json.loads(data_line)
delta = data.get("delta", {})
stop_reason = delta.get("stop_reason")
if stop_reason is not None:
return True
except json.JSONDecodeError:
verbose_proxy_logger.warning(
f"Failed to parse JSON from SSE data: {data_line}"
)
except Exception as e:
verbose_proxy_logger.error(
f"Error checking streaming end in SSE: {e}"
)
# Handle already-parsed dict format
elif isinstance(response, dict):
if response.get("type") == "message_delta":
delta = response.get("delta", {})
stop_reason = delta.get("stop_reason")
if stop_reason is not None:
return True
return False
def _has_text_content(self, response: "AnthropicMessagesResponse") -> bool:
"""
Check if response has any text content to process.
Override this method to customize text content detection.
"""
if isinstance(response, dict):
response_content = response.get("content", [])
else:
response_content = getattr(response, "content", None) or []
if not response_content:
return False
for content_block in response_content:
# Check if this is a text block by checking the 'type' field
if isinstance(content_block, dict) and content_block.get("type") == "text":
content_text = content_block.get("text")
if content_text and isinstance(content_text, str):
return True
return False
def _extract_output_text_and_images(
self,
content_block: Dict[str, Any],
content_idx: int,
texts_to_check: List[str],
images_to_check: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
tool_calls_to_check: Optional[List[ChatCompletionToolCallChunk]] = None,
) -> None:
"""
Extract text content, images, and tool calls from a response content block.
Override this method to customize text/image/tool extraction logic.
"""
content_type = content_block.get("type")
# Extract text content
if content_type == "text":
content_text = content_block.get("text")
if content_text and isinstance(content_text, str):
# Simple string content
texts_to_check.append(content_text)
task_mappings.append((content_idx, None))
# Extract tool calls
elif content_type == "tool_use":
tool_call = AnthropicConfig.convert_tool_use_to_openai_format(
anthropic_tool_content=content_block,
index=content_idx,
)
if tool_calls_to_check is None:
tool_calls_to_check = []
tool_calls_to_check.append(tool_call)
async def _apply_guardrail_responses_to_output(
self,
response: "AnthropicMessagesResponse",
responses: List[str],
task_mappings: List[Tuple[int, Optional[int]]],
) -> None:
"""
Apply guardrail responses back to output response.
Override this method to customize how responses are applied.
"""
for task_idx, guardrail_response in enumerate(responses):
mapping = task_mappings[task_idx]
content_idx = cast(int, mapping[0])
# Handle both dict and object responses
response_content: List[Any] = []
if isinstance(response, dict):
response_content = response.get("content", []) or []
elif hasattr(response, "content"):
content = getattr(response, "content", None)
response_content = content or []
else:
continue
if not response_content:
continue
# Get the content block at the index
if content_idx >= len(response_content):
continue
content_block = response_content[content_idx]
# Verify it's a text block and update the text field
# Handle both dict and Pydantic object content blocks
if isinstance(content_block, dict):
if content_block.get("type") == "text":
cast(Dict[str, Any], content_block)["text"] = guardrail_response
elif (
hasattr(content_block, "type")
and getattr(content_block, "type", None) == "text"
):
# Update Pydantic object's text attribute
if hasattr(content_block, "text"):
content_block.text = guardrail_response

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,627 @@
"""
This file contains common utils for anthropic calls.
"""
from typing import Dict, List, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_file_ids_from_messages,
)
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.anthropic import (
ANTHROPIC_HOSTED_TOOLS,
ANTHROPIC_OAUTH_BETA_HEADER,
ANTHROPIC_OAUTH_TOKEN_PREFIX,
AllAnthropicToolsValues,
AnthropicMcpServerTool,
)
from litellm.types.llms.openai import AllMessageValues
def is_anthropic_oauth_key(value: Optional[str]) -> bool:
"""Check if a value contains an Anthropic OAuth token (sk-ant-oat*)."""
if value is None:
return False
# Handle both raw token and "Bearer <token>" format
if value.startswith("Bearer "):
value = value[7:]
return value.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
def _merge_beta_headers(existing: Optional[str], new_beta: str) -> str:
"""Merge a new beta value into an existing comma-separated anthropic-beta header."""
if not existing:
return new_beta
betas = {b.strip() for b in existing.split(",") if b.strip()}
betas.add(new_beta)
return ",".join(sorted(betas))
def optionally_handle_anthropic_oauth(
headers: dict, api_key: Optional[str]
) -> tuple[dict, Optional[str]]:
"""
Handle Anthropic OAuth token detection and header setup.
If an OAuth token is detected in the Authorization header, extracts it
and sets the required OAuth headers.
Args:
headers: Request headers dict
api_key: Current API key (may be None)
Returns:
Tuple of (updated headers, api_key)
"""
# Check Authorization header (passthrough / forwarded requests)
auth_header = headers.get("authorization", "")
if auth_header and auth_header.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
api_key = auth_header.replace("Bearer ", "")
headers.pop("x-api-key", None)
headers["anthropic-beta"] = _merge_beta_headers(
headers.get("anthropic-beta"), ANTHROPIC_OAUTH_BETA_HEADER
)
headers["anthropic-dangerous-direct-browser-access"] = "true"
return headers, api_key
# Check api_key directly (standard chat/completion flow)
if api_key and api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX):
headers.pop("x-api-key", None)
headers["authorization"] = f"Bearer {api_key}"
headers["anthropic-beta"] = _merge_beta_headers(
headers.get("anthropic-beta"), ANTHROPIC_OAUTH_BETA_HEADER
)
headers["anthropic-dangerous-direct-browser-access"] = "true"
return headers, api_key
class AnthropicError(BaseLLMException):
def __init__(
self,
status_code: int,
message,
headers: Optional[httpx.Headers] = None,
):
super().__init__(status_code=status_code, message=message, headers=headers)
class AnthropicModelInfo(BaseLLMModelInfo):
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"cache_control": ..} in message content block
Used to check if anthropic prompt caching headers need to be set.
"""
for message in messages:
if message.get("cache_control", None) is not None:
return True
_message_content = message.get("content")
if _message_content is not None and isinstance(_message_content, list):
for content in _message_content:
if "cache_control" in content:
return True
return False
def is_file_id_used(self, messages: List[AllMessageValues]) -> bool:
"""
Return if {"source": {"type": "file", "file_id": ..}} in message content block
"""
file_ids = get_file_ids_from_messages(messages)
return len(file_ids) > 0
def is_mcp_server_used(
self, mcp_servers: Optional[List[AnthropicMcpServerTool]]
) -> bool:
if mcp_servers is None:
return False
if mcp_servers:
return True
return False
def is_computer_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> Optional[str]:
"""Returns the computer tool version if used, e.g. 'computer_20250124' or None"""
if tools is None:
return None
for tool in tools:
if "type" in tool and tool["type"].startswith("computer_"):
return tool["type"]
return None
def is_web_search_tool_used(
self, tools: Optional[List[AllAnthropicToolsValues]]
) -> bool:
"""Returns True if web_search tool is used"""
if tools is None:
return False
for tool in tools:
if "type" in tool and tool["type"].startswith(
ANTHROPIC_HOSTED_TOOLS.WEB_SEARCH.value
):
return True
return False
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
"""
Set to true if media passed into messages.
"""
for message in messages:
if (
"content" in message
and message["content"] is not None
and isinstance(message["content"], list)
):
for content in message["content"]:
if "type" in content and content["type"] != "text":
return True
return False
def is_tool_search_used(self, tools: Optional[List]) -> bool:
"""
Check if tool search tools are present in the tools list.
"""
if not tools:
return False
for tool in tools:
tool_type = tool.get("type", "")
if tool_type in [
"tool_search_tool_regex_20251119",
"tool_search_tool_bm25_20251119",
]:
return True
return False
def is_programmatic_tool_calling_used(self, tools: Optional[List]) -> bool:
"""
Check if programmatic tool calling is being used (tools with allowed_callers field).
Returns True if any tool has allowed_callers containing 'code_execution_20250825'.
"""
if not tools:
return False
for tool in tools:
# Check top-level allowed_callers
allowed_callers = tool.get("allowed_callers", None)
if allowed_callers and isinstance(allowed_callers, list):
if "code_execution_20250825" in allowed_callers:
return True
# Check function.allowed_callers for OpenAI format tools
function = tool.get("function", {})
if isinstance(function, dict):
function_allowed_callers = function.get("allowed_callers", None)
if function_allowed_callers and isinstance(
function_allowed_callers, list
):
if "code_execution_20250825" in function_allowed_callers:
return True
return False
def is_input_examples_used(self, tools: Optional[List]) -> bool:
"""
Check if input_examples is being used in any tools.
Returns True if any tool has input_examples field.
"""
if not tools:
return False
for tool in tools:
# Check top-level input_examples
input_examples = tool.get("input_examples", None)
if (
input_examples
and isinstance(input_examples, list)
and len(input_examples) > 0
):
return True
# Check function.input_examples for OpenAI format tools
function = tool.get("function", {})
if isinstance(function, dict):
function_input_examples = function.get("input_examples", None)
if (
function_input_examples
and isinstance(function_input_examples, list)
and len(function_input_examples) > 0
):
return True
return False
@staticmethod
def _is_claude_4_6_model(model: str) -> bool:
"""Check if the model is a Claude 4.6 model (Opus 4.6 or Sonnet 4.6)."""
model_lower = model.lower()
return any(
v in model_lower
for v in (
"opus-4-6",
"opus_4_6",
"opus-4.6",
"opus_4.6",
"sonnet-4-6",
"sonnet_4_6",
"sonnet-4.6",
"sonnet_4.6",
)
)
def is_effort_used(
self, optional_params: Optional[dict], model: Optional[str] = None
) -> bool:
"""
Check if effort parameter is being used and requires a beta header.
Returns True if effort-related parameters are present and
the model requires the effort beta header. Claude 4.6 models
use output_config as a stable API feature — no beta header needed.
"""
if not optional_params:
return False
# Claude 4.6 models use output_config as a stable API feature — no beta header needed
if model and self._is_claude_4_6_model(model):
return False
# Check if reasoning_effort is provided for Claude Opus 4.5
if model and ("opus-4-5" in model.lower() or "opus_4_5" in model.lower()):
reasoning_effort = optional_params.get("reasoning_effort")
if reasoning_effort and isinstance(reasoning_effort, str):
return True
# Check if output_config is directly provided (for non-4.6 models)
output_config = optional_params.get("output_config")
if output_config and isinstance(output_config, dict):
effort = output_config.get("effort")
if effort and isinstance(effort, str):
return True
return False
def is_code_execution_tool_used(self, tools: Optional[List]) -> bool:
"""
Check if code execution tool is being used.
Returns True if any tool has type "code_execution_20250825".
"""
if not tools:
return False
for tool in tools:
tool_type = tool.get("type", "")
if tool_type == "code_execution_20250825":
return True
return False
def is_container_with_skills_used(self, optional_params: Optional[dict]) -> bool:
"""
Check if container with skills is being used.
Returns True if optional_params contains container with skills.
"""
if not optional_params:
return False
container = optional_params.get("container")
if container and isinstance(container, dict):
skills = container.get("skills")
if skills and isinstance(skills, list) and len(skills) > 0:
return True
return False
def _get_user_anthropic_beta_headers(
self, anthropic_beta_header: Optional[str]
) -> Optional[List[str]]:
if anthropic_beta_header is None:
return None
return anthropic_beta_header.split(",")
def get_computer_tool_beta_header(self, computer_tool_version: str) -> str:
"""
Get the appropriate beta header for a given computer tool version.
Args:
computer_tool_version: The computer tool version (e.g., 'computer_20250124', 'computer_20241022')
Returns:
The corresponding beta header string
"""
computer_tool_beta_mapping = {
"computer_20250124": "computer-use-2025-01-24",
"computer_20241022": "computer-use-2024-10-22",
}
return computer_tool_beta_mapping.get(
computer_tool_version, "computer-use-2024-10-22" # Default fallback
)
def get_anthropic_beta_list(
self,
model: str,
optional_params: Optional[dict] = None,
computer_tool_used: Optional[str] = None,
prompt_caching_set: bool = False,
file_id_used: bool = False,
mcp_server_used: bool = False,
) -> List[str]:
"""
Get list of common beta headers based on the features that are active.
Returns:
List of beta header strings
"""
from litellm.types.llms.anthropic import (
ANTHROPIC_EFFORT_BETA_HEADER,
)
betas = []
# Detect features
effort_used = self.is_effort_used(optional_params, model)
if effort_used:
betas.append(ANTHROPIC_EFFORT_BETA_HEADER) # effort-2025-11-24
if computer_tool_used:
beta_header = self.get_computer_tool_beta_header(computer_tool_used)
betas.append(beta_header)
# Anthropic no longer requires the prompt-caching beta header
# Prompt caching now works automatically when cache_control is used in messages
# Reference: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
if file_id_used:
betas.append("files-api-2025-04-14")
betas.append("code-execution-2025-05-22")
if mcp_server_used:
betas.append("mcp-client-2025-04-04")
return list(set(betas))
def get_anthropic_headers(
self,
api_key: str,
anthropic_version: Optional[str] = None,
computer_tool_used: Optional[str] = None,
prompt_caching_set: bool = False,
pdf_used: bool = False,
file_id_used: bool = False,
mcp_server_used: bool = False,
web_search_tool_used: bool = False,
tool_search_used: bool = False,
programmatic_tool_calling_used: bool = False,
input_examples_used: bool = False,
effort_used: bool = False,
is_vertex_request: bool = False,
user_anthropic_beta_headers: Optional[List[str]] = None,
code_execution_tool_used: bool = False,
container_with_skills_used: bool = False,
) -> dict:
betas = set()
# Anthropic no longer requires the prompt-caching beta header
# Prompt caching now works automatically when cache_control is used in messages
# Reference: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
if computer_tool_used:
beta_header = self.get_computer_tool_beta_header(computer_tool_used)
betas.add(beta_header)
# if pdf_used:
# betas.add("pdfs-2024-09-25")
if file_id_used:
betas.add("files-api-2025-04-14")
betas.add("code-execution-2025-05-22")
if mcp_server_used:
betas.add("mcp-client-2025-04-04")
# Tool search, programmatic tool calling, and input_examples all use the same beta header
if tool_search_used or programmatic_tool_calling_used or input_examples_used:
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
betas.add(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
# Effort parameter uses a separate beta header
if effort_used:
from litellm.types.llms.anthropic import ANTHROPIC_EFFORT_BETA_HEADER
betas.add(ANTHROPIC_EFFORT_BETA_HEADER)
# Code execution tool uses a separate beta header
if code_execution_tool_used:
betas.add("code-execution-2025-08-25")
# Container with skills uses a separate beta header
if container_with_skills_used:
betas.add("skills-2025-10-02")
_is_oauth = api_key and api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
headers = {
"anthropic-version": anthropic_version or "2023-06-01",
"accept": "application/json",
"content-type": "application/json",
}
if _is_oauth:
headers["authorization"] = f"Bearer {api_key}"
headers["anthropic-dangerous-direct-browser-access"] = "true"
betas.add(ANTHROPIC_OAUTH_BETA_HEADER)
else:
headers["x-api-key"] = api_key
if user_anthropic_beta_headers is not None:
betas.update(user_anthropic_beta_headers)
# Don't send any beta headers to Vertex, except web search which is required
if is_vertex_request is True:
# Vertex AI requires web search beta header for web search to work
if web_search_tool_used:
from litellm.types.llms.anthropic import ANTHROPIC_BETA_HEADER_VALUES
headers[
"anthropic-beta"
] = ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value
elif len(betas) > 0:
headers["anthropic-beta"] = ",".join(betas)
return headers
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
# Check for Anthropic OAuth token in headers
headers, api_key = optionally_handle_anthropic_oauth(
headers=headers, api_key=api_key
)
if api_key is None:
raise litellm.AuthenticationError(
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
llm_provider="anthropic",
model=model,
)
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
mcp_server_used = self.is_mcp_server_used(
mcp_servers=optional_params.get("mcp_servers")
)
pdf_used = self.is_pdf_used(messages=messages)
file_id_used = self.is_file_id_used(messages=messages)
web_search_tool_used = self.is_web_search_tool_used(tools=tools)
tool_search_used = self.is_tool_search_used(tools=tools)
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(
tools=tools
)
input_examples_used = self.is_input_examples_used(tools=tools)
effort_used = self.is_effort_used(optional_params=optional_params, model=model)
code_execution_tool_used = self.is_code_execution_tool_used(tools=tools)
container_with_skills_used = self.is_container_with_skills_used(
optional_params=optional_params
)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key,
file_id_used=file_id_used,
web_search_tool_used=web_search_tool_used,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
mcp_server_used=mcp_server_used,
tool_search_used=tool_search_used,
programmatic_tool_calling_used=programmatic_tool_calling_used,
input_examples_used=input_examples_used,
effort_used=effort_used,
code_execution_tool_used=code_execution_tool_used,
container_with_skills_used=container_with_skills_used,
)
headers = {**headers, **anthropic_headers}
return headers
@staticmethod
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
from litellm.secret_managers.main import get_secret_str
return (
api_base
or get_secret_str("ANTHROPIC_API_BASE")
or "https://api.anthropic.com"
)
@staticmethod
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
from litellm.secret_managers.main import get_secret_str
return api_key or get_secret_str("ANTHROPIC_API_KEY")
@staticmethod
def get_base_model(model: Optional[str] = None) -> Optional[str]:
return model.replace("anthropic/", "") if model else None
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
api_base = AnthropicModelInfo.get_api_base(api_base)
api_key = AnthropicModelInfo.get_api_key(api_key)
if api_base is None or api_key is None:
raise ValueError(
"ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint."
)
response = litellm.module_level_client.get(
url=f"{api_base}/v1/models",
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
)
try:
response.raise_for_status()
except httpx.HTTPStatusError:
raise Exception(
f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}"
)
models = response.json()["data"]
litellm_model_names = []
for model in models:
stripped_model_name = model["id"]
litellm_model_name = "anthropic/" + stripped_model_name
litellm_model_names.append(litellm_model_name)
return litellm_model_names
def get_token_counter(self) -> Optional[BaseTokenCounter]:
"""
Factory method to create an Anthropic token counter.
Returns:
AnthropicTokenCounter instance for this provider.
"""
from litellm.llms.anthropic.count_tokens.token_counter import (
AnthropicTokenCounter,
)
return AnthropicTokenCounter()
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "anthropic-ratelimit-requests-limit" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"anthropic-ratelimit-requests-limit"
]
if "anthropic-ratelimit-requests-remaining" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"anthropic-ratelimit-requests-remaining"
]
if "anthropic-ratelimit-tokens-limit" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers[
"anthropic-ratelimit-tokens-limit"
]
if "anthropic-ratelimit-tokens-remaining" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"anthropic-ratelimit-tokens-remaining"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
additional_headers = {**llm_response_headers, **openai_headers}
return additional_headers

View File

@@ -0,0 +1,5 @@
"""
Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests
Request/Response transformation is handled in `transformation.py`
"""

View File

@@ -0,0 +1,310 @@
"""
Translation logic for anthropic's `/v1/complete` endpoint
Litellm provider slug: `anthropic_text/<model_name>`
"""
import json
import time
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
import httpx
import litellm
from litellm.constants import DEFAULT_MAX_TOKENS
from litellm.litellm_core_utils.prompt_templates.factory import (
custom_prompt,
prompt_factory,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.base_llm.chat.transformation import (
BaseConfig,
BaseLLMException,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionToolCallChunk,
ChatCompletionUsageBlock,
GenericStreamingChunk,
ModelResponse,
Usage,
)
class AnthropicTextError(BaseLLMException):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url="https://api.anthropic.com/v1/complete"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
message=self.message,
status_code=self.status_code,
request=self.request,
response=self.response,
) # Call the base class constructor with the parameters it needs
class AnthropicTextConfig(BaseConfig):
"""
Reference: https://docs.anthropic.com/claude/reference/complete_post
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
"""
max_tokens_to_sample: Optional[
int
] = litellm.max_tokens # anthropic requires a default
stop_sequences: Optional[list] = None
temperature: Optional[int] = None
top_p: Optional[int] = None
top_k: Optional[int] = None
metadata: Optional[dict] = None
def __init__(
self,
max_tokens_to_sample: Optional[
int
] = DEFAULT_MAX_TOKENS, # anthropic requires a default
stop_sequences: Optional[list] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
top_k: Optional[int] = None,
metadata: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
# makes headers for API call
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
if api_key is None:
raise ValueError(
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
)
_headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"content-type": "application/json",
"x-api-key": api_key,
}
headers.update(_headers)
return headers
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
prompt = self._get_anthropic_text_prompt_from_messages(
messages=messages, model=model
)
## Load Config
config = litellm.AnthropicTextConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
data = {
"model": model,
"prompt": prompt,
**optional_params,
}
return data
def get_supported_openai_params(self, model: str):
"""
Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete
"""
return [
"stream",
"max_tokens",
"max_completion_tokens",
"stop",
"temperature",
"top_p",
"extra_headers",
"user",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API)
Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig
API Ref: https://docs.anthropic.com/en/api/complete
"""
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "max_completion_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
_value = litellm.AnthropicConfig()._map_stop_sequences(value)
if _value is not None:
optional_params["stop_sequences"] = _value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "user":
optional_params["metadata"] = {"user_id": value}
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: str,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:
raise AnthropicTextError(
message=raw_response.text, status_code=raw_response.status_code
)
prompt = self._get_anthropic_text_prompt_from_messages(
messages=messages, model=model
)
if "error" in completion_response:
raise AnthropicTextError(
message=str(completion_response["error"]),
status_code=raw_response.status_code,
)
else:
if len(completion_response["completion"]) > 0:
model_response.choices[0].message.content = completion_response[ # type: ignore
"completion"
]
model_response.choices[0].finish_reason = completion_response["stop_reason"]
## CALCULATING USAGE
prompt_tokens = len(
encoding.encode(prompt)
) ##[TODO] use the anthropic tokenizer here
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
) ##[TODO] use the anthropic tokenizer here
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
) -> BaseLLMException:
return AnthropicTextError(
status_code=status_code,
message=error_message,
)
@staticmethod
def _is_anthropic_text_model(model: str) -> bool:
return model == "claude-2" or model == "claude-instant-1"
def _get_anthropic_text_prompt_from_messages(
self, messages: List[AllMessageValues], model: str
) -> str:
custom_prompt_dict = litellm.custom_prompt_dict
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
return str(prompt)
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
):
return AnthropicTextCompletionResponseIterator(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator):
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
try:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
provider_specific_fields = None
index = int(chunk.get("index", 0))
_chunk_text = chunk.get("completion", None)
if _chunk_text is not None and isinstance(_chunk_text, str):
text = _chunk_text
finish_reason = chunk.get("stop_reason") or ""
if finish_reason is not None:
is_finished = True
returned_chunk = GenericStreamingChunk(
text=text,
tool_use=tool_use,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=index,
provider_specific_fields=provider_specific_fields,
)
return returned_chunk
except json.JSONDecodeError:
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")

View File

@@ -0,0 +1,132 @@
"""
Helper util for handling anthropic-specific cost calculation
- e.g.: prompt caching
"""
from typing import TYPE_CHECKING, Optional, Tuple
from litellm.litellm_core_utils.llm_cost_calc.utils import (
_get_token_base_cost,
_parse_prompt_tokens_details,
calculate_cache_writing_cost,
generic_cost_per_token,
)
if TYPE_CHECKING:
from litellm.types.utils import ModelInfo, Usage
import litellm
def _compute_cache_only_cost(model_info: "ModelInfo", usage: "Usage") -> float:
"""
Return only the cache-related portion of the prompt cost (cache read + cache write).
These costs must NOT be scaled by geo/speed multipliers because the old
explicit ``fast/`` model entries carried unchanged cache rates while
multiplying only the regular input/output token costs.
"""
if usage.prompt_tokens_details is None:
return 0.0
prompt_tokens_details = _parse_prompt_tokens_details(usage)
(
_,
_,
cache_creation_cost,
cache_creation_cost_above_1hr,
cache_read_cost,
) = _get_token_base_cost(model_info=model_info, usage=usage)
cache_cost = float(prompt_tokens_details["cache_hit_tokens"]) * cache_read_cost
if (
prompt_tokens_details["cache_creation_tokens"]
or prompt_tokens_details["cache_creation_token_details"] is not None
):
cache_cost += calculate_cache_writing_cost(
cache_creation_tokens=prompt_tokens_details["cache_creation_tokens"],
cache_creation_token_details=prompt_tokens_details[
"cache_creation_token_details"
],
cache_creation_cost_above_1hr=cache_creation_cost_above_1hr,
cache_creation_cost=cache_creation_cost,
)
return cache_cost
def cost_per_token(model: str, usage: "Usage") -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing anthropic caching information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
prompt_cost, completion_cost = generic_cost_per_token(
model=model, usage=usage, custom_llm_provider="anthropic"
)
# Apply provider_specific_entry multipliers for geo/speed routing
try:
model_info = litellm.get_model_info(
model=model, custom_llm_provider="anthropic"
)
provider_specific_entry: dict = model_info.get("provider_specific_entry") or {}
multiplier = 1.0
if (
hasattr(usage, "inference_geo")
and usage.inference_geo
and usage.inference_geo.lower() not in ["global", "not_available"]
):
multiplier *= provider_specific_entry.get(usage.inference_geo.lower(), 1.0)
if hasattr(usage, "speed") and usage.speed == "fast":
multiplier *= provider_specific_entry.get("fast", 1.0)
if multiplier != 1.0:
cache_cost = _compute_cache_only_cost(model_info=model_info, usage=usage)
prompt_cost = (prompt_cost - cache_cost) * multiplier + cache_cost
completion_cost *= multiplier
except Exception:
pass
return prompt_cost, completion_cost
def get_cost_for_anthropic_web_search(
model_info: Optional["ModelInfo"] = None,
usage: Optional["Usage"] = None,
) -> float:
"""
Get the cost of using a web search tool for Anthropic.
"""
from litellm.types.utils import SearchContextCostPerQuery
## Check if web search requests are in the usage object
if model_info is None:
return 0.0
if (
usage is None
or usage.server_tool_use is None
or usage.server_tool_use.web_search_requests is None
):
return 0.0
## Get the cost per web search request
search_context_pricing: SearchContextCostPerQuery = (
model_info.get("search_context_cost_per_query") or SearchContextCostPerQuery()
)
cost_per_web_search_request = search_context_pricing.get(
"search_context_size_medium", 0.0
)
if cost_per_web_search_request is None or cost_per_web_search_request == 0.0:
return 0.0
## Calculate the total cost
total_cost = cost_per_web_search_request * usage.server_tool_use.web_search_requests
return total_cost

View File

@@ -0,0 +1,15 @@
"""
Anthropic CountTokens API implementation.
"""
from litellm.llms.anthropic.count_tokens.handler import AnthropicCountTokensHandler
from litellm.llms.anthropic.count_tokens.token_counter import AnthropicTokenCounter
from litellm.llms.anthropic.count_tokens.transformation import (
AnthropicCountTokensConfig,
)
__all__ = [
"AnthropicCountTokensHandler",
"AnthropicCountTokensConfig",
"AnthropicTokenCounter",
]

View File

@@ -0,0 +1,128 @@
"""
Anthropic CountTokens API handler.
Uses httpx for HTTP requests instead of the Anthropic SDK.
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.anthropic.common_utils import AnthropicError
from litellm.llms.anthropic.count_tokens.transformation import (
AnthropicCountTokensConfig,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
class AnthropicCountTokensHandler(AnthropicCountTokensConfig):
"""
Handler for Anthropic CountTokens API requests.
Uses httpx for HTTP requests, following the same pattern as BedrockCountTokensHandler.
"""
async def handle_count_tokens_request(
self,
model: str,
messages: List[Dict[str, Any]],
api_key: str,
api_base: Optional[str] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Handle a CountTokens request using httpx.
Args:
model: The model identifier (e.g., "claude-3-5-sonnet-20241022")
messages: The messages to count tokens for
api_key: The Anthropic API key
api_base: Optional custom API base URL
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
Returns:
Dictionary containing token count response
Raises:
AnthropicError: If the API request fails
"""
try:
# Validate the request
self.validate_request(model, messages)
verbose_logger.debug(
f"Processing Anthropic CountTokens request for model: {model}"
)
# Transform request to Anthropic format
request_body = self.transform_request_to_count_tokens(
model=model,
messages=messages,
tools=tools,
system=system,
)
verbose_logger.debug(f"Transformed request: {request_body}")
# Get endpoint URL
endpoint_url = api_base or self.get_anthropic_count_tokens_endpoint()
verbose_logger.debug(f"Making request to: {endpoint_url}")
# Get required headers
headers = self.get_required_headers(api_key)
# Use LiteLLM's async httpx client
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.ANTHROPIC
)
# Use provided timeout or fall back to litellm.request_timeout
request_timeout = (
timeout if timeout is not None else litellm.request_timeout
)
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_body,
timeout=request_timeout,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"Anthropic API error: {error_text}")
raise AnthropicError(
status_code=response.status_code,
message=error_text,
)
anthropic_response = response.json()
verbose_logger.debug(f"Anthropic response: {anthropic_response}")
# Return Anthropic response directly - no transformation needed
return anthropic_response
except AnthropicError:
# Re-raise Anthropic exceptions as-is
raise
except httpx.HTTPStatusError as e:
# HTTP errors - preserve the actual status code
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=e.response.status_code,
message=e.response.text,
)
except Exception as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)

View File

@@ -0,0 +1,108 @@
"""
Anthropic Token Counter implementation using the CountTokens API.
"""
import os
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.anthropic.count_tokens.handler import AnthropicCountTokensHandler
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.types.utils import LlmProviders, TokenCountResponse
# Global handler instance - reuse across all token counting requests
anthropic_count_tokens_handler = AnthropicCountTokensHandler()
class AnthropicTokenCounter(BaseTokenCounter):
"""Token counter implementation for Anthropic provider using the CountTokens API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
return custom_llm_provider == LlmProviders.ANTHROPIC.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using Anthropic's CountTokens API.
Args:
model_to_use: The model identifier
messages: The messages to count tokens for
contents: Alternative content format (not used for Anthropic)
deployment: Deployment configuration containing litellm_params
request_model: The original request model name
Returns:
TokenCountResponse with token count, or None if counting fails
"""
from litellm.llms.anthropic.common_utils import AnthropicError
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Get Anthropic API key from deployment config or environment
api_key = litellm_params.get("api_key")
if not api_key:
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
verbose_logger.warning("No Anthropic API key found for token counting")
return None
try:
result = await anthropic_count_tokens_handler.handle_count_tokens_request(
model=model_to_use,
messages=messages,
api_key=api_key,
tools=tools,
system=system,
)
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="anthropic_api",
original_response=result,
)
except AnthropicError as e:
verbose_logger.warning(
f"Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="anthropic_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(f"Error calling Anthropic CountTokens API: {e}")
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="anthropic_api",
error=True,
error_message=str(e),
status_code=500,
)
return None

View File

@@ -0,0 +1,107 @@
"""
Anthropic CountTokens API transformation logic.
This module handles the transformation of requests to Anthropic's CountTokens API format.
"""
from typing import Any, Dict, List, Optional
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
class AnthropicCountTokensConfig:
"""
Configuration and transformation logic for Anthropic CountTokens API.
Anthropic CountTokens API Specification:
- Endpoint: POST https://api.anthropic.com/v1/messages/count_tokens
- Beta header required: anthropic-beta: token-counting-2024-11-01
- Response: {"input_tokens": <number>}
"""
def get_anthropic_count_tokens_endpoint(self) -> str:
"""
Get the Anthropic CountTokens API endpoint.
Returns:
The endpoint URL for the CountTokens API
"""
return "https://api.anthropic.com/v1/messages/count_tokens"
def transform_request_to_count_tokens(
self,
model: str,
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Transform request to Anthropic CountTokens format.
Includes optional system and tools fields for accurate token counting.
"""
request: Dict[str, Any] = {
"model": model,
"messages": messages,
}
if system is not None:
request["system"] = system
if tools is not None:
request["tools"] = tools
return request
def get_required_headers(self, api_key: str) -> Dict[str, str]:
"""
Get the required headers for the CountTokens API.
Args:
api_key: The Anthropic API key
Returns:
Dictionary of required headers
"""
from litellm.llms.anthropic.common_utils import (
optionally_handle_anthropic_oauth,
)
headers: Dict[str, str] = {
"Content-Type": "application/json",
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
}
headers, _ = optionally_handle_anthropic_oauth(headers=headers, api_key=api_key)
return headers
def validate_request(self, model: str, messages: List[Dict[str, Any]]) -> None:
"""
Validate the incoming count tokens request.
Args:
model: The model name
messages: The messages to count tokens for
Raises:
ValueError: If the request is invalid
"""
if not model:
raise ValueError("model parameter is required")
if not messages:
raise ValueError("messages parameter is required")
if not isinstance(messages, list):
raise ValueError("messages must be a list")
for i, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"Message {i} must be a dictionary")
if "role" not in message:
raise ValueError(f"Message {i} must have a 'role' field")
if "content" not in message:
raise ValueError(f"Message {i} must have a 'content' field")

View File

@@ -0,0 +1,3 @@
from .transformation import LiteLLMAnthropicMessagesAdapter
__all__ = ["LiteLLMAnthropicMessagesAdapter"]

View File

@@ -0,0 +1,345 @@
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Coroutine,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
import litellm
from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import (
AnthropicAdapter,
)
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
from litellm.types.utils import ModelResponse
from litellm.utils import get_model_info
if TYPE_CHECKING:
pass
########################################################
# init adapter
ANTHROPIC_ADAPTER = AnthropicAdapter()
########################################################
class LiteLLMMessagesToCompletionTransformationHandler:
@staticmethod
def _route_openai_thinking_to_responses_api_if_needed(
completion_kwargs: Dict[str, Any],
*,
thinking: Optional[Dict[str, Any]],
) -> None:
"""
When users call `litellm.anthropic.messages.*` with a non-Anthropic model and
`thinking={"type": "enabled", ...}`, LiteLLM converts this into OpenAI
`reasoning_effort`.
For OpenAI models, Chat Completions typically does not return reasoning text
(only token accounting). To return a thinking-like content block in the
Anthropic response format, we route the request through OpenAI's Responses API
and request a reasoning summary.
"""
custom_llm_provider = completion_kwargs.get("custom_llm_provider")
if custom_llm_provider is None:
try:
_, inferred_provider, _, _ = litellm.utils.get_llm_provider(
model=cast(str, completion_kwargs.get("model"))
)
custom_llm_provider = inferred_provider
except Exception:
custom_llm_provider = None
if custom_llm_provider != "openai":
return
if not isinstance(thinking, dict) or thinking.get("type") != "enabled":
return
model = completion_kwargs.get("model")
try:
model_info = get_model_info(
model=cast(str, model), custom_llm_provider=custom_llm_provider
)
if model_info and model_info.get("supports_reasoning") is False:
# Model doesn't support reasoning/responses API, don't route
return
except Exception:
pass
if isinstance(model, str) and model and not model.startswith("responses/"):
# Prefix model with "responses/" to route to OpenAI Responses API
completion_kwargs["model"] = f"responses/{model}"
reasoning_effort = completion_kwargs.get("reasoning_effort")
if isinstance(reasoning_effort, str) and reasoning_effort:
completion_kwargs["reasoning_effort"] = {
"effort": reasoning_effort,
"summary": "detailed",
}
elif isinstance(reasoning_effort, dict):
if (
"summary" not in reasoning_effort
and "generate_summary" not in reasoning_effort
):
updated_reasoning_effort = dict(reasoning_effort)
updated_reasoning_effort["summary"] = "detailed"
completion_kwargs["reasoning_effort"] = updated_reasoning_effort
@staticmethod
def _prepare_completion_kwargs(
*,
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
extra_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[Dict[str, Any], Dict[str, str]]:
"""Prepare kwargs for litellm.completion/acompletion.
Returns:
Tuple of (completion_kwargs, tool_name_mapping)
- tool_name_mapping maps truncated tool names back to original names
for tools that exceeded OpenAI's 64-char limit
"""
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObject,
)
request_data = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
}
if metadata:
request_data["metadata"] = metadata
if stop_sequences:
request_data["stop_sequences"] = stop_sequences
if system:
request_data["system"] = system
if temperature is not None:
request_data["temperature"] = temperature
if thinking:
request_data["thinking"] = thinking
if tool_choice:
request_data["tool_choice"] = tool_choice
if tools:
request_data["tools"] = tools
if top_k is not None:
request_data["top_k"] = top_k
if top_p is not None:
request_data["top_p"] = top_p
if output_format:
request_data["output_format"] = output_format
(
openai_request,
tool_name_mapping,
) = ANTHROPIC_ADAPTER.translate_completion_input_params_with_tool_mapping(
request_data
)
if openai_request is None:
raise ValueError("Failed to translate request to OpenAI format")
completion_kwargs: Dict[str, Any] = dict(openai_request)
if stream:
completion_kwargs["stream"] = stream
completion_kwargs["stream_options"] = {
"include_usage": True,
}
excluded_keys = {"anthropic_messages"}
extra_kwargs = extra_kwargs or {}
for key, value in extra_kwargs.items():
if (
key == "litellm_logging_obj"
and value is not None
and isinstance(value, LiteLLMLoggingObject)
):
from litellm.types.utils import CallTypes
setattr(value, "call_type", CallTypes.completion.value)
setattr(
value, "stream_options", completion_kwargs.get("stream_options")
)
if (
key not in excluded_keys
and key not in completion_kwargs
and value is not None
):
completion_kwargs[key] = value
LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed(
completion_kwargs,
thinking=thinking,
)
return completion_kwargs, tool_name_mapping
@staticmethod
async def async_anthropic_messages_handler(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
**kwargs,
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
"""Handle non-Anthropic models asynchronously using the adapter"""
(
completion_kwargs,
tool_name_mapping,
) = LiteLLMMessagesToCompletionTransformationHandler._prepare_completion_kwargs(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
extra_kwargs=kwargs,
)
completion_response = await litellm.acompletion(**completion_kwargs)
if stream:
transformed_stream = (
ANTHROPIC_ADAPTER.translate_completion_output_params_streaming(
completion_response,
model=model,
tool_name_mapping=tool_name_mapping,
)
)
if transformed_stream is not None:
return transformed_stream
raise ValueError("Failed to transform streaming response")
else:
anthropic_response = ANTHROPIC_ADAPTER.translate_completion_output_params(
cast(ModelResponse, completion_response),
tool_name_mapping=tool_name_mapping,
)
if anthropic_response is not None:
return anthropic_response
raise ValueError("Failed to transform response to Anthropic format")
@staticmethod
def anthropic_messages_handler(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
_is_async: bool = False,
**kwargs,
) -> Union[
AnthropicMessagesResponse,
AsyncIterator[Any],
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
]:
"""Handle non-Anthropic models using the adapter."""
if _is_async is True:
return LiteLLMMessagesToCompletionTransformationHandler.async_anthropic_messages_handler(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
**kwargs,
)
(
completion_kwargs,
tool_name_mapping,
) = LiteLLMMessagesToCompletionTransformationHandler._prepare_completion_kwargs(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
extra_kwargs=kwargs,
)
completion_response = litellm.completion(**completion_kwargs)
if stream:
transformed_stream = (
ANTHROPIC_ADAPTER.translate_completion_output_params_streaming(
completion_response,
model=model,
tool_name_mapping=tool_name_mapping,
)
)
if transformed_stream is not None:
return transformed_stream
raise ValueError("Failed to transform streaming response")
else:
anthropic_response = ANTHROPIC_ADAPTER.translate_completion_output_params(
cast(ModelResponse, completion_response),
tool_name_mapping=tool_name_mapping,
)
if anthropic_response is not None:
return anthropic_response
raise ValueError("Failed to transform response to Anthropic format")

View File

@@ -0,0 +1,488 @@
# What is this?
## Translates OpenAI call to Anthropic `/v1/messages` format
import json
import traceback
from collections import deque
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, Literal, Optional
from litellm import verbose_logger
from litellm._uuid import uuid
from litellm.types.llms.anthropic import UsageDelta
from litellm.types.utils import AdapterCompletionStreamWrapper
if TYPE_CHECKING:
from litellm.types.utils import ModelResponseStream
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
"""
- first chunk return 'message_start'
- content block must be started and stopped
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
"""
from litellm.types.llms.anthropic import (
ContentBlockContentBlockDict,
ContentBlockStart,
ContentBlockStartText,
TextBlock,
)
sent_first_chunk: bool = False
sent_content_block_start: bool = False
sent_content_block_finish: bool = False
current_content_block_type: Literal["text", "tool_use", "thinking"] = "text"
sent_last_message: bool = False
holding_chunk: Optional[Any] = None
holding_stop_reason_chunk: Optional[Any] = None
queued_usage_chunk: bool = False
current_content_block_index: int = 0
current_content_block_start: ContentBlockContentBlockDict = TextBlock(
type="text",
text="",
)
chunk_queue: deque = deque() # Queue for buffering multiple chunks
def __init__(
self,
completion_stream: Any,
model: str,
tool_name_mapping: Optional[Dict[str, str]] = None,
):
super().__init__(completion_stream)
self.model = model
# Mapping of truncated tool names to original names (for OpenAI's 64-char limit)
self.tool_name_mapping = tool_name_mapping or {}
def _create_initial_usage_delta(self) -> UsageDelta:
"""
Create the initial UsageDelta for the message_start event.
Initializes cache token fields (cache_creation_input_tokens, cache_read_input_tokens)
to 0 to indicate to clients (like Claude Code) that prompt caching is supported.
The actual cache token values will be provided in the message_delta event at the
end of the stream, since Bedrock Converse API only returns usage data in the final
response chunk.
Returns:
UsageDelta with all token counts initialized to 0.
"""
return UsageDelta(
input_tokens=0,
output_tokens=0,
cache_creation_input_tokens=0,
cache_read_input_tokens=0,
)
def __next__(self):
from .transformation import LiteLLMAnthropicMessagesAdapter
try:
# Always return queued chunks first
if self.chunk_queue:
return self.chunk_queue.popleft()
# Queue initial chunks if not sent yet
if self.sent_first_chunk is False:
self.sent_first_chunk = True
self.chunk_queue.append(
{
"type": "message_start",
"message": {
"id": "msg_{}".format(uuid.uuid4()),
"type": "message",
"role": "assistant",
"content": [],
"model": self.model,
"stop_reason": None,
"stop_sequence": None,
"usage": self._create_initial_usage_delta(),
},
}
)
return self.chunk_queue.popleft()
if self.sent_content_block_start is False:
self.sent_content_block_start = True
self.chunk_queue.append(
{
"type": "content_block_start",
"index": self.current_content_block_index,
"content_block": {"type": "text", "text": ""},
}
)
return self.chunk_queue.popleft()
for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
should_start_new_block = self._should_start_new_content_block(chunk)
if should_start_new_block:
self._increment_content_block_index()
processed_chunk = LiteLLMAnthropicMessagesAdapter().translate_streaming_openai_response_to_anthropic(
response=chunk,
current_content_block_index=self.current_content_block_index,
)
if should_start_new_block and not self.sent_content_block_finish:
# Queue the sequence: content_block_stop -> content_block_start
# The trigger chunk itself is not emitted as a delta since the
# content_block_start already carries the relevant information.
self.chunk_queue.append(
{
"type": "content_block_stop",
"index": max(self.current_content_block_index - 1, 0),
}
)
self.chunk_queue.append(
{
"type": "content_block_start",
"index": self.current_content_block_index,
"content_block": self.current_content_block_start,
}
)
self.sent_content_block_finish = False
return self.chunk_queue.popleft()
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
# Queue both the content_block_stop and the message_delta
self.chunk_queue.append(
{
"type": "content_block_stop",
"index": self.current_content_block_index,
}
)
self.sent_content_block_finish = True
self.chunk_queue.append(processed_chunk)
return self.chunk_queue.popleft()
elif self.holding_chunk is not None:
self.chunk_queue.append(self.holding_chunk)
self.chunk_queue.append(processed_chunk)
self.holding_chunk = None
return self.chunk_queue.popleft()
else:
self.chunk_queue.append(processed_chunk)
return self.chunk_queue.popleft()
# Handle any remaining held chunks after stream ends
if self.holding_chunk is not None:
self.chunk_queue.append(self.holding_chunk)
self.holding_chunk = None
if not self.sent_last_message:
self.sent_last_message = True
self.chunk_queue.append({"type": "message_stop"})
if self.chunk_queue:
return self.chunk_queue.popleft()
raise StopIteration
except StopIteration:
if self.chunk_queue:
return self.chunk_queue.popleft()
if self.sent_last_message is False:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopIteration
except Exception as e:
verbose_logger.error(
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
)
raise StopAsyncIteration
async def __anext__(self): # noqa: PLR0915
from .transformation import LiteLLMAnthropicMessagesAdapter
try:
# Always return queued chunks first
if self.chunk_queue:
return self.chunk_queue.popleft()
# Queue initial chunks if not sent yet
if self.sent_first_chunk is False:
self.sent_first_chunk = True
self.chunk_queue.append(
{
"type": "message_start",
"message": {
"id": "msg_{}".format(uuid.uuid4()),
"type": "message",
"role": "assistant",
"content": [],
"model": self.model,
"stop_reason": None,
"stop_sequence": None,
"usage": self._create_initial_usage_delta(),
},
}
)
return self.chunk_queue.popleft()
if self.sent_content_block_start is False:
self.sent_content_block_start = True
self.chunk_queue.append(
{
"type": "content_block_start",
"index": self.current_content_block_index,
"content_block": {"type": "text", "text": ""},
}
)
return self.chunk_queue.popleft()
async for chunk in self.completion_stream:
if chunk == "None" or chunk is None:
raise Exception
# Check if we need to start a new content block
should_start_new_block = self._should_start_new_content_block(chunk)
if should_start_new_block:
self._increment_content_block_index()
processed_chunk = LiteLLMAnthropicMessagesAdapter().translate_streaming_openai_response_to_anthropic(
response=chunk,
current_content_block_index=self.current_content_block_index,
)
# Check if this is a usage chunk and we have a held stop_reason chunk
if (
self.holding_stop_reason_chunk is not None
and getattr(chunk, "usage", None) is not None
):
# Merge usage into the held stop_reason chunk
merged_chunk = self.holding_stop_reason_chunk.copy()
if "delta" not in merged_chunk:
merged_chunk["delta"] = {}
# Add usage to the held chunk
uncached_input_tokens = chunk.usage.prompt_tokens or 0
if (
hasattr(chunk.usage, "prompt_tokens_details")
and chunk.usage.prompt_tokens_details
):
cached_tokens = (
getattr(
chunk.usage.prompt_tokens_details, "cached_tokens", 0
)
or 0
)
uncached_input_tokens -= cached_tokens
usage_dict: UsageDelta = {
"input_tokens": uncached_input_tokens,
"output_tokens": chunk.usage.completion_tokens or 0,
}
# Add cache tokens if available (for prompt caching support)
if (
hasattr(chunk.usage, "_cache_creation_input_tokens")
and chunk.usage._cache_creation_input_tokens > 0
):
usage_dict[
"cache_creation_input_tokens"
] = chunk.usage._cache_creation_input_tokens
if (
hasattr(chunk.usage, "_cache_read_input_tokens")
and chunk.usage._cache_read_input_tokens > 0
):
usage_dict[
"cache_read_input_tokens"
] = chunk.usage._cache_read_input_tokens
merged_chunk["usage"] = usage_dict
# Queue the merged chunk and reset
self.chunk_queue.append(merged_chunk)
self.queued_usage_chunk = True
self.holding_stop_reason_chunk = None
return self.chunk_queue.popleft()
# Check if this processed chunk has a stop_reason - hold it for next chunk
if not self.queued_usage_chunk:
if should_start_new_block and not self.sent_content_block_finish:
# Queue the sequence: content_block_stop -> content_block_start
# The trigger chunk itself is not emitted as a delta since the
# content_block_start already carries the relevant information.
# 1. Stop current content block
self.chunk_queue.append(
{
"type": "content_block_stop",
"index": max(self.current_content_block_index - 1, 0),
}
)
# 2. Start new content block
self.chunk_queue.append(
{
"type": "content_block_start",
"index": self.current_content_block_index,
"content_block": self.current_content_block_start,
}
)
# Reset state for new block
self.sent_content_block_finish = False
# Return the first queued item
return self.chunk_queue.popleft()
if (
processed_chunk["type"] == "message_delta"
and self.sent_content_block_finish is False
):
# Queue both the content_block_stop and the holding chunk
self.chunk_queue.append(
{
"type": "content_block_stop",
"index": self.current_content_block_index,
}
)
self.sent_content_block_finish = True
if (
processed_chunk.get("delta", {}).get("stop_reason")
is not None
):
self.holding_stop_reason_chunk = processed_chunk
else:
self.chunk_queue.append(processed_chunk)
return self.chunk_queue.popleft()
elif self.holding_chunk is not None:
# Queue both chunks
self.chunk_queue.append(self.holding_chunk)
self.chunk_queue.append(processed_chunk)
self.holding_chunk = None
return self.chunk_queue.popleft()
else:
# Queue the current chunk
self.chunk_queue.append(processed_chunk)
return self.chunk_queue.popleft()
# Handle any remaining held chunks after stream ends
if not self.queued_usage_chunk:
if self.holding_stop_reason_chunk is not None:
self.chunk_queue.append(self.holding_stop_reason_chunk)
self.holding_stop_reason_chunk = None
if self.holding_chunk is not None:
self.chunk_queue.append(self.holding_chunk)
self.holding_chunk = None
if not self.sent_last_message:
self.sent_last_message = True
self.chunk_queue.append({"type": "message_stop"})
# Return queued items if any
if self.chunk_queue:
return self.chunk_queue.popleft()
raise StopIteration
except StopIteration:
# Handle any remaining queued chunks before stopping
if self.chunk_queue:
return self.chunk_queue.popleft()
# Handle any held stop_reason chunk
if self.holding_stop_reason_chunk is not None:
return self.holding_stop_reason_chunk
if not self.sent_last_message:
self.sent_last_message = True
return {"type": "message_stop"}
raise StopAsyncIteration
def anthropic_sse_wrapper(self) -> Iterator[bytes]:
"""
Convert AnthropicStreamWrapper dict chunks to Server-Sent Events format.
Similar to the Bedrock bedrock_sse_wrapper implementation.
This wrapper ensures dict chunks are SSE formatted with both event and data lines.
"""
for chunk in self:
if isinstance(chunk, dict):
event_type: str = str(chunk.get("type", "message"))
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
yield payload.encode()
else:
# For non-dict chunks, forward the original value unchanged
yield chunk
async def async_anthropic_sse_wrapper(self) -> AsyncIterator[bytes]:
"""
Async version of anthropic_sse_wrapper.
Convert AnthropicStreamWrapper dict chunks to Server-Sent Events format.
"""
async for chunk in self:
if isinstance(chunk, dict):
event_type: str = str(chunk.get("type", "message"))
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
yield payload.encode()
else:
# For non-dict chunks, forward the original value unchanged
yield chunk
def _increment_content_block_index(self):
self.current_content_block_index += 1
def _should_start_new_content_block(self, chunk: "ModelResponseStream") -> bool:
"""
Determine if we should start a new content block based on the processed chunk.
Override this method with your specific logic for detecting new content blocks.
Examples of when you might want to start a new content block:
- Switching from text to tool calls
- Different content types in the response
- Specific markers in the content
"""
from .transformation import LiteLLMAnthropicMessagesAdapter
# Example logic - customize based on your needs:
# If chunk indicates a tool call
if chunk.choices[0].finish_reason is not None:
return False
(
block_type,
content_block_start,
) = LiteLLMAnthropicMessagesAdapter()._translate_streaming_openai_chunk_to_anthropic_content_block(
choices=chunk.choices # type: ignore
)
# Restore original tool name if it was truncated for OpenAI's 64-char limit
if block_type == "tool_use":
# Type narrowing: content_block_start is ToolUseBlock when block_type is "tool_use"
from typing import cast
from litellm.types.llms.anthropic import ToolUseBlock
tool_block = cast(ToolUseBlock, content_block_start)
if tool_block.get("name"):
truncated_name = tool_block["name"]
original_name = self.tool_name_mapping.get(
truncated_name, truncated_name
)
tool_block["name"] = original_name
if block_type != self.current_content_block_type:
self.current_content_block_type = block_type
self.current_content_block_start = content_block_start
return True
# For parallel tool calls, we'll necessarily have a new content block
# if we get a function name since it signals a new tool call
if block_type == "tool_use":
from typing import cast
from litellm.types.llms.anthropic import ToolUseBlock
tool_block = cast(ToolUseBlock, content_block_start)
if tool_block.get("name"):
self.current_content_block_type = block_type
self.current_content_block_start = content_block_start
return True
return False

View File

@@ -0,0 +1,51 @@
# Anthropic Messages Pass-Through Architecture
## Request Flow
```mermaid
flowchart TD
A[litellm.anthropic.messages.acreate] --> B{Provider?}
B -->|anthropic| C[AnthropicMessagesConfig]
B -->|azure_ai| D[AzureAnthropicMessagesConfig]
B -->|bedrock invoke| E[BedrockAnthropicMessagesConfig]
B -->|vertex_ai| F[VertexAnthropicMessagesConfig]
B -->|Other providers| G[LiteLLMAnthropicMessagesAdapter]
C --> H[Direct Anthropic API]
D --> I[Azure AI Foundry API]
E --> J[Bedrock Invoke API]
F --> K[Vertex AI API]
G --> L[translate_anthropic_to_openai]
L --> M[litellm.completion]
M --> N[Provider API]
N --> O[translate_openai_response_to_anthropic]
O --> P[Anthropic Response Format]
H --> P
I --> P
J --> P
K --> P
```
## Adapter Flow (Non-Native Providers)
```mermaid
sequenceDiagram
participant User
participant Handler as anthropic_messages_handler
participant Adapter as LiteLLMAnthropicMessagesAdapter
participant LiteLLM as litellm.completion
participant Provider as Provider API
User->>Handler: Anthropic Messages Request
Handler->>Adapter: translate_anthropic_to_openai()
Note over Adapter: messages, tools, thinking,<br/>output_format → response_format
Adapter->>LiteLLM: OpenAI Format Request
LiteLLM->>Provider: Provider-specific Request
Provider->>LiteLLM: Provider Response
LiteLLM->>Adapter: OpenAI Format Response
Adapter->>Handler: translate_openai_response_to_anthropic()
Handler->>User: Anthropic Messages Response
```

View File

@@ -0,0 +1,251 @@
"""
Fake Streaming Iterator for Anthropic Messages
This module provides a fake streaming iterator that converts non-streaming
Anthropic Messages responses into proper streaming format.
Used when WebSearch interception converts stream=True to stream=False but
the LLM doesn't make a tool call, and we need to return a stream to the user.
"""
import json
from typing import Any, Dict, List, cast
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
class FakeAnthropicMessagesStreamIterator:
"""
Fake streaming iterator for Anthropic Messages responses.
Used when we need to convert a non-streaming response to a streaming format,
such as when WebSearch interception converts stream=True to stream=False but
the LLM doesn't make a tool call.
This creates a proper Anthropic-style streaming response with multiple events:
- message_start
- content_block_start (for each content block)
- content_block_delta (for text content, chunked)
- content_block_stop
- message_delta (for usage)
- message_stop
"""
def __init__(self, response: AnthropicMessagesResponse):
self.response = response
self.chunks = self._create_streaming_chunks()
self.current_index = 0
def _create_streaming_chunks(self) -> List[bytes]:
"""Convert the non-streaming response to streaming chunks"""
chunks = []
# Cast response to dict for easier access
response_dict = cast(Dict[str, Any], self.response)
# 1. message_start event
usage = response_dict.get("usage", {})
message_start = {
"type": "message_start",
"message": {
"id": response_dict.get("id"),
"type": "message",
"role": response_dict.get("role", "assistant"),
"model": response_dict.get("model"),
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {
"input_tokens": usage.get("input_tokens", 0) if usage else 0,
"output_tokens": 0,
},
},
}
chunks.append(
f"event: message_start\ndata: {json.dumps(message_start)}\n\n".encode()
)
# 2-4. For each content block, send start/delta/stop events
content_blocks = response_dict.get("content", [])
if content_blocks:
for index, block in enumerate(content_blocks):
# Cast block to dict for easier access
block_dict = cast(Dict[str, Any], block)
block_type = block_dict.get("type")
if block_type == "text":
# content_block_start
content_block_start = {
"type": "content_block_start",
"index": index,
"content_block": {"type": "text", "text": ""},
}
chunks.append(
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
)
# content_block_delta (send full text as one delta for simplicity)
text = block_dict.get("text", "")
content_block_delta = {
"type": "content_block_delta",
"index": index,
"delta": {"type": "text_delta", "text": text},
}
chunks.append(
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
)
# content_block_stop
content_block_stop = {"type": "content_block_stop", "index": index}
chunks.append(
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
)
elif block_type == "thinking":
# content_block_start for thinking
content_block_start = {
"type": "content_block_start",
"index": index,
"content_block": {
"type": "thinking",
"thinking": "",
"signature": "",
},
}
chunks.append(
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
)
# content_block_delta for thinking text
thinking_text = block_dict.get("thinking", "")
if thinking_text:
content_block_delta = {
"type": "content_block_delta",
"index": index,
"delta": {
"type": "thinking_delta",
"thinking": thinking_text,
},
}
chunks.append(
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
)
# content_block_delta for signature (if present)
signature = block_dict.get("signature", "")
if signature:
signature_delta = {
"type": "content_block_delta",
"index": index,
"delta": {
"type": "signature_delta",
"signature": signature,
},
}
chunks.append(
f"event: content_block_delta\ndata: {json.dumps(signature_delta)}\n\n".encode()
)
# content_block_stop
content_block_stop = {"type": "content_block_stop", "index": index}
chunks.append(
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
)
elif block_type == "redacted_thinking":
# content_block_start for redacted_thinking
content_block_start = {
"type": "content_block_start",
"index": index,
"content_block": {"type": "redacted_thinking"},
}
chunks.append(
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
)
# content_block_stop (no delta for redacted thinking)
content_block_stop = {"type": "content_block_stop", "index": index}
chunks.append(
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
)
elif block_type == "tool_use":
# content_block_start
content_block_start = {
"type": "content_block_start",
"index": index,
"content_block": {
"type": "tool_use",
"id": block_dict.get("id"),
"name": block_dict.get("name"),
"input": {},
},
}
chunks.append(
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
)
# content_block_delta (send input as JSON delta)
input_data = block_dict.get("input", {})
content_block_delta = {
"type": "content_block_delta",
"index": index,
"delta": {
"type": "input_json_delta",
"partial_json": json.dumps(input_data),
},
}
chunks.append(
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
)
# content_block_stop
content_block_stop = {"type": "content_block_stop", "index": index}
chunks.append(
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
)
# 5. message_delta event (with final usage and stop_reason)
message_delta = {
"type": "message_delta",
"delta": {
"stop_reason": response_dict.get("stop_reason"),
"stop_sequence": response_dict.get("stop_sequence"),
},
"usage": {"output_tokens": usage.get("output_tokens", 0) if usage else 0},
}
chunks.append(
f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n".encode()
)
# 6. message_stop event
message_stop = {"type": "message_stop", "usage": usage if usage else {}}
chunks.append(
f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n".encode()
)
return chunks
def __aiter__(self):
return self
async def __anext__(self):
if self.current_index >= len(self.chunks):
raise StopAsyncIteration
chunk = self.chunks[self.current_index]
self.current_index += 1
return chunk
def __iter__(self):
return self
def __next__(self):
if self.current_index >= len(self.chunks):
raise StopIteration
chunk = self.chunks[self.current_index]
self.current_index += 1
return chunk

View File

@@ -0,0 +1,362 @@
"""
- call /messages on Anthropic API
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
- Ensure requests are logged in the DB - stream + non-stream
"""
import asyncio
import contextvars
from functools import partial
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
from litellm.types.llms.anthropic_messages.anthropic_request import AnthropicMetadata
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ProviderConfigManager, client
from ..adapters.handler import LiteLLMMessagesToCompletionTransformationHandler
from ..responses_adapters.handler import LiteLLMMessagesToResponsesAPIHandler
from .utils import AnthropicMessagesRequestUtils, mock_response
# Providers that are routed directly to the OpenAI Responses API instead of
# going through chat/completions.
_RESPONSES_API_PROVIDERS = frozenset({"openai"})
def _should_route_to_responses_api(custom_llm_provider: Optional[str]) -> bool:
"""Return True when the provider should use the Responses API path.
Set ``litellm.use_chat_completions_url_for_anthropic_messages = True`` to
opt out and route OpenAI/Azure requests through chat/completions instead.
"""
if litellm.use_chat_completions_url_for_anthropic_messages:
return False
return custom_llm_provider in _RESPONSES_API_PROVIDERS
####### ENVIRONMENT VARIABLES ###################
# Initialize any necessary instances or variables here
base_llm_http_handler = BaseLLMHTTPHandler()
#################################################
async def _execute_pre_request_hooks(
model: str,
messages: List[Dict],
tools: Optional[List[Dict]],
stream: Optional[bool],
custom_llm_provider: Optional[str],
**kwargs,
) -> Dict:
"""
Execute pre-request hooks from CustomLogger callbacks.
Allows CustomLoggers to modify request parameters before the API call.
Used for WebSearch tool conversion, stream modification, etc.
Args:
model: Model name
messages: List of messages
tools: Optional tools list
stream: Optional stream flag
custom_llm_provider: Provider name (if not set, will be extracted from model)
**kwargs: Additional request parameters
Returns:
Dict containing all (potentially modified) request parameters including tools, stream
"""
# If custom_llm_provider not provided, extract from model
if not custom_llm_provider:
try:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
except Exception:
# If extraction fails, continue without provider
pass
# Build complete request kwargs dict
request_kwargs = {
"tools": tools,
"stream": stream,
"litellm_params": {
"custom_llm_provider": custom_llm_provider,
},
**kwargs,
}
if not litellm.callbacks:
return request_kwargs
from litellm.integrations.custom_logger import CustomLogger as _CustomLogger
for callback in litellm.callbacks:
if not isinstance(callback, _CustomLogger):
continue
# Call the pre-request hook
modified_kwargs = await callback.async_pre_request_hook(
model, messages, request_kwargs
)
# If hook returned modified kwargs, use them
if modified_kwargs is not None:
request_kwargs = modified_kwargs
return request_kwargs
@client
async def anthropic_messages(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
"""
Async: Make llm api request in Anthropic /messages API spec
"""
# Execute pre-request hooks to allow CustomLoggers to modify request
request_kwargs = await _execute_pre_request_hooks(
model=model,
messages=messages,
tools=tools,
stream=stream,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
# Extract modified parameters
tools = request_kwargs.pop("tools", tools)
stream = request_kwargs.pop("stream", stream)
# Remove litellm_params from kwargs (only needed for hooks)
request_kwargs.pop("litellm_params", None)
# Merge back any other modifications
kwargs.update(request_kwargs)
loop = asyncio.get_event_loop()
kwargs["is_async"] = True
func = partial(
anthropic_messages_handler,
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
api_key=api_key,
api_base=api_base,
client=client,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response
return response
def validate_anthropic_api_metadata(metadata: Optional[Dict] = None) -> Optional[Dict]:
"""
Validate Anthropic API metadata - This is done to ensure only allowed `metadata` fields are passed to Anthropic API
If there are any litellm specific metadata fields, use `litellm_metadata` key to pass them.
"""
if metadata is None:
return None
anthropic_metadata_obj = AnthropicMetadata(**metadata)
return anthropic_metadata_obj.model_dump(exclude_none=True)
def anthropic_messages_handler(
max_tokens: int,
messages: List[Dict],
model: str,
metadata: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
container: Optional[Dict] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client: Optional[AsyncHTTPHandler] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
) -> Union[
AnthropicMessagesResponse,
AsyncIterator[Any],
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
]:
"""
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
Args:
container: Container config with skills for code execution
"""
from litellm.types.utils import LlmProviders
metadata = validate_anthropic_api_metadata(metadata)
local_vars = locals()
is_async = kwargs.pop("is_async", False)
# Use provided client or create a new one
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
# Store original model name before get_llm_provider strips the provider prefix
# This is needed by agentic hooks (e.g., websearch_interception) to make follow-up requests
original_model = model
litellm_params = GenericLiteLLMParams(
**kwargs,
api_key=api_key,
api_base=api_base,
custom_llm_provider=custom_llm_provider,
)
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = litellm.get_llm_provider(
model=model,
custom_llm_provider=custom_llm_provider,
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
)
# Store agentic loop params in logging object for agentic hooks
# This provides original request context needed for follow-up calls
if litellm_logging_obj is not None:
litellm_logging_obj.model_call_details["agentic_loop_params"] = {
"model": original_model,
"custom_llm_provider": custom_llm_provider,
}
# Check if stream was converted for WebSearch interception
# This is set in the async wrapper above when stream=True is converted to stream=False
if kwargs.get("_websearch_interception_converted_stream", False):
litellm_logging_obj.model_call_details[
"websearch_interception_converted_stream"
] = True
if litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
return mock_response(
model=model,
messages=messages,
max_tokens=max_tokens,
mock_response=litellm_params.mock_response,
)
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [
provider.value for provider in LlmProviders
]:
anthropic_messages_provider_config = (
ProviderConfigManager.get_provider_anthropic_messages_config(
model=model,
provider=litellm.LlmProviders(custom_llm_provider),
)
)
if anthropic_messages_provider_config is None:
# Route to Responses API for OpenAI / Azure, chat/completions for everything else.
_shared_kwargs = dict(
max_tokens=max_tokens,
messages=messages,
model=model,
metadata=metadata,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
_is_async=is_async,
api_key=api_key,
api_base=api_base,
client=client,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
if _should_route_to_responses_api(custom_llm_provider):
return LiteLLMMessagesToResponsesAPIHandler.anthropic_messages_handler(
**_shared_kwargs
)
return (
LiteLLMMessagesToCompletionTransformationHandler.anthropic_messages_handler(
**_shared_kwargs
)
)
if custom_llm_provider is None:
raise ValueError(
f"custom_llm_provider is required for Anthropic messages, passed in model={model}, custom_llm_provider={custom_llm_provider}"
)
local_vars.update(kwargs)
anthropic_messages_optional_request_params = (
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
params=local_vars
)
)
return base_llm_http_handler.anthropic_messages_handler(
model=model,
messages=messages,
anthropic_messages_provider_config=anthropic_messages_provider_config,
anthropic_messages_optional_request_params=dict(
anthropic_messages_optional_request_params
),
_is_async=is_async,
client=client,
custom_llm_provider=custom_llm_provider,
litellm_params=litellm_params,
logging_obj=litellm_logging_obj,
api_key=api_key,
api_base=api_base,
stream=stream,
kwargs=kwargs,
)

View File

@@ -0,0 +1,108 @@
import asyncio
import json
from datetime import datetime
from typing import Any, AsyncIterator, List, Union
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ = PassThroughEndpointLogging()
class BaseAnthropicMessagesStreamingIterator:
"""
Base class for Anthropic Messages streaming iterators that provides common logic
for streaming response handling and logging.
"""
def __init__(
self,
litellm_logging_obj: LiteLLMLoggingObj,
request_body: dict,
):
self.litellm_logging_obj = litellm_logging_obj
self.request_body = request_body
self.start_time = datetime.now()
async def _handle_streaming_logging(self, collected_chunks: List[bytes]):
"""Handle the logging after all chunks have been collected."""
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
end_time = datetime.now()
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=self.litellm_logging_obj,
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
url_route="/v1/messages",
request_body=self.request_body or {},
endpoint_type=EndpointType.ANTHROPIC,
start_time=self.start_time,
raw_bytes=collected_chunks,
end_time=end_time,
)
)
def get_async_streaming_response_iterator(
self,
httpx_response,
request_body: dict,
litellm_logging_obj: LiteLLMLoggingObj,
) -> AsyncIterator:
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
# Use the existing streaming handler for Anthropic
return PassThroughStreamingHandler.chunk_processor(
response=httpx_response,
request_body=request_body,
litellm_logging_obj=litellm_logging_obj,
endpoint_type=EndpointType.ANTHROPIC,
start_time=self.start_time,
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
url_route="/v1/messages",
)
def _convert_chunk_to_sse_format(self, chunk: Union[dict, Any]) -> bytes:
"""
Convert a chunk to Server-Sent Events format.
This method should be overridden by subclasses if they need custom
chunk formatting logic.
"""
if isinstance(chunk, dict):
event_type: str = str(chunk.get("type", "message"))
payload = f"event: {event_type}\n" f"data: {json.dumps(chunk)}\n\n"
return payload.encode()
else:
# For non-dict chunks, return as is
return chunk
async def async_sse_wrapper(
self,
completion_stream: AsyncIterator[
Union[bytes, GenericStreamingChunk, ModelResponseStream, dict]
],
) -> AsyncIterator[bytes]:
"""
Generic async SSE wrapper that converts streaming chunks to SSE format
and handles logging.
This method provides the common logic for both Anthropic and Bedrock implementations.
"""
collected_chunks = []
async for chunk in completion_stream:
encoded_chunk = self._convert_chunk_to_sse_format(chunk)
collected_chunks.append(encoded_chunk)
yield encoded_chunk
# Handle logging after all chunks are processed
await self._handle_streaming_logging(collected_chunks)

View File

@@ -0,0 +1,308 @@
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import verbose_logger
from litellm.llms.base_llm.anthropic_messages.transformation import (
BaseAnthropicMessagesConfig,
)
from litellm.types.llms.anthropic import (
ANTHROPIC_BETA_HEADER_VALUES,
AnthropicMessagesRequest,
)
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
from litellm.types.llms.anthropic_tool_search import get_tool_search_beta_header
from litellm.types.router import GenericLiteLLMParams
from ...common_utils import (
AnthropicError,
AnthropicModelInfo,
optionally_handle_anthropic_oauth,
)
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
def get_supported_anthropic_messages_params(self, model: str) -> list:
return [
"messages",
"model",
"system",
"max_tokens",
"stop_sequences",
"temperature",
"top_p",
"top_k",
"tools",
"tool_choice",
"thinking",
"context_management",
"output_format",
"inference_geo",
"speed",
"output_config",
# TODO: Add Anthropic `metadata` support
# "metadata",
]
@staticmethod
def _filter_billing_headers_from_system(system_param):
"""
Filter out x-anthropic-billing-header metadata from system parameter.
Args:
system_param: Can be a string or a list of system message content blocks
Returns:
Filtered system parameter (string or list), or None if all content was filtered
"""
if isinstance(system_param, str):
# If it's a string and starts with billing header, filter it out
if system_param.startswith("x-anthropic-billing-header:"):
return None
return system_param
elif isinstance(system_param, list):
# Filter list of system content blocks
filtered_list = []
for content_block in system_param:
if isinstance(content_block, dict):
text = content_block.get("text", "")
content_type = content_block.get("type", "")
# Skip text blocks that start with billing header
if content_type == "text" and text.startswith(
"x-anthropic-billing-header:"
):
continue
filtered_list.append(content_block)
else:
# Keep non-dict items as-is
filtered_list.append(content_block)
return filtered_list if len(filtered_list) > 0 else None
else:
return system_param
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
if not api_base.endswith("/v1/messages"):
api_base = f"{api_base}/v1/messages"
return api_base
def validate_anthropic_messages_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
import os
# Check for Anthropic OAuth token in Authorization header
headers, api_key = optionally_handle_anthropic_oauth(
headers=headers, api_key=api_key
)
if api_key is None:
api_key = os.getenv("ANTHROPIC_API_KEY")
if "x-api-key" not in headers and "authorization" not in headers and api_key:
headers["x-api-key"] = api_key
if "anthropic-version" not in headers:
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
if "content-type" not in headers:
headers["content-type"] = "application/json"
headers = self._update_headers_with_anthropic_beta(
headers=headers,
optional_params=optional_params,
)
return headers, api_base
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""
No transformation is needed for Anthropic messages
This takes in a request in the Anthropic /v1/messages API spec -> transforms it to /v1/messages API spec (i.e) no transformation is needed
"""
max_tokens = anthropic_messages_optional_request_params.pop("max_tokens", None)
if max_tokens is None:
raise AnthropicError(
message="max_tokens is required for Anthropic /v1/messages API",
status_code=400,
)
# Filter out x-anthropic-billing-header from system messages
system_param = anthropic_messages_optional_request_params.get("system")
if system_param is not None:
filtered_system = self._filter_billing_headers_from_system(system_param)
if filtered_system is not None and len(filtered_system) > 0:
anthropic_messages_optional_request_params["system"] = filtered_system
else:
# Remove system parameter if all content was filtered out
anthropic_messages_optional_request_params.pop("system", None)
# Transform context_management from OpenAI format to Anthropic format if needed
context_management_param = anthropic_messages_optional_request_params.get(
"context_management"
)
if context_management_param is not None:
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
transformed_context_management = (
AnthropicConfig.map_openai_context_management_to_anthropic(
context_management_param
)
)
if transformed_context_management is not None:
anthropic_messages_optional_request_params[
"context_management"
] = transformed_context_management
####### get required params for all anthropic messages requests ######
verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}")
anthropic_messages_request: AnthropicMessagesRequest = AnthropicMessagesRequest(
messages=messages,
max_tokens=max_tokens,
model=model,
**anthropic_messages_optional_request_params,
)
return dict(anthropic_messages_request)
def transform_anthropic_messages_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> AnthropicMessagesResponse:
"""
No transformation is needed for Anthropic messages, since we want the response in the Anthropic /v1/messages API spec
"""
try:
raw_response_json = raw_response.json()
except Exception:
raise AnthropicError(
message=raw_response.text, status_code=raw_response.status_code
)
return AnthropicMessagesResponse(**raw_response_json)
def get_async_streaming_response_iterator(
self,
model: str,
httpx_response: httpx.Response,
request_body: dict,
litellm_logging_obj: LiteLLMLoggingObj,
) -> AsyncIterator:
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
from litellm.llms.anthropic.experimental_pass_through.messages.streaming_iterator import (
BaseAnthropicMessagesStreamingIterator,
)
# Use the shared streaming handler for Anthropic
handler = BaseAnthropicMessagesStreamingIterator(
litellm_logging_obj=litellm_logging_obj,
request_body=request_body,
)
return handler.get_async_streaming_response_iterator(
httpx_response=httpx_response,
request_body=request_body,
litellm_logging_obj=litellm_logging_obj,
)
@staticmethod
def _update_headers_with_anthropic_beta(
headers: dict,
optional_params: dict,
custom_llm_provider: str = "anthropic",
) -> dict:
"""
Auto-inject anthropic-beta headers based on features used.
Handles:
- context_management: adds 'context-management-2025-06-27'
- tool_search: adds provider-specific tool search header
- output_format: adds 'structured-outputs-2025-11-13'
- speed: adds 'fast-mode-2026-02-01'
Args:
headers: Request headers dict
optional_params: Optional parameters including tools, context_management, output_format, speed
custom_llm_provider: Provider name for looking up correct tool search header
"""
beta_values: set = set()
# Get existing beta headers if any
existing_beta = headers.get("anthropic-beta")
if existing_beta:
beta_values.update(b.strip() for b in existing_beta.split(","))
# Check for context management
context_management_param = optional_params.get("context_management")
if context_management_param is not None:
# Check edits array for compact_20260112 type
edits = context_management_param.get("edits", [])
has_compact = False
has_other = False
for edit in edits:
edit_type = edit.get("type", "")
if edit_type == "compact_20260112":
has_compact = True
else:
has_other = True
# Add compact header if any compact edits exist
if has_compact:
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
# Add context management header if any other edits exist
if has_other:
beta_values.add(
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
)
# Check for structured outputs
if optional_params.get("output_format") is not None:
beta_values.add(
ANTHROPIC_BETA_HEADER_VALUES.STRUCTURED_OUTPUT_2025_09_25.value
)
# Check for fast mode
if optional_params.get("speed") == "fast":
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.FAST_MODE_2026_02_01.value)
# Check for tool search tools
tools = optional_params.get("tools")
if tools:
anthropic_model_info = AnthropicModelInfo()
if anthropic_model_info.is_tool_search_used(tools):
# Use provider-specific tool search header
tool_search_header = get_tool_search_beta_header(custom_llm_provider)
beta_values.add(tool_search_header)
if beta_values:
headers["anthropic-beta"] = ",".join(sorted(beta_values))
return headers

View File

@@ -0,0 +1,75 @@
from typing import Any, Dict, List, cast, get_type_hints
from litellm.types.llms.anthropic import AnthropicMessagesRequestOptionalParams
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
class AnthropicMessagesRequestUtils:
@staticmethod
def get_requested_anthropic_messages_optional_param(
params: Dict[str, Any],
) -> AnthropicMessagesRequestOptionalParams:
"""
Filter parameters to only include those defined in AnthropicMessagesRequestOptionalParams.
Args:
params: Dictionary of parameters to filter
Returns:
AnthropicMessagesRequestOptionalParams instance with only the valid parameters
"""
valid_keys = get_type_hints(AnthropicMessagesRequestOptionalParams).keys()
filtered_params = {
k: v for k, v in params.items() if k in valid_keys and v is not None
}
return cast(AnthropicMessagesRequestOptionalParams, filtered_params)
def mock_response(
model: str,
messages: List[Dict],
max_tokens: int,
mock_response: str = "Hi! My name is Claude.",
**kwargs,
) -> AnthropicMessagesResponse:
"""
Mock response for Anthropic messages
"""
from litellm.exceptions import (
ContextWindowExceededError,
InternalServerError,
RateLimitError,
)
if mock_response == "litellm.InternalServerError":
raise InternalServerError(
message="this is a mock internal server error",
llm_provider="anthropic",
model=model,
)
elif mock_response == "litellm.ContextWindowExceededError":
raise ContextWindowExceededError(
message="this is a mock context window exceeded error",
llm_provider="anthropic",
model=model,
)
elif mock_response == "litellm.RateLimitError":
raise RateLimitError(
message="this is a mock rate limit error",
llm_provider="anthropic",
model=model,
)
return AnthropicMessagesResponse(
**{
"content": [{"text": mock_response, "type": "text"}],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-sonnet-4-20250514",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": None,
"type": "message",
"usage": {"input_tokens": 2095, "output_tokens": 503},
}
)

View File

@@ -0,0 +1,3 @@
from .transformation import LiteLLMAnthropicToResponsesAPIAdapter
__all__ = ["LiteLLMAnthropicToResponsesAPIAdapter"]

View File

@@ -0,0 +1,239 @@
"""
Handler for the Anthropic v1/messages -> OpenAI Responses API path.
Used when the target model is an OpenAI or Azure model.
"""
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
import litellm
from litellm.types.llms.anthropic import AnthropicMessagesRequest
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
)
from litellm.types.llms.openai import ResponsesAPIResponse
from .streaming_iterator import AnthropicResponsesStreamWrapper
from .transformation import LiteLLMAnthropicToResponsesAPIAdapter
_ADAPTER = LiteLLMAnthropicToResponsesAPIAdapter()
def _build_responses_kwargs(
*,
max_tokens: int,
messages: List[Dict],
model: str,
context_management: Optional[Dict] = None,
metadata: Optional[Dict] = None,
output_config: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
extra_kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Build the kwargs dict to pass directly to litellm.responses() / litellm.aresponses().
"""
# Build a typed AnthropicMessagesRequest for the adapter
request_data: Dict[str, Any] = {
"model": model,
"messages": messages,
"max_tokens": max_tokens,
}
if context_management:
request_data["context_management"] = context_management
if output_config:
request_data["output_config"] = output_config
if metadata:
request_data["metadata"] = metadata
if system:
request_data["system"] = system
if temperature is not None:
request_data["temperature"] = temperature
if thinking:
request_data["thinking"] = thinking
if tool_choice:
request_data["tool_choice"] = tool_choice
if tools:
request_data["tools"] = tools
if top_p is not None:
request_data["top_p"] = top_p
if output_format:
request_data["output_format"] = output_format
anthropic_request = AnthropicMessagesRequest(**request_data) # type: ignore[typeddict-item]
responses_kwargs = _ADAPTER.translate_request(anthropic_request)
if stream:
responses_kwargs["stream"] = True
# Forward litellm-specific kwargs (api_key, api_base, logging obj, etc.)
excluded = {"anthropic_messages"}
for key, value in (extra_kwargs or {}).items():
if key == "litellm_logging_obj" and value is not None:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObject,
)
from litellm.types.utils import CallTypes
if isinstance(value, LiteLLMLoggingObject):
# Reclassify as acompletion so the success handler doesn't try to
# validate the Responses API event as an AnthropicResponse.
# (Mirrors the pattern used in LiteLLMMessagesToCompletionTransformationHandler.)
setattr(value, "call_type", CallTypes.acompletion.value)
responses_kwargs[key] = value
elif key not in excluded and key not in responses_kwargs and value is not None:
responses_kwargs[key] = value
return responses_kwargs
class LiteLLMMessagesToResponsesAPIHandler:
"""
Handles Anthropic /v1/messages requests for OpenAI / Azure models by
calling litellm.responses() / litellm.aresponses() directly and translating
the response back to Anthropic format.
"""
@staticmethod
async def async_anthropic_messages_handler(
max_tokens: int,
messages: List[Dict],
model: str,
context_management: Optional[Dict] = None,
metadata: Optional[Dict] = None,
output_config: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
**kwargs,
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
responses_kwargs = _build_responses_kwargs(
max_tokens=max_tokens,
messages=messages,
model=model,
context_management=context_management,
metadata=metadata,
output_config=output_config,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
extra_kwargs=kwargs,
)
result = await litellm.aresponses(**responses_kwargs)
if stream:
wrapper = AnthropicResponsesStreamWrapper(
responses_stream=result, model=model
)
return wrapper.async_anthropic_sse_wrapper()
if not isinstance(result, ResponsesAPIResponse):
raise ValueError(f"Expected ResponsesAPIResponse, got {type(result)}")
return _ADAPTER.translate_response(result)
@staticmethod
def anthropic_messages_handler(
max_tokens: int,
messages: List[Dict],
model: str,
context_management: Optional[Dict] = None,
metadata: Optional[Dict] = None,
output_config: Optional[Dict] = None,
stop_sequences: Optional[List[str]] = None,
stream: Optional[bool] = False,
system: Optional[str] = None,
temperature: Optional[float] = None,
thinking: Optional[Dict] = None,
tool_choice: Optional[Dict] = None,
tools: Optional[List[Dict]] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
output_format: Optional[Dict] = None,
_is_async: bool = False,
**kwargs,
) -> Union[
AnthropicMessagesResponse,
AsyncIterator[Any],
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
]:
if _is_async:
return (
LiteLLMMessagesToResponsesAPIHandler.async_anthropic_messages_handler(
max_tokens=max_tokens,
messages=messages,
model=model,
context_management=context_management,
metadata=metadata,
output_config=output_config,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
**kwargs,
)
)
# Sync path
responses_kwargs = _build_responses_kwargs(
max_tokens=max_tokens,
messages=messages,
model=model,
context_management=context_management,
metadata=metadata,
output_config=output_config,
stop_sequences=stop_sequences,
stream=stream,
system=system,
temperature=temperature,
thinking=thinking,
tool_choice=tool_choice,
tools=tools,
top_k=top_k,
top_p=top_p,
output_format=output_format,
extra_kwargs=kwargs,
)
result = litellm.responses(**responses_kwargs)
if stream:
wrapper = AnthropicResponsesStreamWrapper(
responses_stream=result, model=model
)
return wrapper.async_anthropic_sse_wrapper()
if not isinstance(result, ResponsesAPIResponse):
raise ValueError(f"Expected ResponsesAPIResponse, got {type(result)}")
return _ADAPTER.translate_response(result)

View File

@@ -0,0 +1,344 @@
# What is this?
## Translates OpenAI call to Anthropic `/v1/messages` format
import json
import traceback
from collections import deque
from typing import Any, AsyncIterator, Dict
from litellm import verbose_logger
from litellm._uuid import uuid
class AnthropicResponsesStreamWrapper:
"""
Wraps a Responses API streaming iterator and re-emits events in Anthropic SSE format.
Responses API event flow (relevant subset):
response.created -> message_start
response.output_item.added -> content_block_start (if message/function_call)
response.output_text.delta -> content_block_delta (text_delta)
response.reasoning_summary_text.delta -> content_block_delta (thinking_delta)
response.function_call_arguments.delta -> content_block_delta (input_json_delta)
response.output_item.done -> content_block_stop
response.completed -> message_delta + message_stop
"""
def __init__(
self,
responses_stream: Any,
model: str,
) -> None:
self.responses_stream = responses_stream
self.model = model
self._message_id: str = f"msg_{uuid.uuid4()}"
self._current_block_index: int = -1
# Map item_id -> content_block_index so we can stop the right block later
self._item_id_to_block_index: Dict[str, int] = {}
# Track open function_call items by item_id so we can emit tool_use start
self._pending_tool_ids: Dict[
str, str
] = {} # item_id -> call_id / name accumulator
self._sent_message_start = False
self._sent_message_stop = False
self._chunk_queue: deque = deque()
def _make_message_start(self) -> Dict[str, Any]:
return {
"type": "message_start",
"message": {
"id": self._message_id,
"type": "message",
"role": "assistant",
"content": [],
"model": self.model,
"stop_reason": None,
"stop_sequence": None,
"usage": {
"input_tokens": 0,
"output_tokens": 0,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
},
},
}
def _next_block_index(self) -> int:
self._current_block_index += 1
return self._current_block_index
def _process_event(self, event: Any) -> None: # noqa: PLR0915
"""Convert one Responses API event into zero or more Anthropic chunks queued for emission."""
event_type = getattr(event, "type", None)
if event_type is None and isinstance(event, dict):
event_type = event.get("type")
if event_type is None:
return
# ---- message_start ----
if event_type == "response.created":
self._sent_message_start = True
self._chunk_queue.append(self._make_message_start())
return
# ---- content_block_start for a new output message item ----
if event_type == "response.output_item.added":
item = getattr(event, "item", None) or (
event.get("item") if isinstance(event, dict) else None
)
if item is None:
return
item_type = getattr(item, "type", None) or (
item.get("type") if isinstance(item, dict) else None
)
item_id = getattr(item, "id", None) or (
item.get("id") if isinstance(item, dict) else None
)
if item_type == "message":
block_idx = self._next_block_index()
if item_id:
self._item_id_to_block_index[item_id] = block_idx
self._chunk_queue.append(
{
"type": "content_block_start",
"index": block_idx,
"content_block": {"type": "text", "text": ""},
}
)
elif item_type == "function_call":
call_id = (
getattr(item, "call_id", None)
or (item.get("call_id") if isinstance(item, dict) else None)
or ""
)
name = (
getattr(item, "name", None)
or (item.get("name") if isinstance(item, dict) else None)
or ""
)
block_idx = self._next_block_index()
if item_id:
self._item_id_to_block_index[item_id] = block_idx
self._pending_tool_ids[item_id] = call_id
self._chunk_queue.append(
{
"type": "content_block_start",
"index": block_idx,
"content_block": {
"type": "tool_use",
"id": call_id,
"name": name,
"input": {},
},
}
)
elif item_type == "reasoning":
block_idx = self._next_block_index()
if item_id:
self._item_id_to_block_index[item_id] = block_idx
self._chunk_queue.append(
{
"type": "content_block_start",
"index": block_idx,
"content_block": {"type": "thinking", "thinking": ""},
}
)
return
# ---- text delta ----
if event_type == "response.output_text.delta":
item_id = getattr(event, "item_id", None) or (
event.get("item_id") if isinstance(event, dict) else None
)
delta = getattr(event, "delta", "") or (
event.get("delta", "") if isinstance(event, dict) else ""
)
block_idx = (
self._item_id_to_block_index.get(item_id, self._current_block_index)
if item_id
else self._current_block_index
)
self._chunk_queue.append(
{
"type": "content_block_delta",
"index": block_idx,
"delta": {"type": "text_delta", "text": delta},
}
)
return
# ---- reasoning summary text delta ----
if event_type == "response.reasoning_summary_text.delta":
item_id = getattr(event, "item_id", None) or (
event.get("item_id") if isinstance(event, dict) else None
)
delta = getattr(event, "delta", "") or (
event.get("delta", "") if isinstance(event, dict) else ""
)
block_idx = (
self._item_id_to_block_index.get(item_id, self._current_block_index)
if item_id
else self._current_block_index
)
self._chunk_queue.append(
{
"type": "content_block_delta",
"index": block_idx,
"delta": {"type": "thinking_delta", "thinking": delta},
}
)
return
# ---- function call arguments delta ----
if event_type == "response.function_call_arguments.delta":
item_id = getattr(event, "item_id", None) or (
event.get("item_id") if isinstance(event, dict) else None
)
delta = getattr(event, "delta", "") or (
event.get("delta", "") if isinstance(event, dict) else ""
)
block_idx = (
self._item_id_to_block_index.get(item_id, self._current_block_index)
if item_id
else self._current_block_index
)
self._chunk_queue.append(
{
"type": "content_block_delta",
"index": block_idx,
"delta": {"type": "input_json_delta", "partial_json": delta},
}
)
return
# ---- output item done -> content_block_stop ----
if event_type == "response.output_item.done":
item = getattr(event, "item", None) or (
event.get("item") if isinstance(event, dict) else None
)
item_id = (
getattr(item, "id", None)
or (item.get("id") if isinstance(item, dict) else None)
if item
else None
)
block_idx = (
self._item_id_to_block_index.get(item_id, self._current_block_index)
if item_id
else self._current_block_index
)
self._chunk_queue.append(
{
"type": "content_block_stop",
"index": block_idx,
}
)
return
# ---- response completed -> message_delta + message_stop ----
if event_type in (
"response.completed",
"response.failed",
"response.incomplete",
):
response_obj = getattr(event, "response", None) or (
event.get("response") if isinstance(event, dict) else None
)
stop_reason = "end_turn"
input_tokens = 0
output_tokens = 0
cache_creation_tokens = 0
cache_read_tokens = 0
if response_obj is not None:
status = getattr(response_obj, "status", None)
if status == "incomplete":
stop_reason = "max_tokens"
usage = getattr(response_obj, "usage", None)
if usage is not None:
input_tokens = getattr(usage, "input_tokens", 0) or 0
output_tokens = getattr(usage, "output_tokens", 0) or 0
cache_creation_tokens = getattr(usage, "input_tokens_details", None) # type: ignore[assignment]
cache_read_tokens = getattr(usage, "output_tokens_details", None) # type: ignore[assignment]
# Prefer direct cache fields if present
cache_creation_tokens = int(
getattr(usage, "cache_creation_input_tokens", 0) or 0
)
cache_read_tokens = int(
getattr(usage, "cache_read_input_tokens", 0) or 0
)
# Check if tool_use was in the output to override stop_reason
if response_obj is not None:
output = getattr(response_obj, "output", []) or []
for out_item in output:
out_type = getattr(out_item, "type", None) or (
out_item.get("type") if isinstance(out_item, dict) else None
)
if out_type == "function_call":
stop_reason = "tool_use"
break
usage_delta: Dict[str, Any] = {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
}
if cache_creation_tokens:
usage_delta["cache_creation_input_tokens"] = cache_creation_tokens
if cache_read_tokens:
usage_delta["cache_read_input_tokens"] = cache_read_tokens
self._chunk_queue.append(
{
"type": "message_delta",
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
"usage": usage_delta,
}
)
self._chunk_queue.append({"type": "message_stop"})
self._sent_message_stop = True
return
def __aiter__(self) -> "AnthropicResponsesStreamWrapper":
return self
async def __anext__(self) -> Dict[str, Any]:
# Return any queued chunks first
if self._chunk_queue:
return self._chunk_queue.popleft()
# Emit message_start if not yet done (fallback if response.created wasn't fired)
if not self._sent_message_start:
self._sent_message_start = True
self._chunk_queue.append(self._make_message_start())
return self._chunk_queue.popleft()
# Consume the upstream stream
try:
async for event in self.responses_stream:
self._process_event(event)
if self._chunk_queue:
return self._chunk_queue.popleft()
except StopAsyncIteration:
pass
except Exception as e:
verbose_logger.error(
f"AnthropicResponsesStreamWrapper error: {e}\n{traceback.format_exc()}"
)
# Drain any remaining queued chunks
if self._chunk_queue:
return self._chunk_queue.popleft()
raise StopAsyncIteration
async def async_anthropic_sse_wrapper(self) -> AsyncIterator[bytes]:
"""Yield SSE-encoded bytes for each Anthropic event chunk."""
async for chunk in self:
if isinstance(chunk, dict):
event_type: str = str(chunk.get("type", "message"))
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
yield payload.encode()
else:
yield chunk

View File

@@ -0,0 +1,488 @@
"""
Transformation layer: Anthropic /v1/messages <-> OpenAI Responses API.
This module owns all format conversions for the direct v1/messages -> Responses API
path used for OpenAI and Azure models.
"""
import json
from typing import Any, Dict, List, Optional, Union, cast
from litellm.types.llms.anthropic import (
AllAnthropicToolsValues,
AnthopicMessagesAssistantMessageParam,
AnthropicFinishReason,
AnthropicMessagesRequest,
AnthropicMessagesToolChoice,
AnthropicMessagesUserMessageParam,
AnthropicResponseContentBlockText,
AnthropicResponseContentBlockThinking,
AnthropicResponseContentBlockToolUse,
)
from litellm.types.llms.anthropic_messages.anthropic_response import (
AnthropicMessagesResponse,
AnthropicUsage,
)
from litellm.types.llms.openai import ResponsesAPIResponse
class LiteLLMAnthropicToResponsesAPIAdapter:
"""
Converts Anthropic /v1/messages requests to OpenAI Responses API format and
converts Responses API responses back to Anthropic format.
"""
# ------------------------------------------------------------------ #
# Request translation: Anthropic -> Responses API #
# ------------------------------------------------------------------ #
@staticmethod
def _translate_anthropic_image_source_to_url(source: dict) -> Optional[str]:
"""Convert Anthropic image source to a URL string."""
source_type = source.get("type")
if source_type == "base64":
media_type = source.get("media_type", "image/jpeg")
data = source.get("data", "")
return f"data:{media_type};base64,{data}" if data else None
elif source_type == "url":
return source.get("url")
return None
def translate_messages_to_responses_input( # noqa: PLR0915
self,
messages: List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
],
) -> List[Dict[str, Any]]:
"""
Convert Anthropic messages list to Responses API `input` items.
Mapping:
user text -> message(role=user, input_text)
user image -> message(role=user, input_image)
user tool_result -> function_call_output
assistant text -> message(role=assistant, output_text)
assistant tool_use -> function_call
"""
input_items: List[Dict[str, Any]] = []
for m in messages:
role = m["role"]
content = m.get("content")
if role == "user":
if isinstance(content, str):
input_items.append(
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": content}],
}
)
elif isinstance(content, list):
user_parts: List[Dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
continue
btype = block.get("type")
if btype == "text":
user_parts.append(
{"type": "input_text", "text": block.get("text", "")}
)
elif btype == "image":
url = self._translate_anthropic_image_source_to_url(
block.get("source", {})
)
if url:
user_parts.append(
{"type": "input_image", "image_url": url}
)
elif btype == "tool_result":
tool_use_id = block.get("tool_use_id", "")
inner = block.get("content")
if inner is None:
output_text = ""
elif isinstance(inner, str):
output_text = inner
elif isinstance(inner, list):
parts = [
c.get("text", "")
for c in inner
if isinstance(c, dict) and c.get("type") == "text"
]
output_text = "\n".join(parts)
else:
output_text = str(inner)
# tool_result is a top-level item, not inside the message
input_items.append(
{
"type": "function_call_output",
"call_id": tool_use_id,
"output": output_text,
}
)
if user_parts:
input_items.append(
{
"type": "message",
"role": "user",
"content": user_parts,
}
)
elif role == "assistant":
if isinstance(content, str):
input_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": content}],
}
)
elif isinstance(content, list):
asst_parts: List[Dict[str, Any]] = []
for block in content:
if not isinstance(block, dict):
continue
btype = block.get("type")
if btype == "text":
asst_parts.append(
{"type": "output_text", "text": block.get("text", "")}
)
elif btype == "tool_use":
# tool_use becomes a top-level function_call item
input_items.append(
{
"type": "function_call",
"call_id": block.get("id", ""),
"name": block.get("name", ""),
"arguments": json.dumps(block.get("input", {})),
}
)
elif btype == "thinking":
thinking_text = block.get("thinking", "")
if thinking_text:
asst_parts.append(
{"type": "output_text", "text": thinking_text}
)
if asst_parts:
input_items.append(
{
"type": "message",
"role": "assistant",
"content": asst_parts,
}
)
return input_items
def translate_tools_to_responses_api(
self,
tools: List[AllAnthropicToolsValues],
) -> List[Dict[str, Any]]:
"""Convert Anthropic tool definitions to Responses API function tools."""
result: List[Dict[str, Any]] = []
for tool in tools:
tool_dict = cast(Dict[str, Any], tool)
tool_type = tool_dict.get("type", "")
tool_name = tool_dict.get("name", "")
# web_search tool
if (
isinstance(tool_type, str) and tool_type.startswith("web_search")
) or tool_name == "web_search":
result.append({"type": "web_search_preview"})
continue
func_tool: Dict[str, Any] = {"type": "function", "name": tool_name}
if "description" in tool_dict:
func_tool["description"] = tool_dict["description"]
if "input_schema" in tool_dict:
func_tool["parameters"] = tool_dict["input_schema"]
result.append(func_tool)
return result
@staticmethod
def translate_tool_choice_to_responses_api(
tool_choice: AnthropicMessagesToolChoice,
) -> Dict[str, Any]:
"""Convert Anthropic tool_choice to Responses API tool_choice."""
tc_type = tool_choice.get("type")
if tc_type == "any":
return {"type": "required"}
elif tc_type == "tool":
return {"type": "function", "name": tool_choice.get("name", "")}
return {"type": "auto"}
@staticmethod
def translate_context_management_to_responses_api(
context_management: Dict[str, Any],
) -> Optional[List[Dict[str, Any]]]:
"""
Convert Anthropic context_management dict to OpenAI Responses API array format.
Anthropic format: {"edits": [{"type": "compact_20260112", "trigger": {"type": "input_tokens", "value": 150000}}]}
OpenAI format: [{"type": "compaction", "compact_threshold": 150000}]
"""
if not isinstance(context_management, dict):
return None
edits = context_management.get("edits", [])
if not isinstance(edits, list):
return None
result: List[Dict[str, Any]] = []
for edit in edits:
if not isinstance(edit, dict):
continue
edit_type = edit.get("type", "")
if edit_type == "compact_20260112":
entry: Dict[str, Any] = {"type": "compaction"}
trigger = edit.get("trigger")
if isinstance(trigger, dict) and trigger.get("value") is not None:
entry["compact_threshold"] = int(trigger["value"])
result.append(entry)
return result if result else None
@staticmethod
def translate_thinking_to_reasoning(
thinking: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Convert Anthropic thinking param to Responses API reasoning param.
thinking.budget_tokens maps to reasoning effort:
>= 10000 -> high, >= 5000 -> medium, >= 2000 -> low, < 2000 -> minimal
"""
if not isinstance(thinking, dict) or thinking.get("type") != "enabled":
return None
budget = thinking.get("budget_tokens", 0)
if budget >= 10000:
effort = "high"
elif budget >= 5000:
effort = "medium"
elif budget >= 2000:
effort = "low"
else:
effort = "minimal"
return {"effort": effort, "summary": "detailed"}
def translate_request(
self,
anthropic_request: AnthropicMessagesRequest,
) -> Dict[str, Any]:
"""
Translate a full Anthropic /v1/messages request dict to
litellm.responses() / litellm.aresponses() kwargs.
"""
model: str = anthropic_request["model"]
messages_list = cast(
List[
Union[
AnthropicMessagesUserMessageParam,
AnthopicMessagesAssistantMessageParam,
]
],
anthropic_request["messages"],
)
responses_kwargs: Dict[str, Any] = {
"model": model,
"input": self.translate_messages_to_responses_input(messages_list),
}
# system -> instructions
system = anthropic_request.get("system")
if system:
if isinstance(system, str):
responses_kwargs["instructions"] = system
elif isinstance(system, list):
text_parts = [
b.get("text", "")
for b in system
if isinstance(b, dict) and b.get("type") == "text"
]
responses_kwargs["instructions"] = "\n".join(filter(None, text_parts))
# max_tokens -> max_output_tokens
max_tokens = anthropic_request.get("max_tokens")
if max_tokens:
responses_kwargs["max_output_tokens"] = max_tokens
# temperature / top_p passed through
if "temperature" in anthropic_request:
responses_kwargs["temperature"] = anthropic_request["temperature"]
if "top_p" in anthropic_request:
responses_kwargs["top_p"] = anthropic_request["top_p"]
# tools
tools = anthropic_request.get("tools")
if tools:
responses_kwargs["tools"] = self.translate_tools_to_responses_api(
cast(List[AllAnthropicToolsValues], tools)
)
# tool_choice
tool_choice = anthropic_request.get("tool_choice")
if tool_choice:
responses_kwargs[
"tool_choice"
] = self.translate_tool_choice_to_responses_api(
cast(AnthropicMessagesToolChoice, tool_choice)
)
# thinking -> reasoning
thinking = anthropic_request.get("thinking")
if isinstance(thinking, dict):
reasoning = self.translate_thinking_to_reasoning(thinking)
if reasoning:
responses_kwargs["reasoning"] = reasoning
# output_format / output_config.format -> text format
# output_format: {"type": "json_schema", "schema": {...}}
# output_config: {"format": {"type": "json_schema", "schema": {...}}}
output_format: Any = anthropic_request.get("output_format")
output_config = anthropic_request.get("output_config")
if not isinstance(output_format, dict) and isinstance(output_config, dict):
output_format = output_config.get("format") # type: ignore[assignment]
if (
isinstance(output_format, dict)
and output_format.get("type") == "json_schema"
):
schema = output_format.get("schema")
if schema:
responses_kwargs["text"] = {
"format": {
"type": "json_schema",
"name": "structured_output",
"schema": schema,
"strict": True,
}
}
# context_management: Anthropic dict -> OpenAI array
context_management = anthropic_request.get("context_management")
if isinstance(context_management, dict):
openai_cm = self.translate_context_management_to_responses_api(
context_management
)
if openai_cm is not None:
responses_kwargs["context_management"] = openai_cm
# metadata user_id -> user
metadata = anthropic_request.get("metadata")
if isinstance(metadata, dict) and "user_id" in metadata:
responses_kwargs["user"] = str(metadata["user_id"])[:64]
return responses_kwargs
# ------------------------------------------------------------------ #
# Response translation: Responses API -> Anthropic #
# ------------------------------------------------------------------ #
def translate_response(
self,
response: ResponsesAPIResponse,
) -> AnthropicMessagesResponse:
"""
Translate an OpenAI ResponsesAPIResponse to AnthropicMessagesResponse.
"""
from openai.types.responses import (
ResponseFunctionToolCall,
ResponseOutputMessage,
ResponseReasoningItem,
)
from litellm.types.llms.openai import ResponseAPIUsage
content: List[Dict[str, Any]] = []
stop_reason: AnthropicFinishReason = "end_turn"
for item in response.output:
if isinstance(item, ResponseReasoningItem):
for summary in item.summary:
text = getattr(summary, "text", "")
if text:
content.append(
AnthropicResponseContentBlockThinking(
type="thinking",
thinking=text,
signature=None,
).model_dump()
)
elif isinstance(item, ResponseOutputMessage):
for part in item.content:
if getattr(part, "type", None) == "output_text":
content.append(
AnthropicResponseContentBlockText(
type="text", text=getattr(part, "text", "")
).model_dump()
)
elif isinstance(item, ResponseFunctionToolCall):
try:
input_data = json.loads(item.arguments) if item.arguments else {}
except (json.JSONDecodeError, TypeError):
input_data = {}
content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=item.call_id or item.id or "",
name=item.name,
input=input_data,
).model_dump()
)
stop_reason = "tool_use"
elif isinstance(item, dict):
item_type = item.get("type")
if item_type == "message":
for part in item.get("content", []):
if isinstance(part, dict) and part.get("type") == "output_text":
content.append(
AnthropicResponseContentBlockText(
type="text", text=part.get("text", "")
).model_dump()
)
elif item_type == "function_call":
try:
input_data = json.loads(item.get("arguments", "{}"))
except (json.JSONDecodeError, TypeError):
input_data = {}
content.append(
AnthropicResponseContentBlockToolUse(
type="tool_use",
id=item.get("call_id") or item.get("id", ""),
name=item.get("name", ""),
input=input_data,
).model_dump()
)
stop_reason = "tool_use"
# status -> stop_reason override
if response.status == "incomplete":
stop_reason = "max_tokens"
# usage
raw_usage: Optional[ResponseAPIUsage] = response.usage
input_tokens = int(getattr(raw_usage, "input_tokens", 0) or 0)
output_tokens = int(getattr(raw_usage, "output_tokens", 0) or 0)
anthropic_usage = AnthropicUsage(
input_tokens=input_tokens,
output_tokens=output_tokens,
)
return AnthropicMessagesResponse(
id=response.id,
type="message",
role="assistant",
model=response.model or "unknown-model",
stop_sequence=None,
usage=anthropic_usage, # type: ignore
content=content, # type: ignore
stop_reason=stop_reason,
)

View File

@@ -0,0 +1,4 @@
from .handler import AnthropicFilesHandler
from .transformation import AnthropicFilesConfig
__all__ = ["AnthropicFilesHandler", "AnthropicFilesConfig"]

View File

@@ -0,0 +1,366 @@
import asyncio
import json
import time
from typing import Any, Coroutine, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.llms.openai import (
FileContentRequest,
HttpxBinaryResponseContent,
OpenAIBatchResult,
OpenAIChatCompletionResponse,
OpenAIErrorBody,
)
from litellm.types.utils import CallTypes, LlmProviders, ModelResponse
from ..chat.transformation import AnthropicConfig
from ..common_utils import AnthropicModelInfo
# Map Anthropic error types to HTTP status codes
ANTHROPIC_ERROR_STATUS_CODE_MAP = {
"invalid_request_error": 400,
"authentication_error": 401,
"permission_error": 403,
"not_found_error": 404,
"rate_limit_error": 429,
"api_error": 500,
"overloaded_error": 503,
"timeout_error": 504,
}
class AnthropicFilesHandler:
"""
Handles Anthropic Files API operations.
Currently supports:
- file_content() for retrieving Anthropic Message Batch results
"""
def __init__(self):
self.anthropic_model_info = AnthropicModelInfo()
async def afile_content(
self,
file_content_request: FileContentRequest,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = 600.0,
max_retries: Optional[int] = None,
) -> HttpxBinaryResponseContent:
"""
Async: Retrieve file content from Anthropic.
For batch results, the file_id should be the batch_id.
This will call Anthropic's /v1/messages/batches/{batch_id}/results endpoint.
Args:
file_content_request: Contains file_id (batch_id for batch results)
api_base: Anthropic API base URL
api_key: Anthropic API key
timeout: Request timeout
max_retries: Max retry attempts (unused for now)
Returns:
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
"""
file_id = file_content_request.get("file_id")
if not file_id:
raise ValueError("file_id is required in file_content_request")
# Extract batch_id from file_id
# Handle both formats: "anthropic_batch_results:{batch_id}" or just "{batch_id}"
if file_id.startswith("anthropic_batch_results:"):
batch_id = file_id.replace("anthropic_batch_results:", "", 1)
else:
batch_id = file_id
# Get Anthropic API credentials
api_base = self.anthropic_model_info.get_api_base(api_base)
api_key = api_key or self.anthropic_model_info.get_api_key()
if not api_key:
raise ValueError("Missing Anthropic API Key")
# Construct the Anthropic batch results URL
results_url = f"{api_base.rstrip('/')}/v1/messages/batches/{batch_id}/results"
# Prepare headers
headers = {
"accept": "application/json",
"anthropic-version": "2023-06-01",
"x-api-key": api_key,
}
# Make the request to Anthropic
async_client = get_async_httpx_client(llm_provider=LlmProviders.ANTHROPIC)
anthropic_response = await async_client.get(url=results_url, headers=headers)
anthropic_response.raise_for_status()
# Transform Anthropic batch results to OpenAI format
transformed_content = self._transform_anthropic_batch_results_to_openai_format(
anthropic_response.content
)
# Create a new response with transformed content
transformed_response = httpx.Response(
status_code=anthropic_response.status_code,
headers=anthropic_response.headers,
content=transformed_content,
request=anthropic_response.request,
)
# Return the transformed response content
return HttpxBinaryResponseContent(response=transformed_response)
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
timeout: Union[float, httpx.Timeout] = 600.0,
max_retries: Optional[int] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
"""
Retrieve file content from Anthropic.
For batch results, the file_id should be the batch_id.
This will call Anthropic's /v1/messages/batches/{batch_id}/results endpoint.
Args:
_is_async: Whether to run asynchronously
file_content_request: Contains file_id (batch_id for batch results)
api_base: Anthropic API base URL
api_key: Anthropic API key
timeout: Request timeout
max_retries: Max retry attempts (unused for now)
Returns:
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
"""
if _is_async:
return self.afile_content(
file_content_request=file_content_request,
api_base=api_base,
api_key=api_key,
max_retries=max_retries,
)
else:
return asyncio.run(
self.afile_content(
file_content_request=file_content_request,
api_base=api_base,
api_key=api_key,
timeout=timeout,
max_retries=max_retries,
)
)
def _transform_anthropic_batch_results_to_openai_format(
self, anthropic_content: bytes
) -> bytes:
"""
Transform Anthropic batch results JSONL to OpenAI batch results JSONL format.
Anthropic format:
{
"custom_id": "...",
"result": {
"type": "succeeded",
"message": { ... } // Anthropic message format
}
}
OpenAI format:
{
"custom_id": "...",
"response": {
"status_code": 200,
"request_id": "...",
"body": { ... } // OpenAI chat completion format
}
}
"""
try:
anthropic_config = AnthropicConfig()
transformed_lines = []
# Parse JSONL content
content_str = anthropic_content.decode("utf-8")
for line in content_str.strip().split("\n"):
if not line.strip():
continue
anthropic_result = json.loads(line)
custom_id = anthropic_result.get("custom_id", "")
result = anthropic_result.get("result", {})
result_type = result.get("type", "")
# Transform based on result type
if result_type == "succeeded":
# Transform Anthropic message to OpenAI format
anthropic_message = result.get("message", {})
if anthropic_message:
openai_response_body = (
self._transform_anthropic_message_to_openai_format(
anthropic_message=anthropic_message,
anthropic_config=anthropic_config,
)
)
# Create OpenAI batch result format
openai_result: OpenAIBatchResult = {
"custom_id": custom_id,
"response": {
"status_code": 200,
"request_id": anthropic_message.get("id", ""),
"body": openai_response_body,
},
}
transformed_lines.append(json.dumps(openai_result))
elif result_type == "errored":
# Handle error case
error = result.get("error", {})
error_obj = error.get("error", {})
error_message = error_obj.get("message", "Unknown error")
error_type = error_obj.get("type", "api_error")
status_code = ANTHROPIC_ERROR_STATUS_CODE_MAP.get(error_type, 500)
error_body_errored: OpenAIErrorBody = {
"error": {
"message": error_message,
"type": error_type,
}
}
openai_result_errored: OpenAIBatchResult = {
"custom_id": custom_id,
"response": {
"status_code": status_code,
"request_id": error.get("request_id", ""),
"body": error_body_errored,
},
}
transformed_lines.append(json.dumps(openai_result_errored))
elif result_type in ["canceled", "expired"]:
# Handle canceled/expired cases
error_body_canceled: OpenAIErrorBody = {
"error": {
"message": f"Batch request was {result_type}",
"type": "invalid_request_error",
}
}
openai_result_canceled: OpenAIBatchResult = {
"custom_id": custom_id,
"response": {
"status_code": 400,
"request_id": "",
"body": error_body_canceled,
},
}
transformed_lines.append(json.dumps(openai_result_canceled))
# Join lines and encode back to bytes
transformed_content = "\n".join(transformed_lines)
if transformed_lines:
transformed_content += "\n" # Add trailing newline for JSONL format
return transformed_content.encode("utf-8")
except Exception as e:
verbose_logger.error(
f"Error transforming Anthropic batch results to OpenAI format: {e}"
)
# Return original content if transformation fails
return anthropic_content
def _transform_anthropic_message_to_openai_format(
self, anthropic_message: dict, anthropic_config: AnthropicConfig
) -> OpenAIChatCompletionResponse:
"""
Transform a single Anthropic message to OpenAI chat completion format.
"""
try:
# Create a mock httpx.Response for transformation
mock_response = httpx.Response(
status_code=200,
content=json.dumps(anthropic_message).encode("utf-8"),
)
# Create a ModelResponse object
model_response = ModelResponse()
# Initialize with required fields - will be populated by transform_parsed_response
model_response.choices = [
litellm.Choices(
finish_reason="stop",
index=0,
message=litellm.Message(content="", role="assistant"),
)
] # type: ignore
# Create a logging object for transformation
logging_obj = Logging(
model=anthropic_message.get("model", "claude-3-5-sonnet-20241022"),
messages=[{"role": "user", "content": "batch_request"}],
stream=False,
call_type=CallTypes.aretrieve_batch,
start_time=time.time(),
litellm_call_id="batch_" + str(uuid.uuid4()),
function_id="batch_processing",
litellm_trace_id=str(uuid.uuid4()),
kwargs={"optional_params": {}},
)
logging_obj.optional_params = {}
# Transform using AnthropicConfig
transformed_response = anthropic_config.transform_parsed_response(
completion_response=anthropic_message,
raw_response=mock_response,
model_response=model_response,
json_mode=False,
prefix_prompt=None,
)
# Convert ModelResponse to OpenAI format dict - it's already in OpenAI format
openai_body: OpenAIChatCompletionResponse = transformed_response.model_dump(
exclude_none=True
)
# Ensure id comes from anthropic_message if not set
if not openai_body.get("id"):
openai_body["id"] = anthropic_message.get("id", "")
return openai_body
except Exception as e:
verbose_logger.error(
f"Error transforming Anthropic message to OpenAI format: {e}"
)
# Return a basic error response if transformation fails
error_response: OpenAIChatCompletionResponse = {
"id": anthropic_message.get("id", ""),
"object": "chat.completion",
"created": int(time.time()),
"model": anthropic_message.get("model", ""),
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": ""},
"finish_reason": "error",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
return error_response

View File

@@ -0,0 +1,307 @@
"""
Anthropic Files API transformation config.
Implements BaseFilesConfig for Anthropic's Files API (beta).
Reference: https://docs.anthropic.com/en/docs/build-with-claude/files
Anthropic Files API endpoints:
- POST /v1/files - Upload a file
- GET /v1/files - List files
- GET /v1/files/{file_id} - Retrieve file metadata
- DELETE /v1/files/{file_id} - Delete a file
- GET /v1/files/{file_id}/content - Download file content
"""
import calendar
import time
from typing import Any, Dict, List, Optional, Union, cast
import httpx
from openai.types.file_deleted import FileDeleted
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.base_llm.files.transformation import (
BaseFilesConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import (
CreateFileRequest,
FileContentRequest,
HttpxBinaryResponseContent,
OpenAICreateFileRequestOptionalParams,
OpenAIFileObject,
)
from litellm.types.utils import LlmProviders
from ..common_utils import AnthropicError, AnthropicModelInfo
ANTHROPIC_FILES_API_BASE = "https://api.anthropic.com"
ANTHROPIC_FILES_BETA_HEADER = "files-api-2025-04-14"
class AnthropicFilesConfig(BaseFilesConfig):
"""
Transformation config for Anthropic Files API.
Anthropic uses:
- x-api-key header for authentication
- anthropic-beta: files-api-2025-04-14 header
- multipart/form-data for file uploads
- purpose="messages" (Anthropic-specific, not for batches/fine-tuning)
"""
def __init__(self):
pass
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.ANTHROPIC
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
api_base = AnthropicModelInfo.get_api_base(api_base) or ANTHROPIC_FILES_API_BASE
return f"{api_base.rstrip('/')}/v1/files"
def get_error_class(
self,
error_message: str,
status_code: int,
headers: Union[dict, httpx.Headers],
) -> BaseLLMException:
return AnthropicError(
status_code=status_code,
message=error_message,
headers=cast(httpx.Headers, headers) if isinstance(headers, dict) else headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: list,
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
api_key = AnthropicModelInfo.get_api_key(api_key)
if not api_key:
raise ValueError(
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
)
headers.update(
{
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"anthropic-beta": ANTHROPIC_FILES_BETA_HEADER,
}
)
return headers
def get_supported_openai_params(
self, model: str
) -> List[OpenAICreateFileRequestOptionalParams]:
return ["purpose"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return optional_params
def transform_create_file_request(
self,
model: str,
create_file_data: CreateFileRequest,
optional_params: dict,
litellm_params: dict,
) -> dict:
"""
Transform to multipart form data for Anthropic file upload.
Anthropic expects: POST /v1/files with multipart form-data
- file: the file content
- purpose: "messages" (defaults to "messages" if not provided)
"""
file_data = create_file_data.get("file")
if file_data is None:
raise ValueError("File data is required")
extracted = extract_file_data(file_data)
filename = extracted["filename"] or f"file_{int(time.time())}"
content = extracted["content"]
content_type = extracted.get("content_type", "application/octet-stream")
purpose = create_file_data.get("purpose", "messages")
return {
"file": (filename, content, content_type),
"purpose": (None, purpose),
}
def transform_create_file_response(
self,
model: Optional[str],
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
"""
Transform Anthropic file response to OpenAI format.
Anthropic response:
{
"id": "file-xxx",
"type": "file",
"filename": "document.pdf",
"mime_type": "application/pdf",
"size_bytes": 12345,
"created_at": "2025-01-01T00:00:00Z"
}
"""
response_json = raw_response.json()
return self._parse_anthropic_file(response_json)
def transform_retrieve_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
api_base = (
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
or ANTHROPIC_FILES_API_BASE
)
return f"{api_base.rstrip('/')}/v1/files/{file_id}", {}
def transform_retrieve_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> OpenAIFileObject:
response_json = raw_response.json()
return self._parse_anthropic_file(response_json)
def transform_delete_file_request(
self,
file_id: str,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
api_base = (
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
or ANTHROPIC_FILES_API_BASE
)
return f"{api_base.rstrip('/')}/v1/files/{file_id}", {}
def transform_delete_file_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> FileDeleted:
response_json = raw_response.json()
file_id = response_json.get("id", "")
return FileDeleted(
id=file_id,
deleted=True,
object="file",
)
def transform_list_files_request(
self,
purpose: Optional[str],
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
api_base = (
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
or ANTHROPIC_FILES_API_BASE
)
url = f"{api_base.rstrip('/')}/v1/files"
params: Dict[str, Any] = {}
if purpose:
params["purpose"] = purpose
return url, params
def transform_list_files_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> List[OpenAIFileObject]:
"""
Anthropic list response:
{
"data": [...],
"has_more": false,
"first_id": "...",
"last_id": "..."
}
"""
response_json = raw_response.json()
files_data = response_json.get("data", [])
return [self._parse_anthropic_file(f) for f in files_data]
def transform_file_content_request(
self,
file_content_request: FileContentRequest,
optional_params: dict,
litellm_params: dict,
) -> tuple[str, dict]:
file_id = file_content_request.get("file_id")
api_base = (
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
or ANTHROPIC_FILES_API_BASE
)
return f"{api_base.rstrip('/')}/v1/files/{file_id}/content", {}
def transform_file_content_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
litellm_params: dict,
) -> HttpxBinaryResponseContent:
return HttpxBinaryResponseContent(response=raw_response)
@staticmethod
def _parse_anthropic_file(file_data: dict) -> OpenAIFileObject:
"""Parse Anthropic file object into OpenAI format."""
created_at_str = file_data.get("created_at", "")
if created_at_str:
try:
created_at = int(
calendar.timegm(
time.strptime(
created_at_str.replace("Z", "+00:00")[:19],
"%Y-%m-%dT%H:%M:%S",
)
)
)
except (ValueError, TypeError):
created_at = int(time.time())
else:
created_at = int(time.time())
return OpenAIFileObject(
id=file_data.get("id", ""),
bytes=file_data.get("size_bytes", file_data.get("bytes", 0)),
created_at=created_at,
filename=file_data.get("filename", ""),
object="file",
purpose=file_data.get("purpose", "messages"),
status="uploaded",
status_details=None,
)

View File

@@ -0,0 +1,5 @@
"""Anthropic Skills API integration"""
from .transformation import AnthropicSkillsConfig
__all__ = ["AnthropicSkillsConfig"]

View File

@@ -0,0 +1,279 @@
# Anthropic Skills API Integration
This module provides comprehensive support for the Anthropic Skills API through LiteLLM.
## Features
The Skills API allows you to:
- **Create skills**: Define reusable AI capabilities
- **List skills**: Browse all available skills
- **Get skills**: Retrieve detailed information about a specific skill
- **Delete skills**: Remove skills that are no longer needed
## Quick Start
### Prerequisites
Set your Anthropic API key:
```python
import os
os.environ["ANTHROPIC_API_KEY"] = "your-api-key-here"
```
### Basic Usage
#### Create a Skill
```python
import litellm
# Create a skill with files
# Note: All files must be in the same top-level directory
# and must include a SKILL.md file at the root
skill = litellm.create_skill(
files=[
# List of file objects to upload
# Must include SKILL.md
],
display_title="Python Code Generator",
custom_llm_provider="anthropic"
)
print(f"Created skill: {skill.id}")
# Asynchronous version
skill = await litellm.acreate_skill(
files=[...], # Your files here
display_title="Python Code Generator",
custom_llm_provider="anthropic"
)
```
#### List Skills
```python
# List all skills
skills = litellm.list_skills(
custom_llm_provider="anthropic"
)
for skill in skills.data:
print(f"{skill.display_title}: {skill.id}")
# With pagination and filtering
skills = litellm.list_skills(
limit=20,
source="custom", # Filter by 'custom' or 'anthropic'
custom_llm_provider="anthropic"
)
# Get next page if available
if skills.has_more:
next_page = litellm.list_skills(
page=skills.next_page,
custom_llm_provider="anthropic"
)
```
#### Get a Skill
```python
skill = litellm.get_skill(
skill_id="skill_abc123",
custom_llm_provider="anthropic"
)
print(f"Skill: {skill.display_title}")
print(f"Created: {skill.created_at}")
print(f"Latest version: {skill.latest_version}")
print(f"Source: {skill.source}")
```
#### Delete a Skill
```python
result = litellm.delete_skill(
skill_id="skill_abc123",
custom_llm_provider="anthropic"
)
print(f"Deleted skill {result.id}, type: {result.type}")
```
## API Reference
### `create_skill()`
Create a new skill.
**Parameters:**
- `files` (List[Any], optional): Files to upload for the skill. All files must be in the same top-level directory and must include a SKILL.md file at the root.
- `display_title` (str, optional): Display title for the skill
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
- `extra_headers` (dict, optional): Additional HTTP headers
- `timeout` (float, optional): Request timeout
**Returns:**
- `Skill`: The created skill object
**Async version:** `acreate_skill()`
### `list_skills()`
List all skills.
**Parameters:**
- `limit` (int, optional): Number of results to return per page (max 100, default 20)
- `page` (str, optional): Pagination token for fetching a specific page of results
- `source` (str, optional): Filter skills by source ('custom' or 'anthropic')
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
- `extra_headers` (dict, optional): Additional HTTP headers
- `timeout` (float, optional): Request timeout
**Returns:**
- `ListSkillsResponse`: Object containing a list of skills and pagination info
**Async version:** `alist_skills()`
### `get_skill()`
Get a specific skill by ID.
**Parameters:**
- `skill_id` (str, required): The skill ID
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
- `extra_headers` (dict, optional): Additional HTTP headers
- `timeout` (float, optional): Request timeout
**Returns:**
- `Skill`: The requested skill object
**Async version:** `aget_skill()`
### `delete_skill()`
Delete a skill.
**Parameters:**
- `skill_id` (str, required): The skill ID to delete
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
- `extra_headers` (dict, optional): Additional HTTP headers
- `timeout` (float, optional): Request timeout
**Returns:**
- `DeleteSkillResponse`: Object with `id` and `type` fields
**Async version:** `adelete_skill()`
## Response Types
### `Skill`
Represents a skill from the Anthropic Skills API.
**Fields:**
- `id` (str): Unique identifier
- `created_at` (str): ISO 8601 timestamp
- `display_title` (str, optional): Display title
- `latest_version` (str, optional): Latest version identifier
- `source` (str): Source ("custom" or "anthropic")
- `type` (str): Object type (always "skill")
- `updated_at` (str): ISO 8601 timestamp
### `ListSkillsResponse`
Response from listing skills.
**Fields:**
- `data` (List[Skill]): List of skills
- `next_page` (str, optional): Pagination token for the next page
- `has_more` (bool): Whether more skills are available
### `DeleteSkillResponse`
Response from deleting a skill.
**Fields:**
- `id` (str): The deleted skill ID
- `type` (str): Deleted object type (always "skill_deleted")
## Architecture
The Skills API implementation follows LiteLLM's standard patterns:
1. **Type Definitions** (`litellm/types/llms/anthropic_skills.py`)
- Pydantic models for request/response types
- TypedDict definitions for request parameters
2. **Base Configuration** (`litellm/llms/base_llm/skills/transformation.py`)
- Abstract base class `BaseSkillsAPIConfig`
- Defines transformation interface for provider-specific implementations
3. **Provider Implementation** (`litellm/llms/anthropic/skills/transformation.py`)
- `AnthropicSkillsConfig` - Anthropic-specific transformations
- Handles API authentication, URL construction, and response mapping
4. **Main Handler** (`litellm/skills/main.py`)
- Public API functions (sync and async)
- Request validation and routing
- Error handling
5. **HTTP Handlers** (`litellm/llms/custom_httpx/llm_http_handler.py`)
- Low-level HTTP request/response handling
- Connection pooling and retry logic
## Beta API Support
The Skills API is in beta. The beta header (`skills-2025-10-02`) is automatically added by the Anthropic provider configuration. You can customize it if needed:
```python
skill = litellm.create_skill(
display_title="My Skill",
extra_headers={
"anthropic-beta": "skills-2025-10-02" # Or any other beta version
},
custom_llm_provider="anthropic"
)
```
The default beta version is configured in `litellm.constants.ANTHROPIC_SKILLS_API_BETA_VERSION`.
## Error Handling
All Skills API functions follow LiteLLM's standard error handling:
```python
import litellm
try:
skill = litellm.create_skill(
display_title="My Skill",
custom_llm_provider="anthropic"
)
except litellm.exceptions.AuthenticationError as e:
print(f"Authentication failed: {e}")
except litellm.exceptions.RateLimitError as e:
print(f"Rate limit exceeded: {e}")
except litellm.exceptions.APIError as e:
print(f"API error: {e}")
```
## Contributing
To add support for Skills API to a new provider:
1. Create provider-specific configuration class inheriting from `BaseSkillsAPIConfig`
2. Implement all abstract methods for request/response transformations
3. Register the config in `ProviderConfigManager.get_provider_skills_api_config()`
4. Add appropriate tests
## Related Documentation
- [Anthropic Skills API Documentation](https://platform.claude.com/docs/en/api/beta/skills/create)
- [LiteLLM Responses API](../../../responses/)
- [Provider Configuration System](../../base_llm/)
## Support
For issues or questions:
- GitHub Issues: https://github.com/BerriAI/litellm/issues
- Discord: https://discord.gg/wuPM9dRgDw

View File

@@ -0,0 +1,204 @@
"""
Anthropic Skills API configuration and transformations
"""
from typing import Any, Dict, Optional, Tuple
import httpx
from litellm._logging import verbose_logger
from litellm.llms.base_llm.skills.transformation import (
BaseSkillsAPIConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.anthropic_skills import (
CreateSkillRequest,
DeleteSkillResponse,
ListSkillsParams,
ListSkillsResponse,
Skill,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
class AnthropicSkillsConfig(BaseSkillsAPIConfig):
"""Anthropic-specific Skills API configuration"""
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.ANTHROPIC
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
"""Add Anthropic-specific headers"""
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
# Get API key
api_key = None
if litellm_params:
api_key = litellm_params.api_key
api_key = AnthropicModelInfo.get_api_key(api_key)
if not api_key:
raise ValueError("ANTHROPIC_API_KEY is required for Skills API")
# Add required headers
headers["x-api-key"] = api_key
headers["anthropic-version"] = "2023-06-01"
# Add beta header for skills API
from litellm.constants import ANTHROPIC_SKILLS_API_BETA_VERSION
if "anthropic-beta" not in headers:
headers["anthropic-beta"] = ANTHROPIC_SKILLS_API_BETA_VERSION
elif isinstance(headers["anthropic-beta"], list):
if ANTHROPIC_SKILLS_API_BETA_VERSION not in headers["anthropic-beta"]:
headers["anthropic-beta"].append(ANTHROPIC_SKILLS_API_BETA_VERSION)
elif isinstance(headers["anthropic-beta"], str):
if ANTHROPIC_SKILLS_API_BETA_VERSION not in headers["anthropic-beta"]:
headers["anthropic-beta"] = [
headers["anthropic-beta"],
ANTHROPIC_SKILLS_API_BETA_VERSION,
]
headers["content-type"] = "application/json"
return headers
def get_complete_url(
self,
api_base: Optional[str],
endpoint: str,
skill_id: Optional[str] = None,
) -> str:
"""Get complete URL for Anthropic Skills API"""
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
if api_base is None:
api_base = AnthropicModelInfo.get_api_base()
if skill_id:
return f"{api_base}/v1/skills/{skill_id}"
return f"{api_base}/v1/{endpoint}"
def transform_create_skill_request(
self,
create_request: CreateSkillRequest,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""Transform create skill request for Anthropic"""
verbose_logger.debug("Transforming create skill request: %s", create_request)
# Anthropic expects the request body directly
request_body = {k: v for k, v in create_request.items() if v is not None}
return request_body
def transform_create_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Skill:
"""Transform Anthropic response to Skill object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming create skill response: %s", response_json)
return Skill(**response_json)
def transform_list_skills_request(
self,
list_params: ListSkillsParams,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform list skills request for Anthropic"""
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
api_base = AnthropicModelInfo.get_api_base(
litellm_params.api_base if litellm_params else None
)
url = self.get_complete_url(api_base=api_base, endpoint="skills")
# Build query parameters
query_params: Dict[str, Any] = {}
if "limit" in list_params and list_params["limit"]:
query_params["limit"] = list_params["limit"]
if "page" in list_params and list_params["page"]:
query_params["page"] = list_params["page"]
if "source" in list_params and list_params["source"]:
query_params["source"] = list_params["source"]
verbose_logger.debug(
"List skills request made to Anthropic Skills endpoint with params: %s",
query_params,
)
return url, query_params
def transform_list_skills_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ListSkillsResponse:
"""Transform Anthropic response to ListSkillsResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming list skills response: %s", response_json)
return ListSkillsResponse(**response_json)
def transform_get_skill_request(
self,
skill_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform get skill request for Anthropic"""
url = self.get_complete_url(
api_base=api_base, endpoint="skills", skill_id=skill_id
)
verbose_logger.debug("Get skill request - URL: %s", url)
return url, headers
def transform_get_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> Skill:
"""Transform Anthropic response to Skill object"""
response_json = raw_response.json()
verbose_logger.debug("Transforming get skill response: %s", response_json)
return Skill(**response_json)
def transform_delete_skill_request(
self,
skill_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""Transform delete skill request for Anthropic"""
url = self.get_complete_url(
api_base=api_base, endpoint="skills", skill_id=skill_id
)
verbose_logger.debug("Delete skill request - URL: %s", url)
return url, headers
def transform_delete_skill_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> DeleteSkillResponse:
"""Transform Anthropic response to DeleteSkillResponse"""
response_json = raw_response.json()
verbose_logger.debug("Transforming delete skill response: %s", response_json)
return DeleteSkillResponse(**response_json)

View File

@@ -0,0 +1,392 @@
"""
AWS Polly Text-to-Speech transformation
Maps OpenAI TTS spec to AWS Polly SynthesizeSpeech API
Reference: https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
"""
import json
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
import httpx
from litellm.llms.base_llm.text_to_speech.transformation import (
BaseTextToSpeechConfig,
TextToSpeechRequestData,
)
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent
else:
LiteLLMLoggingObj = Any
HttpxBinaryResponseContent = Any
class AWSPollyTextToSpeechConfig(BaseTextToSpeechConfig, BaseAWSLLM):
"""
Configuration for AWS Polly Text-to-Speech
Reference: https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
"""
def __init__(self):
BaseTextToSpeechConfig.__init__(self)
BaseAWSLLM.__init__(self)
# Default settings
DEFAULT_VOICE = "Joanna"
DEFAULT_ENGINE = "neural"
DEFAULT_OUTPUT_FORMAT = "mp3"
DEFAULT_REGION = "us-east-1"
# Voice name mappings from OpenAI voices to Polly voices
VOICE_MAPPINGS = {
"alloy": "Joanna", # US English female
"echo": "Matthew", # US English male
"fable": "Amy", # British English female
"onyx": "Brian", # British English male
"nova": "Ivy", # US English female (child)
"shimmer": "Kendra", # US English female
}
# Response format mappings from OpenAI to Polly
FORMAT_MAPPINGS = {
"mp3": "mp3",
"opus": "ogg_vorbis",
"aac": "mp3", # Polly doesn't support AAC, use MP3
"flac": "mp3", # Polly doesn't support FLAC, use MP3
"wav": "pcm",
"pcm": "pcm",
}
# Valid Polly engines
VALID_ENGINES = {"standard", "neural", "long-form", "generative"}
def dispatch_text_to_speech(
self,
model: str,
input: str,
voice: Optional[Union[str, Dict]],
optional_params: Dict,
litellm_params_dict: Dict,
logging_obj: "LiteLLMLoggingObj",
timeout: Union[float, httpx.Timeout],
extra_headers: Optional[Dict[str, Any]],
base_llm_http_handler: Any,
aspeech: bool,
api_base: Optional[str],
api_key: Optional[str],
**kwargs: Any,
) -> Union[
"HttpxBinaryResponseContent",
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
]:
"""
Dispatch method to handle AWS Polly TTS requests
This method encapsulates AWS-specific credential resolution and parameter handling
Args:
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
"""
# Get AWS region from kwargs or environment
aws_region_name = kwargs.get(
"aws_region_name"
) or self._get_aws_region_name_for_polly(optional_params=optional_params)
# Convert voice to string if it's a dict
voice_str: Optional[str] = None
if isinstance(voice, str):
voice_str = voice
elif isinstance(voice, dict):
voice_str = voice.get("name") if voice else None
# Update litellm_params with resolved values
# Note: AWS credentials (aws_access_key_id, aws_secret_access_key, etc.)
# are already in litellm_params_dict via get_litellm_params() in main.py
litellm_params_dict["aws_region_name"] = aws_region_name
litellm_params_dict["api_base"] = api_base
litellm_params_dict["api_key"] = api_key
# Call the text_to_speech_handler
response = base_llm_http_handler.text_to_speech_handler(
model=model,
input=input,
voice=voice_str,
text_to_speech_provider_config=self,
text_to_speech_optional_params=optional_params,
custom_llm_provider="aws_polly",
litellm_params=litellm_params_dict,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=None,
_is_async=aspeech,
)
return response
def _get_aws_region_name_for_polly(self, optional_params: Dict) -> str:
"""Get AWS region name for Polly API calls."""
aws_region_name = optional_params.get("aws_region_name")
if aws_region_name is None:
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls()
return aws_region_name
def get_supported_openai_params(self, model: str) -> list:
"""
AWS Polly TTS supports these OpenAI parameters
"""
return ["voice", "response_format", "speed"]
def map_openai_params(
self,
model: str,
optional_params: Dict,
voice: Optional[Union[str, Dict]] = None,
drop_params: bool = False,
kwargs: Dict = {},
) -> Tuple[Optional[str], Dict]:
"""
Map OpenAI parameters to AWS Polly parameters
"""
mapped_params = {}
# Map voice - support both native Polly voices and OpenAI voice mappings
mapped_voice: Optional[str] = None
if isinstance(voice, str):
if voice in self.VOICE_MAPPINGS:
# OpenAI voice -> Polly voice
mapped_voice = self.VOICE_MAPPINGS[voice]
else:
# Assume it's already a Polly voice name
mapped_voice = voice
# Map response format
if "response_format" in optional_params:
format_name = optional_params["response_format"]
if format_name in self.FORMAT_MAPPINGS:
mapped_params["output_format"] = self.FORMAT_MAPPINGS[format_name]
else:
mapped_params["output_format"] = format_name
else:
mapped_params["output_format"] = self.DEFAULT_OUTPUT_FORMAT
# Extract engine from model name (e.g., "aws_polly/neural" -> "neural")
engine = self._extract_engine_from_model(model)
mapped_params["engine"] = engine
# Pass through Polly-specific parameters (use AWS API casing)
if "language_code" in kwargs:
mapped_params["LanguageCode"] = kwargs["language_code"]
if "lexicon_names" in kwargs:
mapped_params["LexiconNames"] = kwargs["lexicon_names"]
if "sample_rate" in kwargs:
mapped_params["SampleRate"] = kwargs["sample_rate"]
return mapped_voice, mapped_params
def _extract_engine_from_model(self, model: str) -> str:
"""
Extract engine from model name.
Examples:
- aws_polly/neural -> neural
- aws_polly/standard -> standard
- aws_polly/long-form -> long-form
- aws_polly -> neural (default)
"""
if "/" in model:
parts = model.split("/")
if len(parts) >= 2:
engine = parts[1].lower()
if engine in self.VALID_ENGINES:
return engine
return self.DEFAULT_ENGINE
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate AWS environment and set up headers.
AWS SigV4 signing will be done in transform_text_to_speech_request.
"""
validated_headers = headers.copy()
validated_headers["Content-Type"] = "application/json"
return validated_headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for AWS Polly SynthesizeSpeech request
Polly endpoint format:
https://polly.{region}.amazonaws.com/v1/speech
"""
if api_base is not None:
return api_base.rstrip("/") + "/v1/speech"
aws_region_name = litellm_params.get("aws_region_name", self.DEFAULT_REGION)
return f"https://polly.{aws_region_name}.amazonaws.com/v1/speech"
def is_ssml_input(self, input: str) -> bool:
"""
Returns True if input is SSML, False otherwise.
Based on AWS Polly SSML requirements - must contain <speak> tag.
"""
return "<speak>" in input or "<speak " in input
def _sign_polly_request(
self,
request_body: Dict[str, Any],
endpoint_url: str,
litellm_params: Dict,
) -> Tuple[Dict[str, str], str]:
"""
Sign the AWS Polly request using SigV4.
Returns:
Tuple of (signed_headers, json_body_string)
"""
try:
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
except ImportError:
raise ImportError(
"Missing boto3 to call AWS Polly. Run 'pip install boto3'."
)
# Get AWS region
aws_region_name = litellm_params.get("aws_region_name", self.DEFAULT_REGION)
# Get AWS credentials
credentials = self.get_credentials(
aws_access_key_id=litellm_params.get("aws_access_key_id"),
aws_secret_access_key=litellm_params.get("aws_secret_access_key"),
aws_session_token=litellm_params.get("aws_session_token"),
aws_region_name=aws_region_name,
aws_session_name=litellm_params.get("aws_session_name"),
aws_profile_name=litellm_params.get("aws_profile_name"),
aws_role_name=litellm_params.get("aws_role_name"),
aws_web_identity_token=litellm_params.get("aws_web_identity_token"),
aws_sts_endpoint=litellm_params.get("aws_sts_endpoint"),
aws_external_id=litellm_params.get("aws_external_id"),
)
# Serialize request body to JSON
json_body = json.dumps(request_body)
# Create headers for signing
headers = {
"Content-Type": "application/json",
}
# Create AWS request for signing
aws_request = AWSRequest(
method="POST",
url=endpoint_url,
data=json_body,
headers=headers,
)
# Sign the request
SigV4Auth(credentials, "polly", aws_region_name).add_auth(aws_request)
# Return signed headers and body
return dict(aws_request.headers), json_body
def transform_text_to_speech_request(
self,
model: str,
input: str,
voice: Optional[str],
optional_params: Dict,
litellm_params: Dict,
headers: dict,
) -> TextToSpeechRequestData:
"""
Transform OpenAI TTS request to AWS Polly SynthesizeSpeech format.
Supports:
- Native Polly voices (Joanna, Matthew, etc.)
- OpenAI voice mapping (alloy, echo, etc.)
- SSML input (auto-detected via <speak> tag)
- Multiple engines (neural, standard, long-form, generative)
Returns:
TextToSpeechRequestData: Contains signed request for Polly API
"""
# Get voice (already mapped in main.py, or use default)
polly_voice = voice or self.DEFAULT_VOICE
# Get output format
output_format = optional_params.get("output_format", self.DEFAULT_OUTPUT_FORMAT)
# Get engine
engine = optional_params.get("engine", self.DEFAULT_ENGINE)
# Build request body
request_body: Dict[str, Any] = {
"Engine": engine,
"OutputFormat": output_format,
"Text": input,
"VoiceId": polly_voice,
}
# Auto-detect SSML
if self.is_ssml_input(input):
request_body["TextType"] = "ssml"
else:
request_body["TextType"] = "text"
# Add optional Polly parameters (already in AWS casing from map_openai_params)
for key in ["LanguageCode", "LexiconNames", "SampleRate"]:
if key in optional_params:
request_body[key] = optional_params[key]
# Get endpoint URL
endpoint_url = self.get_complete_url(
model=model,
api_base=litellm_params.get("api_base"),
litellm_params=litellm_params,
)
# Sign the request with AWS SigV4
signed_headers, json_body = self._sign_polly_request(
request_body=request_body,
endpoint_url=endpoint_url,
litellm_params=litellm_params,
)
# Return as ssml_body so the handler uses data= instead of json=
# This preserves the exact JSON string that was signed
return TextToSpeechRequestData(
ssml_body=json_body,
headers=signed_headers,
)
def transform_text_to_speech_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: "LiteLLMLoggingObj",
) -> "HttpxBinaryResponseContent":
"""
Transform AWS Polly response to standard format.
Polly returns the audio data directly in the response body.
"""
from litellm.types.llms.openai import HttpxBinaryResponseContent
return HttpxBinaryResponseContent(raw_response)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,198 @@
from litellm._uuid import uuid
from typing import Any, Coroutine, Optional, Union
from openai import AsyncAzureOpenAI, AzureOpenAI
from pydantic import BaseModel
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
from litellm.types.utils import FileTypes
from litellm.utils import (
TranscriptionResponse,
convert_to_model_response_object,
extract_duration_from_srt_or_vtt,
)
from .azure import AzureChatCompletion
from .common_utils import AzureOpenAIError
class AzureAudioTranscription(AzureChatCompletion):
def audio_transcriptions(
self,
model: str,
audio_file: FileTypes,
optional_params: dict,
logging_obj: Any,
model_response: TranscriptionResponse,
timeout: float,
max_retries: int,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
client=None,
azure_ad_token: Optional[str] = None,
atranscription: bool = False,
litellm_params: Optional[dict] = None,
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
data = {"model": model, "file": audio_file, **optional_params}
if atranscription is True:
return self.async_audio_transcriptions(
audio_file=audio_file,
data=data,
model_response=model_response,
timeout=timeout,
api_key=api_key,
api_base=api_base,
client=client,
max_retries=max_retries,
logging_obj=logging_obj,
model=model,
litellm_params=litellm_params,
)
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=False,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=f"audio_file_{uuid.uuid4()}",
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
response = azure_client.audio.transcriptions.create(
**data, timeout=timeout # type: ignore
)
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
hidden_params = {"model": model, "custom_llm_provider": "azure"}
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
return final_response
async def async_audio_transcriptions(
self,
audio_file: FileTypes,
model: str,
data: dict,
model_response: TranscriptionResponse,
timeout: float,
logging_obj: Any,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
client=None,
max_retries=None,
litellm_params: Optional[dict] = None,
) -> TranscriptionResponse:
response = None
try:
async_azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(async_azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="async_azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=f"audio_file_{uuid.uuid4()}",
api_key=async_azure_client.api_key,
additional_args={
"headers": {
"Authorization": f"Bearer {async_azure_client.api_key}"
},
"api_base": async_azure_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
)
raw_response = (
await async_azure_client.audio.transcriptions.with_raw_response.create(
**data, timeout=timeout
)
) # type: ignore
headers = dict(raw_response.headers)
response = raw_response.parse()
if isinstance(response, BaseModel):
stringified_response = response.model_dump()
else:
stringified_response = TranscriptionResponse(text=response).model_dump()
duration = extract_duration_from_srt_or_vtt(response)
stringified_response["_audio_transcription_duration"] = duration
## LOGGING
logging_obj.post_call(
input=get_audio_file_name(audio_file),
api_key=api_key,
additional_args={
"headers": {
"Authorization": f"Bearer {async_azure_client.api_key}"
},
"api_base": async_azure_client._base_url._uri_reference,
"atranscription": True,
"complete_input_dict": data,
},
original_response=stringified_response,
)
hidden_params = {"model": model, "custom_llm_provider": "azure"}
response = convert_to_model_response_object(
_response_headers=headers,
response_object=stringified_response,
model_response_object=model_response,
hidden_params=hidden_params,
response_type="audio_transcription",
)
if not isinstance(response, TranscriptionResponse):
raise AzureOpenAIError(
status_code=500,
message="response is not an instance of TranscriptionResponse",
)
return response
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,233 @@
"""
Azure Batches API Handler
"""
from typing import Any, Coroutine, Optional, Union, cast
import httpx
from openai import AsyncOpenAI, OpenAI
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
from litellm.types.llms.openai import (
CancelBatchRequest,
CreateBatchRequest,
RetrieveBatchRequest,
)
from litellm.types.utils import LiteLLMBatch
from ..common_utils import BaseAzureLLM
class AzureBatchesAPI(BaseAzureLLM):
"""
Azure methods to support for batches
- create_batch()
- retrieve_batch()
- cancel_batch()
- list_batch()
"""
def __init__(self) -> None:
super().__init__()
async def acreate_batch(
self,
create_batch_data: CreateBatchRequest,
azure_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> LiteLLMBatch:
response = await azure_client.batches.create(**create_batch_data) # type: ignore[arg-type]
return LiteLLMBatch(**response.model_dump())
def create_batch(
self,
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.acreate_batch( # type: ignore
create_batch_data=create_batch_data, azure_client=azure_client
)
response = cast(Union[AzureOpenAI, OpenAI], azure_client).batches.create(**create_batch_data) # type: ignore[arg-type]
return LiteLLMBatch(**response.model_dump())
async def aretrieve_batch(
self,
retrieve_batch_data: RetrieveBatchRequest,
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> LiteLLMBatch:
response = await client.batches.retrieve(**retrieve_batch_data) # type: ignore[arg-type]
return LiteLLMBatch(**response.model_dump())
def retrieve_batch(
self,
_is_async: bool,
retrieve_batch_data: RetrieveBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.aretrieve_batch( # type: ignore
retrieve_batch_data=retrieve_batch_data, client=azure_client
)
response = cast(Union[AzureOpenAI, OpenAI], azure_client).batches.retrieve(
**retrieve_batch_data
)
return LiteLLMBatch(**response.model_dump())
async def acancel_batch(
self,
cancel_batch_data: CancelBatchRequest,
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> LiteLLMBatch:
response = await client.batches.cancel(**cancel_batch_data)
return LiteLLMBatch(**response.model_dump())
def cancel_batch(
self,
_is_async: bool,
cancel_batch_data: CancelBatchRequest,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"Azure client is not an instance of AsyncAzureOpenAI or AsyncOpenAI. Make sure you passed an async client."
)
return self.acancel_batch( # type: ignore
cancel_batch_data=cancel_batch_data, client=azure_client
)
# At this point, azure_client is guaranteed to be a sync client
if not isinstance(azure_client, (AzureOpenAI, OpenAI)):
raise ValueError(
"Azure client is not an instance of AzureOpenAI or OpenAI. Make sure you passed a sync client."
)
response = azure_client.batches.cancel(**cancel_batch_data)
return LiteLLMBatch(**response.model_dump())
async def alist_batches(
self,
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
after: Optional[str] = None,
limit: Optional[int] = None,
):
response = await client.batches.list(after=after, limit=limit) # type: ignore
return response
def list_batches(
self,
_is_async: bool,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
after: Optional[str] = None,
limit: Optional[int] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
azure_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
litellm_params=litellm_params or {},
)
if azure_client is None:
raise ValueError(
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
)
return self.alist_batches( # type: ignore
client=azure_client, after=after, limit=limit
)
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
return response

View File

@@ -0,0 +1,160 @@
"""Support for Azure OpenAI gpt-5 model family."""
from typing import List
import litellm
from litellm.exceptions import UnsupportedParamsError
from litellm.llms.openai.chat.gpt_5_transformation import (
OpenAIGPT5Config,
_get_effort_level,
)
from litellm.types.llms.openai import AllMessageValues
from .gpt_transformation import AzureOpenAIConfig
class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
"""Azure specific handling for gpt-5 models."""
GPT5_SERIES_ROUTE = "gpt5_series/"
@classmethod
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
"""Override to handle gpt5_series/ prefix used for Azure routing.
The parent class calls ``_supports_factory(model, custom_llm_provider=None)``
which fails to resolve ``gpt5_series/gpt-5.1`` to the correct Azure model
entry. Strip the prefix and prepend ``azure/`` so the lookup finds
``azure/gpt-5.1`` in model_prices_and_context_window.json.
"""
if model.startswith(cls.GPT5_SERIES_ROUTE):
model = "azure/" + model[len(cls.GPT5_SERIES_ROUTE) :]
elif not model.startswith("azure/"):
model = "azure/" + model
return super()._supports_reasoning_effort_level(model, level)
@classmethod
def is_model_gpt_5_model(cls, model: str) -> bool:
"""Check if the Azure model string refers to a gpt-5 variant.
Accepts both explicit gpt-5 model names and the ``gpt5_series/`` prefix
used for manual routing.
"""
# gpt-5-chat* is a chat model and shouldn't go through GPT-5 reasoning restrictions.
return (
"gpt-5" in model and "gpt-5-chat" not in model
) or "gpt5_series" in model
def get_supported_openai_params(self, model: str) -> List[str]:
"""Get supported parameters for Azure OpenAI GPT-5 models.
Azure OpenAI GPT-5.2/5.4 models support logprobs, unlike OpenAI's GPT-5.
This overrides the parent class to add logprobs support back for gpt-5.2+.
Reference:
- Tested with Azure OpenAI GPT-5.2 (api-version: 2025-01-01-preview)
- Azure returns logprobs successfully despite Microsoft's general
documentation stating reasoning models don't support it.
"""
params = OpenAIGPT5Config.get_supported_openai_params(self, model=model)
# Azure supports tool_choice for GPT-5 deployments, but the base GPT-5 config
# can drop it when the deployment name isn't in the OpenAI model registry.
if "tool_choice" not in params:
params.append("tool_choice")
# Only gpt-5.2+ has been verified to support logprobs on Azure.
# The base OpenAI class includes logprobs for gpt-5.1+, but Azure
# hasn't verified support for gpt-5.1, so remove them unless gpt-5.2/5.4+.
if self._supports_reasoning_effort_level(
model, "none"
) and not self.is_model_gpt_5_2_model(model):
params = [p for p in params if p not in ["logprobs", "top_logprobs"]]
elif self.is_model_gpt_5_2_model(model):
azure_supported_params = ["logprobs", "top_logprobs"]
params.extend(azure_supported_params)
return params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
api_version: str = "",
) -> dict:
reasoning_effort_value = non_default_params.get(
"reasoning_effort"
) or optional_params.get("reasoning_effort")
effective_effort = _get_effort_level(reasoning_effort_value)
# gpt-5.1/5.2/5.4 support reasoning_effort='none', but other gpt-5 models don't
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/reasoning
supports_none = self._supports_reasoning_effort_level(model, "none")
if effective_effort == "none" and not supports_none:
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
non_default_params = non_default_params.copy()
optional_params = optional_params.copy()
if (
_get_effort_level(non_default_params.get("reasoning_effort"))
== "none"
):
non_default_params.pop("reasoning_effort")
if _get_effort_level(optional_params.get("reasoning_effort")) == "none":
optional_params.pop("reasoning_effort")
else:
raise UnsupportedParamsError(
status_code=400,
message=(
"Azure OpenAI does not support reasoning_effort='none' for this model. "
"Supported values are: 'low', 'medium', and 'high'. "
"To drop this parameter, set `litellm.drop_params=True` or for proxy:\n\n"
"`litellm_settings:\n drop_params: true`\n"
"Issue: https://github.com/BerriAI/litellm/issues/16704"
),
)
result = OpenAIGPT5Config.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
# Only drop reasoning_effort='none' for models that don't support it
result_effort = _get_effort_level(result.get("reasoning_effort"))
if result_effort == "none" and not supports_none:
result.pop("reasoning_effort")
# Azure Chat Completions: gpt-5.4+ does not support tools + reasoning together.
# Drop reasoning_effort when both are present (OpenAI routes to Responses API; Azure does not).
if self.is_model_gpt_5_4_plus_model(model):
has_tools = bool(
non_default_params.get("tools") or optional_params.get("tools")
)
if has_tools and result_effort not in (None, "none"):
result.pop("reasoning_effort", None)
return result
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
model = model.replace(self.GPT5_SERIES_ROUTE, "")
return super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)

View File

@@ -0,0 +1,334 @@
from typing import TYPE_CHECKING, Any, List, Optional, Union
from httpx._models import Headers, Response
import litellm
from litellm.litellm_core_utils.prompt_templates.factory import (
convert_to_azure_openai_messages,
)
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.types.llms.azure import (
API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT,
API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT,
)
from litellm.types.utils import ModelResponse
from ....exceptions import UnsupportedParamsError
from ....types.llms.openai import AllMessageValues
from ...base_llm.chat.transformation import BaseConfig
from ..common_utils import AzureOpenAIError
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
LoggingClass = LiteLLMLoggingObj
else:
LoggingClass = Any
class AzureOpenAIConfig(BaseConfig):
"""
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
function_call: Optional[Union[str, dict]] = None,
functions: Optional[list] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return super().get_config()
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"temperature",
"n",
"stream",
"stream_options",
"stop",
"max_tokens",
"max_completion_tokens",
"tools",
"tool_choice",
"presence_penalty",
"frequency_penalty",
"logit_bias",
"user",
"function_call",
"functions",
"tools",
"tool_choice",
"top_p",
"logprobs",
"top_logprobs",
"response_format",
"seed",
"extra_headers",
"parallel_tool_calls",
"prediction",
"modalities",
"audio",
"web_search_options",
"prompt_cache_key",
"store",
]
def _is_response_format_supported_model(self, model: str) -> bool:
"""
Determines if the model supports response_format.
- Handles Azure deployment names (e.g., azure/gpt-4.1-suffix)
- Normalizes model names (e.g., gpt-4-1 -> gpt-4.1)
- Strips deployment-specific suffixes
- Passes provider to supports_response_schema
- Backwards compatible with previous model name patterns
"""
import re
# Normalize model name: e.g., gpt-3-5-turbo -> gpt-3.5-turbo
normalized_model = re.sub(r"(\d)-(\d)", r"\1.\2", model)
if "gpt-3.5" in normalized_model or "gpt-35" in model:
return False
return True
def _is_response_format_supported_api_version(
self, api_version_year: str, api_version_month: str
) -> bool:
"""
- check if api_version is supported for response_format
- returns True if the API version is equal to or newer than the supported version
"""
api_year = int(api_version_year)
api_month = int(api_version_month)
supported_year = int(API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT)
supported_month = int(API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT)
# If the year is greater than supported year, it's definitely supported
if api_year > supported_year:
return True
# If the year is less than supported year, it's not supported
elif api_year < supported_year:
return False
# If same year, check if month is >= supported month
else:
return api_month >= supported_month
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
api_version: str = "",
) -> dict:
supported_openai_params = self.get_supported_openai_params(model)
api_version_times = api_version.split("-")
if len(api_version_times) >= 3:
api_version_year = api_version_times[0]
api_version_month = api_version_times[1]
api_version_day = api_version_times[2]
else:
api_version_year = None
api_version_month = None
api_version_day = None
for param, value in non_default_params.items():
if param == "tool_choice":
"""
This parameter requires API version 2023-12-01-preview or later
tool_choice='required' is not supported as of 2024-05-01-preview
"""
## check if api version supports this param ##
if (
api_version_year is None
or api_version_month is None
or api_version_day is None
):
optional_params["tool_choice"] = value
else:
if (
api_version_year < "2023"
or (api_version_year == "2023" and api_version_month < "12")
or (
api_version_year == "2023"
and api_version_month == "12"
and api_version_day < "01"
)
):
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
raise UnsupportedParamsError(
status_code=400,
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
)
elif value == "required" and (
api_version_year == "2024" and api_version_month <= "05"
): ## check if tool_choice value is supported ##
if litellm.drop_params is True or (
drop_params is not None and drop_params is True
):
pass
else:
raise UnsupportedParamsError(
status_code=400,
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
)
else:
optional_params["tool_choice"] = value
elif param == "response_format" and isinstance(value, dict):
_is_response_format_supported_model = (
self._is_response_format_supported_model(model)
)
if api_version_year is None or api_version_month is None:
is_response_format_supported_api_version = True
else:
is_response_format_supported_api_version = (
self._is_response_format_supported_api_version(
api_version_year, api_version_month
)
)
is_response_format_supported = (
is_response_format_supported_api_version
and _is_response_format_supported_model
)
optional_params = self._add_response_format_to_tools(
optional_params=optional_params,
value=value,
is_response_format_supported=is_response_format_supported,
)
elif param == "tools" and isinstance(value, list):
optional_params.setdefault("tools", [])
optional_params["tools"].extend(value)
elif param in supported_openai_params:
optional_params[param] = value
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
messages = convert_to_azure_openai_messages(messages)
return {
"model": model,
"messages": messages,
**optional_params,
}
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LoggingClass,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
raise NotImplementedError(
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
)
def get_mapped_special_auth_params(self) -> dict:
return {"token": "azure_ad_token"}
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "token":
optional_params["azure_ad_token"] = value
return optional_params
def get_eu_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return ["europe", "sweden", "switzerland", "france", "uk"]
def get_us_regions(self) -> List[str]:
"""
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
"""
return [
"us",
"eastus",
"eastus2",
"eastus2euap",
"eastus3",
"southcentralus",
"westus",
"westus2",
"westus3",
"westus4",
]
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, Headers]
) -> BaseLLMException:
return AzureOpenAIError(
message=error_message, status_code=status_code, headers=headers
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
raise NotImplementedError(
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
)

View File

@@ -0,0 +1,77 @@
"""
Handler file for calls to Azure OpenAI's o1/o3 family of models
Written separately to handle faking streaming for o1 and o3 models.
"""
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
import httpx
from litellm.types.utils import ModelResponse
from ...openai.openai import OpenAIChatCompletion
from ..common_utils import BaseAzureLLM
if TYPE_CHECKING:
from aiohttp import ClientSession
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
def completion(
self,
model_response: ModelResponse,
timeout: Union[float, httpx.Timeout],
optional_params: dict,
litellm_params: dict,
logging_obj: Any,
model: Optional[str] = None,
messages: Optional[list] = None,
print_verbose: Optional[Callable] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
dynamic_params: Optional[bool] = None,
azure_ad_token: Optional[str] = None,
acompletion: bool = False,
logger_fn=None,
headers: Optional[dict] = None,
custom_prompt_dict: dict = {},
client=None,
organization: Optional[str] = None,
custom_llm_provider: Optional[str] = None,
drop_params: Optional[bool] = None,
shared_session: Optional["ClientSession"] = None,
):
client = self.get_azure_openai_client(
litellm_params=litellm_params,
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=acompletion,
)
return super().completion(
model_response=model_response,
timeout=timeout,
optional_params=optional_params,
litellm_params=litellm_params,
logging_obj=logging_obj,
model=model,
messages=messages,
print_verbose=print_verbose,
api_key=api_key,
api_base=api_base,
api_version=api_version,
dynamic_params=dynamic_params,
azure_ad_token=azure_ad_token,
acompletion=acompletion,
logger_fn=logger_fn,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
client=client,
organization=organization,
custom_llm_provider=custom_llm_provider,
drop_params=drop_params,
shared_session=shared_session,
)

View File

@@ -0,0 +1,123 @@
"""
Support for o1 and o3 model families
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- modalities: image => drop param (if user opts in to dropping param)
- role: system ==> translate to role 'user'
- streaming => faked by LiteLLM
- Tools, response_format => drop param (if user opts in to dropping param)
- Logprobs => drop param (if user opts in to dropping param)
- Temperature => drop param (if user opts in to dropping param)
"""
from typing import List, Optional
import litellm
from litellm import verbose_logger
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import get_model_info, supports_reasoning
from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
def get_supported_openai_params(self, model: str) -> list:
"""
Get the supported OpenAI params for the Azure O-Series models
"""
all_openai_params = litellm.OpenAIGPTConfig().get_supported_openai_params(
model=model
)
non_supported_params = [
"logprobs",
"top_p",
"presence_penalty",
"frequency_penalty",
"top_logprobs",
]
o_series_only_param = self._get_o_series_only_params(model)
all_openai_params.extend(o_series_only_param)
return [
param for param in all_openai_params if param not in non_supported_params
]
def _get_o_series_only_params(self, model: str) -> list:
"""
Helper function to get the o-series only params for the model
- reasoning_effort
"""
o_series_only_param = []
#########################################################
# Case 1: If the model is recognized and in litellm model cost map
# then check if it supports reasoning
#########################################################
if model in litellm.model_list_set:
if supports_reasoning(model):
o_series_only_param.append("reasoning_effort")
#########################################################
# Case 2: If the model is not recognized, then we assume it supports reasoning
# This is critical because several users tend to use custom deployment names
# for azure o-series models.
#########################################################
else:
o_series_only_param.append("reasoning_effort")
return o_series_only_param
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Currently no Azure O Series models support native streaming.
"""
if stream is not True:
return False
if (
model and "o3" in model
): # o3 models support streaming - https://github.com/BerriAI/litellm/issues/8274
return False
if model is not None:
try:
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider
) # allow user to override default with model_info={"supports_native_streaming": true}
if (
model_info.get("supports_native_streaming") is True
): # allow user to override default with model_info={"supports_native_streaming": true}
return False
except Exception as e:
verbose_logger.debug(
f"Error getting model info in AzureOpenAIO1Config: {e}"
)
return True
def is_o_series_model(self, model: str) -> bool:
return "o1" in model or "o3" in model or "o4" in model or "o_series/" in model
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
model = model.replace(
"o_series/", ""
) # handle o_series/my-random-deployment-name
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)

View File

@@ -0,0 +1,844 @@
import json
import os
from typing import Any, Callable, Dict, Literal, NamedTuple, Optional, Union, cast
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
import litellm
from litellm._logging import verbose_logger
from litellm.caching.caching import DualCache
from litellm.llms.base_llm.chat.transformation import BaseLLMException
from litellm.llms.openai.common_utils import BaseOpenAILLM
from litellm.secret_managers.get_azure_ad_token_provider import (
get_azure_ad_token_provider,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import _add_path_to_api_base
azure_ad_cache = DualCache()
class AzureOpenAIError(BaseLLMException):
def __init__(
self,
status_code,
message,
request: Optional[httpx.Request] = None,
response: Optional[httpx.Response] = None,
headers: Optional[Union[httpx.Headers, dict]] = None,
body: Optional[dict] = None,
):
super().__init__(
status_code=status_code,
message=message,
request=request,
response=response,
headers=headers,
body=body,
)
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
openai_headers = {}
if "x-ratelimit-limit-requests" in headers:
openai_headers["x-ratelimit-limit-requests"] = headers[
"x-ratelimit-limit-requests"
]
if "x-ratelimit-remaining-requests" in headers:
openai_headers["x-ratelimit-remaining-requests"] = headers[
"x-ratelimit-remaining-requests"
]
if "x-ratelimit-limit-tokens" in headers:
openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"]
if "x-ratelimit-remaining-tokens" in headers:
openai_headers["x-ratelimit-remaining-tokens"] = headers[
"x-ratelimit-remaining-tokens"
]
llm_response_headers = {
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
}
return {**llm_response_headers, **openai_headers}
def get_azure_ad_token_from_entra_id(
tenant_id: str,
client_id: str,
client_secret: str,
scope: str = "https://cognitiveservices.azure.com/.default",
) -> Callable[[], str]:
"""
Get Azure AD token provider from `client_id`, `client_secret`, and `tenant_id`
Args:
tenant_id: str
client_id: str
client_secret: str
scope: str
Returns:
callable that returns a bearer token.
"""
from azure.identity import ClientSecretCredential, get_bearer_token_provider
verbose_logger.debug("Getting Azure AD Token from Entra ID")
if tenant_id.startswith("os.environ/"):
_tenant_id = get_secret_str(tenant_id)
else:
_tenant_id = tenant_id
if client_id.startswith("os.environ/"):
_client_id = get_secret_str(client_id)
else:
_client_id = client_id
if client_secret.startswith("os.environ/"):
_client_secret = get_secret_str(client_secret)
else:
_client_secret = client_secret
verbose_logger.debug(
"tenant_id %s, client_id %s, client_secret %s",
_tenant_id,
_client_id,
_client_secret,
)
if _tenant_id is None or _client_id is None or _client_secret is None:
raise ValueError("tenant_id, client_id, and client_secret must be provided")
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
verbose_logger.debug("credential %s", credential)
token_provider = get_bearer_token_provider(credential, scope)
verbose_logger.debug("token_provider %s", token_provider)
return token_provider
def get_azure_ad_token_from_username_password(
client_id: str,
azure_username: str,
azure_password: str,
scope: str = "https://cognitiveservices.azure.com/.default",
) -> Callable[[], str]:
"""
Get Azure AD token provider from `client_id`, `azure_username`, and `azure_password`
Args:
client_id: str
azure_username: str
azure_password: str
scope: str
Returns:
callable that returns a bearer token.
"""
from azure.identity import UsernamePasswordCredential, get_bearer_token_provider
verbose_logger.debug(
"client_id %s, azure_username %s, azure_password %s",
client_id,
azure_username,
azure_password,
)
credential = UsernamePasswordCredential(
client_id=client_id,
username=azure_username,
password=azure_password,
)
verbose_logger.debug("credential %s", credential)
token_provider = get_bearer_token_provider(credential, scope)
verbose_logger.debug("token_provider %s", token_provider)
return token_provider
def get_azure_ad_token_from_oidc(
azure_ad_token: str,
azure_client_id: Optional[str] = None,
azure_tenant_id: Optional[str] = None,
scope: Optional[str] = None,
) -> str:
"""
Get Azure AD token from OIDC token
Args:
azure_ad_token: str
azure_client_id: Optional[str]
azure_tenant_id: Optional[str]
scope: str
Returns:
`azure_ad_token_access_token` - str
"""
if scope is None:
scope = "https://cognitiveservices.azure.com/.default"
azure_authority_host = os.getenv(
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
)
azure_client_id = azure_client_id or os.getenv("AZURE_CLIENT_ID")
azure_tenant_id = azure_tenant_id or os.getenv("AZURE_TENANT_ID")
if azure_client_id is None or azure_tenant_id is None:
raise AzureOpenAIError(
status_code=422,
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
)
oidc_token = get_secret_str(azure_ad_token)
if oidc_token is None:
raise AzureOpenAIError(
status_code=401,
message="OIDC token could not be retrieved from secret manager.",
)
azure_ad_token_cache_key = json.dumps(
{
"azure_client_id": azure_client_id,
"azure_tenant_id": azure_tenant_id,
"azure_authority_host": azure_authority_host,
"oidc_token": oidc_token,
}
)
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
if azure_ad_token_access_token is not None:
return azure_ad_token_access_token
client = litellm.module_level_client
req_token = client.post(
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
data={
"client_id": azure_client_id,
"grant_type": "client_credentials",
"scope": scope,
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": oidc_token,
},
)
if req_token.status_code != 200:
raise AzureOpenAIError(
status_code=req_token.status_code,
message=req_token.text,
)
azure_ad_token_json = req_token.json()
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
if azure_ad_token_access_token is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token access_token not returned"
)
if azure_ad_token_expires_in is None:
raise AzureOpenAIError(
status_code=422, message="Azure AD Token expires_in not returned"
)
azure_ad_cache.set_cache(
key=azure_ad_token_cache_key,
value=azure_ad_token_access_token,
ttl=azure_ad_token_expires_in,
)
return azure_ad_token_access_token
def select_azure_base_url_or_endpoint(azure_client_params: dict):
azure_endpoint = azure_client_params.get("azure_endpoint", None)
if azure_endpoint is not None:
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
if "/openai/deployments" in azure_endpoint:
# this is base_url, not an azure_endpoint
azure_client_params["base_url"] = azure_endpoint
azure_client_params.pop("azure_endpoint")
return azure_client_params
def get_azure_ad_token(
litellm_params: GenericLiteLLMParams,
) -> Optional[str]:
"""
Get Azure AD token from various authentication methods.
This function tries different methods to obtain an Azure AD token:
1. From an existing token provider
2. From Entra ID using tenant_id, client_id, and client_secret
3. From username and password
4. From OIDC token
5. From a service principal with secret workflow
6. From DefaultAzureCredential
Args:
litellm_params: Dictionary containing authentication parameters
- azure_ad_token_provider: Optional callable that returns a token
- azure_ad_token: Optional existing token
- tenant_id: Optional Azure tenant ID
- client_id: Optional Azure client ID
- client_secret: Optional Azure client secret
- azure_username: Optional Azure username
- azure_password: Optional Azure password
Returns:
Azure AD token as string if successful, None otherwise
"""
# Extract parameters
# Use `or` instead of default parameter to handle cases where key exists but value is None
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
azure_ad_token = litellm_params.get("azure_ad_token") or get_secret_str(
"AZURE_AD_TOKEN"
)
tenant_id = litellm_params.get("tenant_id") or os.getenv("AZURE_TENANT_ID")
client_id = litellm_params.get("client_id") or os.getenv("AZURE_CLIENT_ID")
client_secret = litellm_params.get("client_secret") or os.getenv(
"AZURE_CLIENT_SECRET"
)
azure_username = litellm_params.get("azure_username") or os.getenv("AZURE_USERNAME")
azure_password = litellm_params.get("azure_password") or os.getenv("AZURE_PASSWORD")
scope = litellm_params.get("azure_scope") or os.getenv(
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
)
if scope is None:
scope = "https://cognitiveservices.azure.com/.default"
# Try to get token provider from Entra ID
if azure_ad_token_provider is None and tenant_id and client_id and client_secret:
verbose_logger.debug(
"Using Azure AD Token Provider from Entra ID for Azure Auth"
)
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
scope=scope,
)
# Try to get token provider from username and password
if (
azure_ad_token_provider is None
and azure_username
and azure_password
and client_id
):
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
scope=scope,
)
# Try to get token from OIDC
if (
client_id
and tenant_id
and azure_ad_token
and azure_ad_token.startswith("oidc/")
):
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
azure_tenant_id=tenant_id,
scope=scope,
)
# Try to get token provider from service principal or DefaultAzureCredential
elif (
azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
verbose_logger.debug(
"Using Azure AD token provider based on Service Principal with Secret workflow or DefaultAzureCredential for Azure Auth"
)
try:
azure_ad_token_provider = get_azure_ad_token_provider(azure_scope=scope)
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")
except Exception as e:
verbose_logger.error(
f"Error calling Azure AD token provider: {str(e)}. Follow docs - https://docs.litellm.ai/docs/providers/azure/#azure-ad-token-refresh---defaultazurecredential"
)
raise e
#########################################################
# If litellm.enable_azure_ad_token_refresh is True and no other token provider is available,
# try to get DefaultAzureCredential provider
#########################################################
if azure_ad_token_provider is None and azure_ad_token is None:
azure_ad_token_provider = (
BaseAzureLLM._try_get_default_azure_credential_provider(
scope=scope,
)
)
# Execute the token provider to get the token if available
if azure_ad_token_provider and callable(azure_ad_token_provider):
try:
token = azure_ad_token_provider()
if not isinstance(token, str):
verbose_logger.error(
f"Azure AD token provider returned non-string value: {type(token)}"
)
raise TypeError(f"Azure AD token must be a string, got {type(token)}")
else:
azure_ad_token = token
except TypeError:
# Re-raise TypeError directly
raise
except Exception as e:
verbose_logger.error(f"Error calling Azure AD token provider: {str(e)}")
raise RuntimeError(f"Failed to get Azure AD token: {str(e)}") from e
return azure_ad_token
class BaseAzureLLM(BaseOpenAILLM):
@staticmethod
def _try_get_default_azure_credential_provider(
scope: str,
) -> Optional[Callable[[], str]]:
"""
Try to get DefaultAzureCredential provider
Args:
scope: Azure scope for the token
Returns:
Token provider callable if DefaultAzureCredential is enabled and available, None otherwise
"""
from litellm.types.secret_managers.get_azure_ad_token_provider import (
AzureCredentialType,
)
verbose_logger.debug("Attempting to use DefaultAzureCredential for Azure Auth")
try:
azure_ad_token_provider = get_azure_ad_token_provider(
azure_scope=scope,
azure_credential=AzureCredentialType.DefaultAzureCredential,
)
verbose_logger.debug(
"Successfully obtained Azure AD token provider using DefaultAzureCredential"
)
return azure_ad_token_provider
except Exception as e:
verbose_logger.debug(f"DefaultAzureCredential failed: {str(e)}")
return None
def get_azure_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
api_version: Optional[str] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
_is_async: bool = False,
model: Optional[str] = None,
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]]:
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None
client_initialization_params: dict = locals()
client_initialization_params["is_async"] = _is_async
if client is None:
cached_client = self.get_cached_openai_client(
client_initialization_params=client_initialization_params,
client_type="azure",
)
if cached_client:
if isinstance(
cached_client, (AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI)
):
return cached_client
azure_client_params = self.initialize_azure_sdk_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
model_name=model,
api_version=api_version,
is_async=_is_async,
)
# For Azure v1 API, use standard OpenAI client instead of AzureOpenAI
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
if self._is_azure_v1_api_version(api_version):
# Extract only params that OpenAI client accepts
# Always use /openai/v1/ regardless of whether user passed "v1", "latest", or "preview"
v1_params = {
"api_key": azure_client_params.get("api_key"),
"base_url": f"{api_base}/openai/v1/",
}
if "timeout" in azure_client_params:
v1_params["timeout"] = azure_client_params["timeout"]
if "max_retries" in azure_client_params:
v1_params["max_retries"] = azure_client_params["max_retries"]
if "http_client" in azure_client_params:
v1_params["http_client"] = azure_client_params["http_client"]
verbose_logger.debug(
f"Using Azure v1 API with base_url: {v1_params['base_url']}"
)
if _is_async is True:
openai_client = AsyncOpenAI(**v1_params) # type: ignore
else:
openai_client = OpenAI(**v1_params) # type: ignore
else:
# Traditional Azure API uses AzureOpenAI client
if _is_async is True:
openai_client = AsyncAzureOpenAI(**azure_client_params)
else:
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
openai_client = client
if (
api_version is not None
and isinstance(openai_client, (AzureOpenAI, AsyncAzureOpenAI))
and isinstance(openai_client._custom_query, dict)
):
# set api_version to version passed by user
openai_client._custom_query.setdefault("api-version", api_version)
# save client in-memory cache
self.set_cached_openai_client(
openai_client=openai_client,
client_initialization_params=client_initialization_params,
client_type="azure",
)
return openai_client
def initialize_azure_sdk_client(
self,
litellm_params: dict,
api_key: Optional[str],
api_base: Optional[str],
model_name: Optional[str],
api_version: Optional[str],
is_async: bool,
) -> dict:
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
# If we have api_key, then we have higher priority
azure_ad_token = litellm_params.get("azure_ad_token")
# litellm_params sometimes contains the key, but the value is None
# We should respect environment variables in this case
tenant_id = self._resolve_env_var(
litellm_params, "tenant_id", "AZURE_TENANT_ID"
)
client_id = self._resolve_env_var(
litellm_params, "client_id", "AZURE_CLIENT_ID"
)
client_secret = self._resolve_env_var(
litellm_params, "client_secret", "AZURE_CLIENT_SECRET"
)
azure_username = self._resolve_env_var(
litellm_params, "azure_username", "AZURE_USERNAME"
)
azure_password = self._resolve_env_var(
litellm_params, "azure_password", "AZURE_PASSWORD"
)
scope = self._resolve_env_var(litellm_params, "azure_scope", "AZURE_SCOPE")
if scope is None:
scope = "https://cognitiveservices.azure.com/.default"
max_retries = litellm_params.get("max_retries")
timeout = litellm_params.get("timeout")
if (
not api_key
and azure_ad_token_provider is None
and tenant_id
and client_id
and client_secret
):
verbose_logger.debug(
"Using Azure AD Token Provider from Entra ID for Azure Auth"
)
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
tenant_id=tenant_id,
client_id=client_id,
client_secret=client_secret,
scope=scope,
)
if (
azure_ad_token_provider is None
and azure_username
and azure_password
and client_id
):
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
azure_ad_token_provider = get_azure_ad_token_from_username_password(
azure_username=azure_username,
azure_password=azure_password,
client_id=client_id,
scope=scope,
)
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
azure_tenant_id=tenant_id,
scope=scope,
)
elif (
not api_key
and azure_ad_token_provider is None
and litellm.enable_azure_ad_token_refresh is True
):
verbose_logger.debug(
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
)
try:
azure_ad_token_provider = get_azure_ad_token_provider(
azure_scope=scope,
)
except ValueError:
verbose_logger.debug("Azure AD Token Provider could not be used.")
if api_version is None:
api_version = os.getenv(
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
)
_api_key = api_key
if _api_key is not None and isinstance(_api_key, str):
# only show first 5 chars of api_key
_api_key = _api_key[:8] + "*" * 15
verbose_logger.debug(
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
)
azure_client_params = {
"api_key": api_key,
"azure_endpoint": api_base,
"api_version": api_version,
"azure_ad_token": azure_ad_token,
"azure_ad_token_provider": azure_ad_token_provider,
}
# init http client + SSL Verification settings
if is_async is True:
azure_client_params["http_client"] = self._get_async_http_client()
else:
azure_client_params["http_client"] = self._get_sync_http_client()
if max_retries is not None:
azure_client_params["max_retries"] = max_retries
if timeout is not None:
azure_client_params["timeout"] = timeout
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
azure_client_params = select_azure_base_url_or_endpoint(
azure_client_params=azure_client_params
)
return azure_client_params
def _init_azure_client_for_cloudflare_ai_gateway(
self,
api_base: str,
model: str,
api_version: str,
max_retries: int,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
api_key: Optional[str],
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable[[], str]],
acompletion: bool,
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
## build base url - assume api base includes resource name
tenant_id = litellm_params.get("tenant_id", os.getenv("AZURE_TENANT_ID"))
client_id = litellm_params.get("client_id", os.getenv("AZURE_CLIENT_ID"))
scope = litellm_params.get(
"azure_scope",
os.getenv("AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"),
)
if client is None:
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
azure_client_params: Dict[str, Any] = {
"api_version": api_version,
"base_url": f"{api_base}",
"http_client": litellm.client_session,
"max_retries": max_retries,
"timeout": timeout,
}
if api_key is not None:
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
if azure_ad_token.startswith("oidc/"):
azure_ad_token = get_azure_ad_token_from_oidc(
azure_ad_token=azure_ad_token,
azure_client_id=client_id,
azure_tenant_id=tenant_id,
scope=scope,
)
azure_client_params["azure_ad_token"] = azure_ad_token
if azure_ad_token_provider is not None:
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
if acompletion is True:
client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
else:
client = AzureOpenAI(**azure_client_params) # type: ignore
return client
@staticmethod
def _base_validate_azure_environment(
headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
litellm_params = litellm_params or GenericLiteLLMParams()
# Check if api-key is already in headers; if so, use it
if "api-key" in headers:
return headers
api_key = (
litellm_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
if api_key:
headers["api-key"] = api_key
return headers
### Fallback to Azure AD token-based authentication if no API key is available
### Retrieves Azure AD token and adds it to the Authorization header
azure_ad_token = get_azure_ad_token(litellm_params)
if azure_ad_token:
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
@staticmethod
def _get_base_azure_url(
api_base: Optional[str],
litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]],
route: Union[Literal["/openai/responses", "/openai/vector_stores"], str],
default_api_version: Optional[Union[str, Literal["latest", "preview"]]] = None,
) -> str:
"""
Get the base Azure URL for the given route and API version.
Args:
api_base: The base URL of the Azure API.
litellm_params: The litellm parameters.
route: The route to the API.
default_api_version: The default API version to use if no api_version is provided. If 'latest', it will use `openai/v1/...` route.
"""
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
)
original_url = httpx.URL(api_base)
# Extract api_version or use default
litellm_params = litellm_params or {}
api_version = (
cast(Optional[str], litellm_params.get("api_version"))
or default_api_version
)
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL
if route not in api_base:
new_url = _add_path_to_api_base(api_base=api_base, ending_path=route)
else:
new_url = api_base
if BaseAzureLLM._is_azure_v1_api_version(api_version):
# ensure the request go to /openai/v1 and not just /openai
if "/openai/v1" not in new_url:
parsed_url = httpx.URL(new_url)
new_url = str(
parsed_url.copy_with(
path=parsed_url.path.replace("/openai", "/openai/v1")
)
)
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)
@staticmethod
def _is_azure_v1_api_version(api_version: Optional[str]) -> bool:
if api_version is None:
return False
return api_version in {"preview", "latest", "v1"}
def _resolve_env_var(
self, litellm_params: Dict[str, Any], param_key: str, env_var_key: str
) -> Optional[str]:
"""Resolve the environment variable for a given parameter key.
The logic here is different from `params.get(key, os.getenv(env_var))` because
litellm_params may contain the key with a None value, in which case we want
to fallback to the environment variable.
"""
param_value = litellm_params.get(param_key)
if param_value is not None:
return param_value
return os.getenv(env_var_key)
class AzureCredentials(NamedTuple):
api_base: Optional[str]
api_key: Optional[str]
api_version: Optional[str]
def get_azure_credentials(
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
) -> AzureCredentials:
"""Resolve Azure credentials from params, litellm globals, and env vars."""
resolved_api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
resolved_api_version = (
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
)
resolved_api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
return AzureCredentials(
api_base=resolved_api_base,
api_key=resolved_api_key,
api_version=resolved_api_version,
)

View File

@@ -0,0 +1,379 @@
from typing import Any, Callable, Optional
from openai import AsyncAzureOpenAI, AzureOpenAI
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
from ...openai.completion.transformation import OpenAITextCompletionConfig
from ..common_utils import AzureOpenAIError, BaseAzureLLM
openai_text_completion_config = OpenAITextCompletionConfig()
class AzureTextCompletion(BaseAzureLLM):
def __init__(self) -> None:
super().__init__()
def validate_environment(self, api_key, azure_ad_token):
headers = {
"content-type": "application/json",
}
if api_key is not None:
headers["api-key"] = api_key
elif azure_ad_token is not None:
headers["Authorization"] = f"Bearer {azure_ad_token}"
return headers
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
model_response: ModelResponse,
api_key: Optional[str],
api_base: str,
api_version: str,
api_type: str,
azure_ad_token: Optional[str],
azure_ad_token_provider: Optional[Callable],
print_verbose: Callable,
timeout,
logging_obj,
optional_params,
litellm_params,
logger_fn,
acompletion: bool = False,
headers: Optional[dict] = None,
client=None,
):
try:
if model is None or messages is None:
raise AzureOpenAIError(
status_code=422, message="Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
prompt = prompt_factory(
messages=messages, model=model, custom_llm_provider="azure_text"
)
### CHECK IF CLOUDFLARE AI GATEWAY ###
### if so - set the model as part of the base url
if api_base is not None and "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
client = self._init_azure_client_for_cloudflare_ai_gateway(
api_key=api_key,
api_version=api_version,
api_base=api_base,
model=model,
client=client,
max_retries=max_retries,
timeout=timeout,
azure_ad_token=azure_ad_token,
azure_ad_token_provider=azure_ad_token_provider,
acompletion=acompletion,
litellm_params=litellm_params,
)
data = {"model": None, "prompt": prompt, **optional_params}
else:
data = {
"model": model, # type: ignore
"prompt": prompt,
**optional_params,
}
if acompletion is True:
if optional_params.get("stream", False):
return self.async_streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
litellm_params=litellm_params,
)
else:
return self.acompletion(
api_base=api_base,
data=data,
model_response=model_response,
api_key=api_key,
api_version=api_version,
model=model,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
logging_obj=logging_obj,
max_retries=max_retries,
litellm_params=litellm_params,
)
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,
data=data,
model=model,
api_key=api_key,
api_version=api_version,
azure_ad_token=azure_ad_token,
timeout=timeout,
client=client,
)
else:
## LOGGING
logging_obj.pre_call(
input=prompt,
api_key=api_key,
additional_args={
"headers": {
"api_key": api_key,
"azure_ad_token": azure_ad_token,
},
"api_version": api_version,
"api_base": api_base,
"complete_input_dict": data,
},
)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client = self.get_azure_openai_client(
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
litellm_params=litellm_params,
_is_async=False,
model=model,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
raw_response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
stringified_response = response.model_dump()
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return (
openai_text_completion_config.convert_to_chat_model_response_object(
response_object=TextCompletionResponse(**stringified_response),
model_response_object=model_response,
)
)
except AzureOpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
async def acompletion(
self,
api_key: Optional[str],
api_version: str,
model: str,
api_base: str,
data: dict,
timeout: Any,
model_response: ModelResponse,
logging_obj: Any,
max_retries: int,
azure_ad_token: Optional[str] = None,
client=None, # this is the AsyncAzureOpenAI
litellm_params: dict = {},
):
response = None
try:
# init AzureOpenAI Client
# setting Azure client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
raw_response = await azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
return openai_text_completion_config.convert_to_chat_model_response_object(
response_object=response.model_dump(),
model_response_object=model_response,
)
except AzureOpenAIError as e:
raise e
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)
def streaming(
self,
logging_obj,
api_base: str,
api_key: Optional[str],
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
litellm_params: dict = {},
):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise AzureOpenAIError(
status_code=422, message="max retries must be an int"
)
# init AzureOpenAI Client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=False,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
raw_response = azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure_text",
logging_obj=logging_obj,
)
return streamwrapper
async def async_streaming(
self,
logging_obj,
api_base: str,
api_key: Optional[str],
api_version: str,
data: dict,
model: str,
timeout: Any,
azure_ad_token: Optional[str] = None,
client=None,
litellm_params: dict = {},
):
try:
# init AzureOpenAI Client
azure_client = self.get_azure_openai_client(
api_version=api_version,
api_base=api_base,
api_key=api_key,
model=model,
_is_async=True,
client=client,
litellm_params=litellm_params,
)
if not isinstance(azure_client, AsyncAzureOpenAI):
raise AzureOpenAIError(
status_code=500,
message="azure_client is not an instance of AsyncAzureOpenAI",
)
## LOGGING
logging_obj.pre_call(
input=data["prompt"],
api_key=azure_client.api_key,
additional_args={
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
"api_base": azure_client._base_url._uri_reference,
"acompletion": True,
"complete_input_dict": data,
},
)
raw_response = await azure_client.completions.with_raw_response.create(
**data, timeout=timeout
)
response = raw_response.parse()
# return response
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="azure_text",
logging_obj=logging_obj,
)
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
except Exception as e:
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
raise AzureOpenAIError(
status_code=status_code, message=str(e), headers=error_headers
)

View File

@@ -0,0 +1,53 @@
from typing import Optional, Union
from ...openai.completion.transformation import OpenAITextCompletionConfig
class AzureOpenAITextConfig(OpenAITextCompletionConfig):
"""
Reference: https://platform.openai.com/docs/api-reference/chat/create
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
- `function_call` (string or object): This optional parameter controls how the model calls functions.
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
"""
def __init__(
self,
frequency_penalty: Optional[int] = None,
logit_bias: Optional[dict] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[int] = None,
stop: Optional[Union[str, list]] = None,
temperature: Optional[int] = None,
top_p: Optional[int] = None,
) -> None:
super().__init__(
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
stop=stop,
temperature=temperature,
top_p=top_p,
)

View File

@@ -0,0 +1,50 @@
"""
Helper util for handling azure openai-specific cost calculation
- e.g.: prompt caching, audio tokens
"""
from typing import Optional, Tuple
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
from litellm.types.utils import Usage
from litellm.utils import get_model_info
def cost_per_token(
model: str, usage: Usage, response_time_ms: Optional[float] = 0.0
) -> Tuple[float, float]:
"""
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
Input:
- model: str, the model name without provider prefix
- usage: LiteLLM Usage block, containing caching and audio token information
Returns:
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
"""
## GET MODEL INFO
model_info = get_model_info(model=model, custom_llm_provider="azure")
## Speech / Audio cost calculation (cost per second for TTS models)
if (
"output_cost_per_second" in model_info
and model_info["output_cost_per_second"] is not None
and response_time_ms is not None
):
verbose_logger.debug(
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; response time: {response_time_ms}"
)
## COST PER SECOND ##
prompt_cost = 0.0
completion_cost = model_info["output_cost_per_second"] * response_time_ms / 1000
return prompt_cost, completion_cost
## Use generic cost calculator for all other cases
## This properly handles: text tokens, audio tokens, cached tokens, reasoning tokens, etc.
return generic_cost_per_token(
model=model,
usage=usage,
custom_llm_provider="azure",
)

View File

@@ -0,0 +1,91 @@
from typing import Any, Dict, Optional, Tuple
from litellm.exceptions import ContentPolicyViolationError
class AzureOpenAIExceptionMapping:
"""
Class for creating Azure OpenAI specific exceptions
"""
@staticmethod
def create_content_policy_violation_error(
message: str,
model: str,
extra_information: str,
original_exception: Exception,
) -> ContentPolicyViolationError:
"""
Create a content policy violation error
"""
azure_error, inner_error = AzureOpenAIExceptionMapping._extract_azure_error(
original_exception
)
# Prefer the provider message/type/code when present.
provider_message = (
azure_error.get("message") if isinstance(azure_error, dict) else None
) or message
provider_type = (
azure_error.get("type") if isinstance(azure_error, dict) else None
)
provider_code = (
azure_error.get("code") if isinstance(azure_error, dict) else None
)
# Keep the OpenAI-style body fields populated so downstream (proxy + SDK)
# can surface `type` / `code` correctly.
openai_style_body: Dict[str, Any] = {
"message": provider_message,
"type": provider_type or "invalid_request_error",
"code": provider_code or "content_policy_violation",
"param": None,
}
raise ContentPolicyViolationError(
message=provider_message,
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
provider_specific_fields={
# Preserve legacy key for backward compatibility.
"innererror": inner_error,
# Prefer Azure's current naming.
"inner_error": inner_error,
# Include the full Azure error object for clients that want it.
"azure_error": azure_error or None,
},
body=openai_style_body,
)
@staticmethod
def _extract_azure_error(
original_exception: Exception,
) -> Tuple[Dict[str, Any], Optional[dict]]:
"""Extract Azure OpenAI error payload and inner error details.
Azure error formats can vary by endpoint/version. Common shapes:
- {"innererror": {...}} (legacy)
- {"error": {"code": "...", "message": "...", "type": "...", "inner_error": {...}}}
- {"code": "...", "message": "...", "type": "..."} (already flattened)
"""
body_dict = getattr(original_exception, "body", None) or {}
if not isinstance(body_dict, dict):
return {}, None
# Some SDKs place the payload under "error".
azure_error: Dict[str, Any]
if isinstance(body_dict.get("error"), dict):
azure_error = body_dict.get("error", {}) # type: ignore[assignment]
else:
azure_error = body_dict
inner_error = (
azure_error.get("inner_error")
or azure_error.get("innererror")
or body_dict.get("innererror")
or body_dict.get("inner_error")
)
return azure_error, inner_error

View File

@@ -0,0 +1,308 @@
from typing import Any, Coroutine, Optional, Union, cast
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from openai.types.file_deleted import FileDeleted
from litellm._logging import verbose_logger
from litellm.types.llms.openai import *
from ..common_utils import BaseAzureLLM
class AzureOpenAIFilesAPI(BaseAzureLLM):
"""
AzureOpenAI methods to support for batches
- create_file()
- retrieve_file()
- list_files()
- delete_file()
- file_content()
- update_file()
"""
def __init__(self) -> None:
super().__init__()
@staticmethod
def _prepare_create_file_data(
create_file_data: CreateFileRequest,
) -> dict[str, Any]:
"""
Prepare create_file_data for OpenAI SDK.
Removes expires_after if None to match SDK's Omit pattern.
SDK expects file_create_params.ExpiresAfter | Omit, but FileExpiresAfter works at runtime.
"""
data = dict(create_file_data)
if data.get("expires_after") is None:
data.pop("expires_after", None)
return data
async def acreate_file(
self,
create_file_data: CreateFileRequest,
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> OpenAIFileObject:
verbose_logger.debug("create_file_data=%s", create_file_data)
response = await openai_client.files.create(**self._prepare_create_file_data(create_file_data)) # type: ignore[arg-type]
verbose_logger.debug("create_file_response=%s", response)
return OpenAIFileObject(**response.model_dump())
def create_file(
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: Optional[str],
api_key: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.acreate_file(
create_file_data=create_file_data, openai_client=openai_client
)
response = cast(Union[AzureOpenAI, OpenAI], openai_client).files.create(**self._prepare_create_file_data(create_file_data)) # type: ignore[arg-type]
return OpenAIFileObject(**response.model_dump())
async def afile_content(
self,
file_content_request: FileContentRequest,
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> HttpxBinaryResponseContent:
response = await openai_client.files.content(**file_content_request)
return HttpxBinaryResponseContent(response=response.response)
def file_content(
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
) -> Union[
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
]:
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.afile_content( # type: ignore
file_content_request=file_content_request,
openai_client=openai_client,
)
response = cast(Union[AzureOpenAI, OpenAI], openai_client).files.content(
**file_content_request
)
return HttpxBinaryResponseContent(response=response.response)
async def aretrieve_file(
self,
file_id: str,
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> FileObject:
response = await openai_client.files.retrieve(file_id=file_id)
return response
def retrieve_file(
self,
_is_async: bool,
file_id: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
api_version: Optional[str] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.aretrieve_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.retrieve(file_id=file_id)
return response
async def adelete_file(
self,
file_id: str,
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
) -> FileDeleted:
response = await openai_client.files.delete(file_id=file_id)
if not isinstance(response, FileDeleted): # azure returns an empty string
return FileDeleted(id=file_id, deleted=True, object="file")
return response
def delete_file(
self,
_is_async: bool,
file_id: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.adelete_file( # type: ignore
file_id=file_id,
openai_client=openai_client,
)
response = openai_client.files.delete(file_id=file_id)
if not isinstance(response, FileDeleted): # azure returns an empty string
return FileDeleted(id=file_id, deleted=True, object="file")
return response
async def alist_files(
self,
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
purpose: Optional[str] = None,
):
if isinstance(purpose, str):
response = await openai_client.files.list(purpose=purpose)
else:
response = await openai_client.files.list()
return response
def list_files(
self,
_is_async: bool,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
purpose: Optional[str] = None,
api_version: Optional[str] = None,
client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = None,
litellm_params: Optional[dict] = None,
):
openai_client: Optional[
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
] = self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)
if openai_client is None:
raise ValueError(
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
)
if _is_async is True:
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
raise ValueError(
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
)
return self.alist_files( # type: ignore
purpose=purpose,
openai_client=openai_client,
)
if isinstance(purpose, str):
response = openai_client.files.list(purpose=purpose)
else:
response = openai_client.files.list()
return response

View File

@@ -0,0 +1,40 @@
from typing import Optional, Union
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
"""
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
"""
def get_openai_client(
self,
api_key: Optional[str],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
organization: Optional[str],
client: Optional[
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
] = None,
_is_async: bool = False,
api_version: Optional[str] = None,
litellm_params: Optional[dict] = None,
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
# Override to use Azure-specific client initialization
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
client = None
return self.get_azure_openai_client(
litellm_params=litellm_params or {},
api_key=api_key,
api_base=api_base,
api_version=api_version,
client=client,
_is_async=_is_async,
)

View File

@@ -0,0 +1,83 @@
from typing import Optional, cast
import httpx
import litellm
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
from litellm.secret_managers.main import get_secret_str
from litellm.utils import _add_path_to_api_base
class AzureImageEditConfig(OpenAIImageEditConfig):
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
api_key = (
api_key
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
headers.update(
{
"Authorization": f"Bearer {api_key}",
}
)
return headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.openai.azure.com"
OR
"https://litellm8397336933.openai.azure.com/openai/deployments/<deployment_name>/images/edits?api-version=2024-05-01-preview"
- model: Model name (deployment name).
- litellm_params: Additional query parameters, including "api_version".
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.openai.azure.com/openai/deployments/<deployment_name>/images/edits?api-version=2024-05-01-preview"
"""
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
)
original_url = httpx.URL(api_base)
# Extract api_version or use default
api_version = cast(Optional[str], litellm_params.get("api_version"))
# Create a new dictionary with existing params
query_params = dict(original_url.params)
# Add api_version if needed
if "api-version" not in query_params and api_version:
query_params["api-version"] = api_version
# Add the path to the base URL using the model as deployment name
if "/openai/deployments/" not in api_base:
new_url = _add_path_to_api_base(
api_base=api_base,
ending_path=f"/openai/deployments/{model}/images/edits",
)
else:
new_url = api_base
# Use the new query_params dictionary
final_url = httpx.URL(new_url).copy_with(params=query_params)
return str(final_url)

View File

@@ -0,0 +1,29 @@
from litellm._logging import verbose_logger
from litellm.llms.base_llm.image_generation.transformation import (
BaseImageGenerationConfig,
)
from .dall_e_2_transformation import AzureDallE2ImageGenerationConfig
from .dall_e_3_transformation import AzureDallE3ImageGenerationConfig
from .gpt_transformation import AzureGPTImageGenerationConfig
__all__ = [
"AzureDallE2ImageGenerationConfig",
"AzureDallE3ImageGenerationConfig",
"AzureGPTImageGenerationConfig",
]
def get_azure_image_generation_config(model: str) -> BaseImageGenerationConfig:
model = model.lower()
model = model.replace("-", "")
model = model.replace("_", "")
if model == "" or "dalle2" in model: # empty model is dall-e-2
return AzureDallE2ImageGenerationConfig()
elif "dalle3" in model:
return AzureDallE3ImageGenerationConfig()
else:
verbose_logger.debug(
f"Using AzureGPTImageGenerationConfig for model: {model}. This follows the gpt-image-1 model format."
)
return AzureGPTImageGenerationConfig()

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import DallE2ImageGenerationConfig
class AzureDallE2ImageGenerationConfig(DallE2ImageGenerationConfig):
"""
Azure dall-e-2 image generation config
"""
pass

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import DallE3ImageGenerationConfig
class AzureDallE3ImageGenerationConfig(DallE3ImageGenerationConfig):
"""
Azure dall-e-3 image generation config
"""
pass

View File

@@ -0,0 +1,9 @@
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
class AzureGPTImageGenerationConfig(GPTImageGenerationConfig):
"""
Azure gpt-image-1 image generation config
"""
pass

View File

@@ -0,0 +1,85 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
import httpx
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from httpx import URL
class AzurePassthroughConfig(BasePassthroughConfig):
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
return "stream" in request_data
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
base_target_url = self.get_api_base(api_base)
if base_target_url is None:
raise Exception("Azure api base not found")
litellm_metadata = litellm_params.get("litellm_metadata") or {}
model_group = litellm_metadata.get("model_group")
if model_group and model_group in endpoint:
endpoint = endpoint.replace(model_group, model)
complete_url = BaseAzureLLM._get_base_azure_url(
api_base=base_target_url,
litellm_params=litellm_params,
route=endpoint,
default_api_version=litellm_params.get("api_version"),
)
return (
httpx.URL(complete_url),
base_target_url,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers,
litellm_params=GenericLiteLLMParams(
**{**litellm_params, "api_key": api_key}
),
)
@staticmethod
def get_api_base(
api_base: Optional[str] = None,
) -> Optional[str]:
return api_base or get_secret_str("AZURE_API_BASE")
@staticmethod
def get_api_key(
api_key: Optional[str] = None,
) -> Optional[str]:
return api_key or get_secret_str("AZURE_API_KEY")
@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
return super().get_models(api_key, api_base)

View File

@@ -0,0 +1,126 @@
"""
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
This requires websockets, and is currently only supported on LiteLLM Proxy.
"""
from typing import Any, Optional, cast
from litellm._logging import verbose_proxy_logger
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
from ....llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
from ..azure import AzureChatCompletion
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
async def forward_messages(client_ws: Any, backend_ws: Any):
import websockets
try:
while True:
message = await backend_ws.recv()
await client_ws.send_text(message)
except websockets.exceptions.ConnectionClosed: # type: ignore
pass
class AzureOpenAIRealtime(AzureChatCompletion):
def _construct_url(
self,
api_base: str,
model: str,
api_version: Optional[str],
realtime_protocol: Optional[str] = None,
) -> str:
"""
Construct Azure realtime WebSocket URL.
Args:
api_base: Azure API base URL (will be converted from https:// to wss://)
model: Model deployment name
api_version: Azure API version
realtime_protocol: Protocol version to use:
- "GA" or "v1": Uses /openai/v1/realtime (GA path)
- "beta" or None: Uses /openai/realtime (beta path, default)
Returns:
WebSocket URL string
Examples:
beta/default: "wss://.../openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"
GA/v1: "wss://.../openai/v1/realtime?model=gpt-realtime-deployment"
"""
api_base = api_base.replace("https://", "wss://")
# Determine path based on realtime_protocol (case-insensitive)
_is_ga = realtime_protocol is not None and realtime_protocol.upper() in (
"GA",
"V1",
)
if _is_ga:
path = "/openai/v1/realtime"
return f"{api_base}{path}?model={model}"
else:
# Default to beta path for backwards compatibility
path = "/openai/realtime"
return f"{api_base}{path}?api-version={api_version}&deployment={model}"
async def async_realtime(
self,
model: str,
websocket: Any,
logging_obj: LiteLLMLogging,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
api_version: Optional[str] = None,
azure_ad_token: Optional[str] = None,
client: Optional[Any] = None,
timeout: Optional[float] = None,
realtime_protocol: Optional[str] = None,
user_api_key_dict: Optional[Any] = None,
litellm_metadata: Optional[dict] = None,
):
import websockets
from websockets.asyncio.client import ClientConnection
if api_base is None:
raise ValueError("api_base is required for Azure OpenAI calls")
if api_version is None and (
realtime_protocol is None or realtime_protocol.upper() not in ("GA", "V1")
):
raise ValueError("api_version is required for Azure OpenAI calls")
url = self._construct_url(
api_base, model, api_version, realtime_protocol=realtime_protocol
)
try:
ssl_context = get_shared_realtime_ssl_context()
async with websockets.connect( # type: ignore
url,
additional_headers={
"api-key": api_key, # type: ignore
},
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
ssl=ssl_context,
) as backend_ws:
realtime_streaming = RealTimeStreaming(
websocket,
cast(ClientConnection, backend_ws),
logging_obj,
user_api_key_dict=user_api_key_dict,
request_data={"litellm_metadata": litellm_metadata or {}},
)
await realtime_streaming.bidirectional_forward()
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
await websocket.close(code=e.status_code, reason=str(e))
except Exception:
verbose_proxy_logger.exception(
"Error in AzureOpenAIRealtime.async_realtime"
)
pass

View File

@@ -0,0 +1,46 @@
"""Azure OpenAI realtime HTTP transformation config (client_secrets + realtime_calls)."""
from typing import Optional
import litellm
from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig
from litellm.secret_managers.main import get_secret_str
class AzureRealtimeHTTPConfig(BaseRealtimeHTTPConfig):
def get_api_base(self, api_base: Optional[str], **kwargs) -> str:
return api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") or ""
def get_api_key(self, api_key: Optional[str], **kwargs) -> str:
return api_key or litellm.api_key or get_secret_str("AZURE_API_KEY") or ""
def get_complete_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
base = self.get_api_base(api_base).rstrip("/")
version = api_version or get_secret_str("AZURE_API_VERSION") or "2024-12-17"
return f"{base}/openai/realtime/client_secrets?api-version={version}"
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
) -> dict:
return {
**headers,
"api-key": api_key or "",
"Content-Type": "application/json",
}
def get_realtime_calls_url(
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
) -> str:
base = self.get_api_base(api_base).rstrip("/")
version = api_version or get_secret_str("AZURE_API_VERSION") or "2024-12-17"
return f"{base}/openai/realtime/calls?api-version={version}"
def get_realtime_calls_headers(self, ephemeral_key: str) -> dict:
return {
"api-key": ephemeral_key,
}

View File

@@ -0,0 +1,94 @@
"""
Support for Azure OpenAI O-series models (o1, o3, etc.) in Responses API
https://platform.openai.com/docs/guides/reasoning
Translations handled by LiteLLM:
- temperature => drop param (if user opts in to dropping param)
- Other parameters follow base Azure OpenAI Responses API behavior
"""
from typing import TYPE_CHECKING, Any, Dict
from litellm._logging import verbose_logger
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
from litellm.utils import supports_reasoning
from .transformation import AzureOpenAIResponsesAPIConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AzureOpenAIOSeriesResponsesAPIConfig(AzureOpenAIResponsesAPIConfig):
"""
Configuration for Azure OpenAI O-series models in Responses API.
O-series models (o1, o3, etc.) do not support the temperature parameter
in the responses API, so we need to drop it when drop_params is enabled.
"""
def get_supported_openai_params(self, model: str) -> list:
"""
Get supported parameters for Azure OpenAI O-series Responses API.
O-series models don't support temperature parameter in responses API.
"""
# Get the base Azure supported params
base_supported_params = super().get_supported_openai_params(model)
# O-series models don't support temperature parameter in responses API
o_series_unsupported_params = ["temperature"]
# Filter out unsupported parameters for O-series models
o_series_supported_params = [
param
for param in base_supported_params
if param not in o_series_unsupported_params
]
return o_series_supported_params
def map_openai_params(
self,
response_api_optional_params: ResponsesAPIOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""
Map OpenAI parameters for Azure OpenAI O-series Responses API.
Drops temperature parameter if drop_params is True since O-series models
don't support temperature in the responses API.
"""
mapped_params = dict(response_api_optional_params)
# If drop_params is enabled, remove temperature parameter for O-series models
if drop_params and "temperature" in mapped_params:
verbose_logger.debug(
f"Dropping unsupported parameter 'temperature' for Azure OpenAI O-series responses API model {model}"
)
mapped_params.pop("temperature", None)
return mapped_params
def is_o_series_model(self, model: str) -> bool:
"""
Check if the model is an O-series model.
Args:
model: The model name to check
Returns:
True if it's an O-series model, False otherwise
"""
# Check if model name contains o_series or if it's a known O-series model
if "o_series" in model.lower():
return True
# Check if the model supports reasoning (which is O-series specific)
return supports_reasoning(model)

View File

@@ -0,0 +1,359 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import httpx
from openai.types.responses import ResponseReasoningItem
from litellm._logging import verbose_logger
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.types.llms.openai import *
from litellm.types.responses.main import *
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
# Parameters not supported by Azure Responses API
AZURE_UNSUPPORTED_PARAMS = ["context_management"]
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.AZURE
def get_supported_openai_params(self, model: str) -> list:
"""
Azure Responses API does not support context_management (compaction).
"""
base_supported_params = super().get_supported_openai_params(model)
return [
param
for param in base_supported_params
if param not in self.AZURE_UNSUPPORTED_PARAMS
]
def validate_environment(
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params
)
def get_stripped_model_name(self, model: str) -> str:
# if "responses/" is in the model name, remove it
if "responses/" in model:
model = model.replace("responses/", "")
if "o_series" in model:
model = model.replace("o_series/", "")
return model
def _handle_reasoning_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""
Handle reasoning items to filter out the status field.
Issue: https://github.com/BerriAI/litellm/issues/13484
Azure OpenAI API does not accept 'status' field in reasoning input items.
"""
if item.get("type") == "reasoning":
try:
# Ensure required fields are present for ResponseReasoningItem
item_data = dict(item)
if "summary" not in item_data:
item_data["summary"] = (
item_data.get("reasoning_content", "")[:100] + "..."
if len(item_data.get("reasoning_content", "")) > 100
else item_data.get("reasoning_content", "")
)
# Create ResponseReasoningItem object from the item data
reasoning_item = ResponseReasoningItem(**item_data)
# Convert back to dict with exclude_none=True to exclude None fields
dict_reasoning_item = reasoning_item.model_dump(exclude_none=True)
dict_reasoning_item.pop("status", None)
return dict_reasoning_item
except Exception as e:
verbose_logger.debug(
f"Failed to create ResponseReasoningItem, falling back to manual filtering: {e}"
)
# Fallback: manually filter out known None fields
filtered_item = {
k: v
for k, v in item.items()
if v is not None
or k not in {"status", "content", "encrypted_content"}
}
return filtered_item
return item
def _validate_input_param(
self, input: Union[str, ResponseInputParam]
) -> Union[str, ResponseInputParam]:
"""
Override parent method to also filter out 'status' field from message items.
Azure OpenAI API does not accept 'status' field in input messages.
"""
from typing import cast
# First call parent's validation
validated_input = super()._validate_input_param(input)
# Then filter out status from message items
if isinstance(validated_input, list):
filtered_input: List[Any] = []
for item in validated_input:
if isinstance(item, dict) and item.get("type") == "message":
# Filter out status field from message items
filtered_item = {k: v for k, v in item.items() if k != "status"}
filtered_input.append(filtered_item)
else:
filtered_input.append(item)
return cast(ResponseInputParam, filtered_input)
return validated_input
def transform_responses_api_request(
self,
model: str,
input: Union[str, ResponseInputParam],
response_api_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
"""No transform applied since inputs are in OpenAI spec already"""
stripped_model_name = self.get_stripped_model_name(model)
# Azure Responses API requires flattened tools (params at top level, not nested in 'function')
if "tools" in response_api_optional_request_params and isinstance(
response_api_optional_request_params["tools"], list
):
new_tools: List[Dict[str, Any]] = []
for tool in response_api_optional_request_params["tools"]:
if isinstance(tool, dict) and "function" in tool:
new_tool: Dict[str, Any] = deepcopy(tool)
function_data = new_tool.pop("function")
new_tool.update(function_data)
new_tools.append(new_tool)
else:
new_tools.append(tool)
response_api_optional_request_params["tools"] = new_tools
return super().transform_responses_api_request(
model=stripped_model_name,
input=input,
response_api_optional_request_params=response_api_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Constructs a complete URL for the API request.
Args:
- api_base: Base URL, e.g.,
"https://litellm8397336933.openai.azure.com"
OR
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
- model: Model name.
- optional_params: Additional query parameters, including "api_version".
- stream: If streaming is required (optional).
Returns:
- A complete URL string, e.g.,
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
"""
from litellm.constants import AZURE_DEFAULT_RESPONSES_API_VERSION
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/responses",
default_api_version=AZURE_DEFAULT_RESPONSES_API_VERSION,
)
#########################################################
########## DELETE RESPONSE API TRANSFORMATION ##############
#########################################################
def _construct_url_for_response_id_in_path(
self, api_base: str, response_id: str
) -> str:
"""
Constructs a URL for the API request with the response_id in the path.
"""
from urllib.parse import urlparse, urlunparse
# Parse the URL to separate its components
parsed_url = urlparse(api_base)
# Insert the response_id at the end of the path component
# Remove trailing slash if present to avoid double slashes
path = parsed_url.path.rstrip("/")
new_path = f"{path}/{response_id}"
# Reconstruct the URL with all original components but with the modified path
constructed_url = urlunparse(
(
parsed_url.scheme, # http, https
parsed_url.netloc, # domain name, port
new_path, # path with response_id added
parsed_url.params, # parameters
parsed_url.query, # query string
parsed_url.fragment, # fragment
)
)
return constructed_url
def transform_delete_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the delete response API request into a URL and data
Azure OpenAI API expects the following request:
- DELETE /openai/responses/{response_id}?api-version=xxx
This function handles URLs with query parameters by inserting the response_id
at the correct location (before any query parameters).
"""
delete_url = self._construct_url_for_response_id_in_path(
api_base=api_base, response_id=response_id
)
data: Dict = {}
verbose_logger.debug(f"delete response url={delete_url}")
return delete_url, data
#########################################################
########## GET RESPONSE API TRANSFORMATION ###############
#########################################################
def transform_get_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the get response API request into a URL and data
OpenAI API expects the following request
- GET /v1/responses/{response_id}
"""
get_url = self._construct_url_for_response_id_in_path(
api_base=api_base, response_id=response_id
)
data: Dict = {}
verbose_logger.debug(f"get response url={get_url}")
return get_url, data
def transform_list_input_items_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
after: Optional[str] = None,
before: Optional[str] = None,
include: Optional[List[str]] = None,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
) -> Tuple[str, Dict]:
url = (
self._construct_url_for_response_id_in_path(
api_base=api_base, response_id=response_id
)
+ "/input_items"
)
params: Dict[str, Any] = {}
if after is not None:
params["after"] = after
if before is not None:
params["before"] = before
if include:
params["include"] = ",".join(include)
if limit is not None:
params["limit"] = limit
if order is not None:
params["order"] = order
verbose_logger.debug(f"list input items url={url}")
return url, params
#########################################################
########## CANCEL RESPONSE API TRANSFORMATION ##########
#########################################################
def transform_cancel_response_api_request(
self,
response_id: str,
api_base: str,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Tuple[str, Dict]:
"""
Transform the cancel response API request into a URL and data
Azure OpenAI API expects the following request:
- POST /openai/responses/{response_id}/cancel?api-version=xxx
This function handles URLs with query parameters by inserting the response_id
at the correct location (before any query parameters).
"""
from urllib.parse import urlparse, urlunparse
# Parse the URL to separate its components
parsed_url = urlparse(api_base)
# Insert the response_id and /cancel at the end of the path component
# Remove trailing slash if present to avoid double slashes
path = parsed_url.path.rstrip("/")
new_path = f"{path}/{response_id}/cancel"
# Reconstruct the URL with all original components but with the modified path
cancel_url = urlunparse(
(
parsed_url.scheme, # http, https
parsed_url.netloc, # domain name, port
new_path, # path with response_id and /cancel added
parsed_url.params, # parameters
parsed_url.query, # query string
parsed_url.fragment, # fragment
)
)
data: Dict = {}
verbose_logger.debug(f"cancel response url={cancel_url}")
return cancel_url, data
def transform_cancel_response_api_response(
self,
raw_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
) -> ResponsesAPIResponse:
"""
Transform the cancel response API response into a ResponsesAPIResponse
"""
try:
raw_response_json = raw_response.json()
except Exception:
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIError
raise AzureOpenAIError(
message=raw_response.text, status_code=raw_response.status_code
)
return ResponsesAPIResponse(**raw_response_json)

View File

@@ -0,0 +1,7 @@
"""Azure Text-to-Speech module"""
from .transformation import AzureAVATextToSpeechConfig
__all__ = [
"AzureAVATextToSpeechConfig",
]

View File

@@ -0,0 +1,504 @@
"""
Azure AVA (Cognitive Services) Text-to-Speech transformation
Maps OpenAI TTS spec to Azure Cognitive Services TTS API
"""
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
from urllib.parse import urlparse
import httpx
import litellm
from litellm.llms.base_llm.text_to_speech.transformation import (
BaseTextToSpeechConfig,
TextToSpeechRequestData,
)
from litellm.secret_managers.main import get_secret_str
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import HttpxBinaryResponseContent
else:
LiteLLMLoggingObj = Any
HttpxBinaryResponseContent = Any
class AzureAVATextToSpeechConfig(BaseTextToSpeechConfig):
"""
Configuration for Azure AVA (Cognitive Services) Text-to-Speech
Reference: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/rest-text-to-speech
"""
# Azure endpoint domains
DEFAULT_VOICE = "en-US-AriaNeural"
COGNITIVE_SERVICES_DOMAIN = "api.cognitive.microsoft.com"
TTS_SPEECH_DOMAIN = "tts.speech.microsoft.com"
TTS_ENDPOINT_PATH = "/cognitiveservices/v1"
# Voice name mappings from OpenAI voices to Azure voices
VOICE_MAPPINGS = {
"alloy": "en-US-JennyNeural",
"echo": "en-US-GuyNeural",
"fable": "en-GB-RyanNeural",
"onyx": "en-US-DavisNeural",
"nova": "en-US-AmberNeural",
"shimmer": "en-US-AriaNeural",
}
# Response format mappings from OpenAI to Azure
FORMAT_MAPPINGS = {
"mp3": "audio-24khz-48kbitrate-mono-mp3",
"opus": "ogg-48khz-16bit-mono-opus",
"aac": "audio-24khz-48kbitrate-mono-mp3", # Azure doesn't have AAC, use MP3
"flac": "audio-24khz-48kbitrate-mono-mp3", # Azure doesn't have FLAC, use MP3
"wav": "riff-24khz-16bit-mono-pcm",
"pcm": "raw-24khz-16bit-mono-pcm",
}
def dispatch_text_to_speech(
self,
model: str,
input: str,
voice: Optional[Union[str, Dict]],
optional_params: Dict,
litellm_params_dict: Dict,
logging_obj: "LiteLLMLoggingObj",
timeout: Union[float, httpx.Timeout],
extra_headers: Optional[Dict[str, Any]],
base_llm_http_handler: Any,
aspeech: bool,
api_base: Optional[str],
api_key: Optional[str],
**kwargs: Any,
) -> Union[
"HttpxBinaryResponseContent",
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
]:
"""
Dispatch method to handle Azure AVA TTS requests
This method encapsulates Azure-specific credential resolution and parameter handling
Args:
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
"""
# Resolve api_base from multiple sources
api_base = (
api_base
or litellm_params_dict.get("api_base")
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
# Resolve api_key from multiple sources (Azure-specific)
api_key = (
api_key
or litellm_params_dict.get("api_key")
or litellm.api_key
or litellm.azure_key
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
# Convert voice to string if it's a dict (for Azure AVA, voice must be a string)
voice_str: Optional[str] = None
if isinstance(voice, str):
voice_str = voice
elif isinstance(voice, dict):
# Extract voice name from dict if needed
voice_str = voice.get("name") if voice else None
litellm_params_dict.update(
{
"api_key": api_key,
"api_base": api_base,
}
)
# Call the text_to_speech_handler
response = base_llm_http_handler.text_to_speech_handler(
model=model,
input=input,
voice=voice_str,
text_to_speech_provider_config=self,
text_to_speech_optional_params=optional_params,
custom_llm_provider="azure",
litellm_params=litellm_params_dict,
logging_obj=logging_obj,
timeout=timeout,
extra_headers=extra_headers,
client=None,
_is_async=aspeech,
)
return response
def get_supported_openai_params(self, model: str) -> list:
"""
Azure AVA TTS supports these OpenAI parameters
Note: Azure also supports additional SSML-specific parameters (style, styledegree, role)
which can be passed but are not part of the OpenAI spec
"""
return ["voice", "response_format", "speed"]
def _convert_speed_to_azure_rate(self, speed: float) -> str:
"""
Convert OpenAI speed value to Azure SSML prosody rate percentage
Args:
speed: OpenAI speed value (0.25-4.0, default 1.0)
Returns:
Azure rate string with percentage (e.g., "+50%", "-50%", "+0%")
Examples:
speed=1.0 -> "+0%" (default)
speed=2.0 -> "+100%"
speed=0.5 -> "-50%"
"""
rate_percentage = int((speed - 1.0) * 100)
return f"{rate_percentage:+d}%"
def _build_express_as_element(
self,
content: str,
style: Optional[str] = None,
styledegree: Optional[str] = None,
role: Optional[str] = None,
) -> str:
"""
Build mstts:express-as element with optional style, styledegree, and role attributes
Args:
content: The inner content to wrap
style: Speaking style (e.g., "cheerful", "sad", "angry")
styledegree: Style intensity (0.01 to 2)
role: Voice role (e.g., "Girl", "Boy", "SeniorFemale", "SeniorMale")
Returns:
Content wrapped in mstts:express-as if any attributes provided, otherwise raw content
"""
if not (style or styledegree or role):
return content
express_as_attrs = []
if style:
express_as_attrs.append(f"style='{style}'")
if styledegree:
express_as_attrs.append(f"styledegree='{styledegree}'")
if role:
express_as_attrs.append(f"role='{role}'")
express_as_attrs_str = " ".join(express_as_attrs)
return f"<mstts:express-as {express_as_attrs_str}>{content}</mstts:express-as>"
def _get_voice_language(
self,
voice_name: Optional[str],
explicit_lang: Optional[str] = None,
) -> Optional[str]:
"""
Get the language for the voice element's xml:lang attribute
Args:
voice_name: The Azure voice name (e.g., "en-US-AriaNeural")
explicit_lang: Explicitly provided language code (takes precedence)
Returns:
Language code if available (e.g., "es-ES"), or None
Examples:
- explicit_lang="es-ES""es-ES" (explicit takes precedence)
- voice_name="en-US-AriaNeural", explicit_lang=None → None (use default from voice)
- voice_name="en-US-AvaMultilingualNeural", explicit_lang="fr-FR""fr-FR"
"""
# If explicit language is provided, use it (for multilingual voices)
if explicit_lang:
return explicit_lang
# For non-multilingual voices, we don't need to set xml:lang on the voice element
# The voice name already encodes the language (e.g., en-US-AriaNeural)
# Only return a language if explicitly set
return None
def map_openai_params(
self,
model: str,
optional_params: Dict,
voice: Optional[Union[str, Dict]] = None,
drop_params: bool = False,
kwargs: Dict = {},
) -> Tuple[Optional[str], Dict]:
"""
Map OpenAI parameters to Azure AVA TTS parameters
"""
mapped_params = {}
##########################################################
# Map voice
# OpenAI uses voice as a required param, hence not in optional_params
##########################################################
# If it's already an Azure voice, use it directly
mapped_voice: Optional[str] = None
if isinstance(voice, str):
if voice in self.VOICE_MAPPINGS:
mapped_voice = self.VOICE_MAPPINGS[voice]
else:
# Assume it's already an Azure voice name
mapped_voice = voice
# Map response format
if "response_format" in optional_params:
format_name = optional_params["response_format"]
if format_name in self.FORMAT_MAPPINGS:
mapped_params["output_format"] = self.FORMAT_MAPPINGS[format_name]
else:
# Try to use it directly as Azure format
mapped_params["output_format"] = format_name
else:
# Default to MP3
mapped_params["output_format"] = "audio-24khz-48kbitrate-mono-mp3"
# Map speed (OpenAI: 0.25-4.0, Azure: prosody rate)
if "speed" in optional_params:
speed = optional_params["speed"]
if speed is not None:
mapped_params["rate"] = self._convert_speed_to_azure_rate(speed=speed)
# Pass through Azure-specific SSML parameters
if "style" in kwargs:
mapped_params["style"] = kwargs["style"]
if "styledegree" in kwargs:
mapped_params["styledegree"] = kwargs["styledegree"]
if "role" in kwargs:
mapped_params["role"] = kwargs["role"]
if "lang" in kwargs:
mapped_params["lang"] = kwargs["lang"]
return mapped_voice, mapped_params
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate Azure environment and set up authentication headers
"""
validated_headers = headers.copy()
# Azure AVA TTS requires either:
# 1. Ocp-Apim-Subscription-Key header, or
# 2. Authorization: Bearer <token> header
# We'll use the token-based auth via our token handler
# The token will be added later in the handler
if api_key:
# If subscription key is provided, use it directly
validated_headers["Ocp-Apim-Subscription-Key"] = api_key
# Content-Type for SSML
validated_headers["Content-Type"] = "application/ssml+xml"
# User-Agent
validated_headers["User-Agent"] = "litellm"
return validated_headers
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Get the complete URL for Azure AVA TTS request
Azure TTS endpoint format:
https://{region}.tts.speech.microsoft.com/cognitiveservices/v1
"""
if api_base is None:
raise ValueError(
f"api_base is required for Azure AVA TTS. "
f"Format: https://{{region}}.{self.COGNITIVE_SERVICES_DOMAIN} or "
f"https://{{region}}.{self.TTS_SPEECH_DOMAIN}"
)
# Remove trailing slash and parse URL
api_base = api_base.rstrip("/")
parsed_url = urlparse(api_base)
hostname = parsed_url.hostname or ""
# Check if it's a Cognitive Services endpoint (convert to TTS endpoint)
if self._is_cognitive_services_endpoint(hostname=hostname):
region = self._extract_region_from_hostname(
hostname=hostname, domain=self.COGNITIVE_SERVICES_DOMAIN
)
return self._build_tts_url(region=region)
# Check if it's already a TTS endpoint
if self._is_tts_endpoint(hostname=hostname):
if not api_base.endswith(self.TTS_ENDPOINT_PATH):
return f"{api_base}{self.TTS_ENDPOINT_PATH}"
return api_base
# Assume it's a custom endpoint, append the path
return f"{api_base}{self.TTS_ENDPOINT_PATH}"
def _is_cognitive_services_endpoint(self, hostname: str) -> bool:
"""Check if hostname is a Cognitive Services endpoint"""
return hostname == self.COGNITIVE_SERVICES_DOMAIN or hostname.endswith(
f".{self.COGNITIVE_SERVICES_DOMAIN}"
)
def _is_tts_endpoint(self, hostname: str) -> bool:
"""Check if hostname is a TTS endpoint"""
return hostname == self.TTS_SPEECH_DOMAIN or hostname.endswith(
f".{self.TTS_SPEECH_DOMAIN}"
)
def _extract_region_from_hostname(self, hostname: str, domain: str) -> str:
"""
Extract region from hostname
Examples:
eastus.api.cognitive.microsoft.com -> eastus
api.cognitive.microsoft.com -> ""
"""
if hostname.endswith(f".{domain}"):
return hostname[: -len(f".{domain}")]
return ""
def _build_tts_url(self, region: str) -> str:
"""Build the complete TTS URL with region"""
if region:
return f"https://{region}.{self.TTS_SPEECH_DOMAIN}{self.TTS_ENDPOINT_PATH}"
return f"https://{self.TTS_SPEECH_DOMAIN}{self.TTS_ENDPOINT_PATH}"
def is_ssml_input(self, input: str) -> bool:
"""
Returns True if input is SSML, False otherwise
Based on https://www.w3.org/TR/speech-synthesis/ all SSML must start with <speak>
"""
return "<speak>" in input or "<speak " in input
def transform_text_to_speech_request(
self,
model: str,
input: str,
voice: Optional[str],
optional_params: Dict,
litellm_params: Dict,
headers: dict,
) -> TextToSpeechRequestData:
"""
Transform OpenAI TTS request to Azure AVA TTS SSML format
Note: optional_params should already be mapped via map_openai_params in main.py
Supports Azure-specific SSML features:
- style: Speaking style (e.g., "cheerful", "sad", "angry")
- styledegree: Style intensity (0.01 to 2)
- role: Voice role (e.g., "Girl", "Boy", "SeniorFemale", "SeniorMale")
- lang: Language code for multilingual voices (e.g., "es-ES", "fr-FR")
Auto-detects SSML:
- If input contains <speak>, it's passed through as-is without transformation
Returns:
TextToSpeechRequestData: Contains SSML body and Azure-specific headers
"""
# Get voice (already mapped in main.py, or use default)
azure_voice = voice or self.DEFAULT_VOICE
# Get output format (already mapped in main.py)
output_format = optional_params.get(
"output_format", "audio-24khz-48kbitrate-mono-mp3"
)
headers["X-Microsoft-OutputFormat"] = output_format
# Auto-detect SSML: if input contains <speak>, pass it through as-is
# Similar to Vertex AI behavior - check if input looks like SSML
if self.is_ssml_input(input=input):
return TextToSpeechRequestData(
ssml_body=input,
headers=headers,
)
# Build SSML from plain text
rate = optional_params.get("rate", "0%")
style = optional_params.get("style")
styledegree = optional_params.get("styledegree")
role = optional_params.get("role")
lang = optional_params.get("lang")
# Escape XML special characters in input text
escaped_input = (
input.replace("&", "&amp;")
.replace("<", "&lt;")
.replace(">", "&gt;")
.replace('"', "&quot;")
.replace("'", "&apos;")
)
# Determine if we need mstts namespace (for express-as element)
use_mstts = style or role or styledegree
# Build the xmlns attributes
if use_mstts:
xmlns = "xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts='https://www.w3.org/2001/mstts'"
else:
xmlns = "xmlns='http://www.w3.org/2001/10/synthesis'"
# Build the inner content with prosody
prosody_content = f"<prosody rate='{rate}'>{escaped_input}</prosody>"
# Wrap in mstts:express-as if style or role is specified
voice_content = self._build_express_as_element(
content=prosody_content,
style=style,
styledegree=styledegree,
role=role,
)
# Build voice element with optional xml:lang attribute
voice_lang = self._get_voice_language(
voice_name=azure_voice,
explicit_lang=lang,
)
voice_lang_attr = f" xml:lang='{voice_lang}'" if voice_lang else ""
ssml_body = f"""<speak version='1.0' {xmlns} xml:lang='en-US'>
<voice name='{azure_voice}'{voice_lang_attr}>
{voice_content}
</voice>
</speak>"""
return {
"ssml_body": ssml_body,
"headers": headers,
}
def transform_text_to_speech_response(
self,
model: str,
raw_response: httpx.Response,
logging_obj: "LiteLLMLoggingObj",
) -> "HttpxBinaryResponseContent":
"""
Transform Azure AVA TTS response to standard format
Azure returns the audio data directly in the response body
"""
from litellm.types.llms.openai import HttpxBinaryResponseContent
# Azure returns audio data directly in the response body
# Wrap it in HttpxBinaryResponseContent for consistent return type
return HttpxBinaryResponseContent(raw_response)

View File

@@ -0,0 +1,25 @@
from typing import Optional
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.vector_stores.transformation import OpenAIVectorStoreConfig
from litellm.types.router import GenericLiteLLMParams
class AzureOpenAIVectorStoreConfig(OpenAIVectorStoreConfig):
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/vector_stores",
)
def validate_environment(
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
) -> dict:
return BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params
)

View File

@@ -0,0 +1,93 @@
from typing import TYPE_CHECKING, Any, Dict, Optional
from litellm.types.videos.main import VideoCreateOptionalRequestParams
from litellm.types.router import GenericLiteLLMParams
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.llms.openai.videos.transformation import OpenAIVideoConfig
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from ...base_llm.videos.transformation import BaseVideoConfig as _BaseVideoConfig
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
LiteLLMLoggingObj = _LiteLLMLoggingObj
BaseVideoConfig = _BaseVideoConfig
BaseLLMException = _BaseLLMException
else:
LiteLLMLoggingObj = Any
BaseVideoConfig = Any
BaseLLMException = Any
class AzureVideoConfig(OpenAIVideoConfig):
"""
Configuration class for OpenAI video generation.
"""
def __init__(self):
super().__init__()
def get_supported_openai_params(self, model: str) -> list:
"""
Get the list of supported OpenAI parameters for video generation.
"""
return [
"model",
"prompt",
"input_reference",
"seconds",
"size",
"user",
"extra_headers",
]
def map_openai_params(
self,
video_create_optional_params: VideoCreateOptionalRequestParams,
model: str,
drop_params: bool,
) -> Dict:
"""No mapping applied since inputs are in OpenAI spec already"""
return dict(video_create_optional_params)
def validate_environment(
self,
headers: dict,
model: str,
api_key: Optional[str] = None,
litellm_params: Optional[GenericLiteLLMParams] = None,
) -> dict:
"""
Validate Azure environment and set up authentication headers.
Uses _base_validate_azure_environment to properly handle credentials from litellm_credential_name.
"""
# If litellm_params is provided, use it; otherwise create a new one
if litellm_params is None:
litellm_params = GenericLiteLLMParams()
if api_key and not litellm_params.api_key:
litellm_params.api_key = api_key
# Use the base Azure validation method which properly handles:
# 1. Credentials from litellm_credential_name via litellm_params
# 2. Sets the correct "api-key" header (not "Authorization: Bearer")
return BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params
)
def get_complete_url(
self,
model: str,
api_base: Optional[str],
litellm_params: dict,
) -> str:
"""
Constructs a complete URL for the API request.
"""
return BaseAzureLLM._get_base_azure_url(
api_base=api_base,
litellm_params=litellm_params,
route="/openai/v1/videos",
default_api_version="",
)

View File

@@ -0,0 +1 @@
`/chat/completion` calls routed via `openai.py`.

View File

@@ -0,0 +1,11 @@
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
from litellm.llms.azure_ai.agents.transformation import (
AzureAIAgentsConfig,
AzureAIAgentsError,
)
__all__ = [
"AzureAIAgentsConfig",
"AzureAIAgentsError",
"azure_ai_agents_handler",
]

View File

@@ -0,0 +1,659 @@
"""
Handler for Azure Foundry Agent Service API.
This handler executes the multi-step agent flow:
1. Create thread (or use existing)
2. Add messages to thread
3. Create and poll a run
4. Retrieve the assistant's response messages
Model format: azure_ai/agents/<agent_id>
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
Authentication: Uses Azure AD Bearer tokens (not API keys)
Get token via: az account get-access-token --resource 'https://ai.azure.com'
Supports both polling-based and native streaming (SSE) modes.
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
import asyncio
import json
import time
import uuid
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
List,
Optional,
Tuple,
)
import httpx
from litellm._logging import verbose_logger
from litellm.llms.azure_ai.agents.transformation import (
AzureAIAgentsConfig,
AzureAIAgentsError,
)
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
class AzureAIAgentsHandler:
"""
Handler for Azure AI Agent Service.
Executes the complete agent flow which requires multiple API calls.
"""
def __init__(self):
self.config = AzureAIAgentsConfig()
# -------------------------------------------------------------------------
# URL Builders
# -------------------------------------------------------------------------
# Azure Foundry Agents API uses /assistants, /threads, etc. directly
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
# -------------------------------------------------------------------------
def _build_thread_url(self, api_base: str, api_version: str) -> str:
return f"{api_base}/threads?api-version={api_version}"
def _build_messages_url(
self, api_base: str, thread_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
def _build_runs_url(self, api_base: str, thread_id: str, api_version: str) -> str:
return f"{api_base}/threads/{thread_id}/runs?api-version={api_version}"
def _build_run_status_url(
self, api_base: str, thread_id: str, run_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/runs/{run_id}?api-version={api_version}"
def _build_list_messages_url(
self, api_base: str, thread_id: str, api_version: str
) -> str:
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
def _build_create_thread_and_run_url(self, api_base: str, api_version: str) -> str:
"""URL for the create-thread-and-run endpoint (supports streaming)."""
return f"{api_base}/threads/runs?api-version={api_version}"
# -------------------------------------------------------------------------
# Response Helpers
# -------------------------------------------------------------------------
def _extract_content_from_messages(self, messages_data: dict) -> str:
"""Extract assistant content from the messages response."""
for msg in messages_data.get("data", []):
if msg.get("role") == "assistant":
for content_item in msg.get("content", []):
if content_item.get("type") == "text":
return content_item.get("text", {}).get("value", "")
return ""
def _build_model_response(
self,
model: str,
content: str,
model_response: ModelResponse,
thread_id: str,
messages: List[Dict[str, Any]],
) -> ModelResponse:
"""Build the ModelResponse from agent output."""
from litellm.types.utils import Choices, Message, Usage
model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message=Message(content=content, role="assistant"),
)
]
model_response.model = model
# Store thread_id for conversation continuity
if (
not hasattr(model_response, "_hidden_params")
or model_response._hidden_params is None
):
model_response._hidden_params = {}
model_response._hidden_params["thread_id"] = thread_id
# Estimate token usage
try:
from litellm.utils import token_counter
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
completion_tokens = token_counter(
model="gpt-3.5-turbo", text=content, count_response_tokens=True
)
setattr(
model_response,
"usage",
Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
except Exception as e:
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
return model_response
def _prepare_completion_params(
self,
model: str,
api_base: str,
api_key: str,
optional_params: dict,
headers: Optional[dict],
) -> tuple:
"""Prepare common parameters for completion.
Azure Foundry Agents API uses Bearer token authentication:
- Authorization: Bearer <token> (Azure AD token from 'az account get-access-token --resource https://ai.azure.com')
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
if headers is None:
headers = {}
headers["Content-Type"] = "application/json"
# Azure Foundry Agents uses Bearer token authentication
# The api_key here is expected to be an Azure AD token
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
api_version = optional_params.get(
"api_version", self.config.DEFAULT_API_VERSION
)
agent_id = self.config._get_agent_id(model, optional_params)
thread_id = optional_params.get("thread_id")
api_base = api_base.rstrip("/")
verbose_logger.debug(
f"Azure AI Agents completion - api_base: {api_base}, agent_id: {agent_id}"
)
return headers, api_version, agent_id, thread_id, api_base
def _check_response(
self, response: httpx.Response, expected_codes: List[int], error_msg: str
):
"""Check response status and raise error if not expected."""
if response.status_code not in expected_codes:
raise AzureAIAgentsError(
status_code=response.status_code,
message=f"{error_msg}: {response.text}",
)
# -------------------------------------------------------------------------
# Sync Completion
# -------------------------------------------------------------------------
def completion(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
client: Optional[HTTPHandler] = None,
headers: Optional[dict] = None,
) -> ModelResponse:
"""Execute synchronous completion using Azure Agent Service."""
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
if client is None:
client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
def make_request(
method: str, url: str, json_data: Optional[dict] = None
) -> httpx.Response:
if method == "GET":
return client.get(url=url, headers=headers)
return client.post(
url=url,
headers=headers,
data=json.dumps(json_data) if json_data else None,
)
# Execute the agent flow
thread_id, content = self._execute_agent_flow_sync(
make_request=make_request,
api_base=api_base,
api_version=api_version,
agent_id=agent_id,
thread_id=thread_id,
messages=messages,
optional_params=optional_params,
)
return self._build_model_response(
model, content, model_response, thread_id, messages
)
def _execute_agent_flow_sync(
self,
make_request: Callable,
api_base: str,
api_version: str,
agent_id: str,
thread_id: Optional[str],
messages: List[Dict[str, Any]],
optional_params: dict,
) -> Tuple[str, str]:
"""Execute the agent flow synchronously. Returns (thread_id, content)."""
# Step 1: Create thread if not provided
if not thread_id:
verbose_logger.debug(
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
)
response = make_request(
"POST", self._build_thread_url(api_base, api_version), {}
)
self._check_response(response, [200, 201], "Failed to create thread")
thread_id = response.json()["id"]
verbose_logger.debug(f"Created thread: {thread_id}")
# At this point thread_id is guaranteed to be a string
assert thread_id is not None
# Step 2: Add messages to thread
for msg in messages:
if msg.get("role") in ["user", "system"]:
url = self._build_messages_url(api_base, thread_id, api_version)
response = make_request(
"POST", url, {"role": "user", "content": msg.get("content", "")}
)
self._check_response(response, [200, 201], "Failed to add message")
# Step 3: Create run
run_payload = {"assistant_id": agent_id}
if "instructions" in optional_params:
run_payload["instructions"] = optional_params["instructions"]
response = make_request(
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
)
self._check_response(response, [200, 201], "Failed to create run")
run_id = response.json()["id"]
verbose_logger.debug(f"Created run: {run_id}")
# Step 4: Poll for completion
status_url = self._build_run_status_url(
api_base, thread_id, run_id, api_version
)
for _ in range(self.config.MAX_POLL_ATTEMPTS):
response = make_request("GET", status_url)
self._check_response(response, [200], "Failed to get run status")
status = response.json().get("status")
verbose_logger.debug(f"Run status: {status}")
if status == "completed":
break
elif status in ["failed", "cancelled", "expired"]:
error_msg = (
response.json()
.get("last_error", {})
.get("message", "Unknown error")
)
raise AzureAIAgentsError(
status_code=500, message=f"Run {status}: {error_msg}"
)
time.sleep(self.config.POLL_INTERVAL_SECONDS)
else:
raise AzureAIAgentsError(
status_code=408, message="Run timed out waiting for completion"
)
# Step 5: Get messages
response = make_request(
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
)
self._check_response(response, [200], "Failed to get messages")
content = self._extract_content_from_messages(response.json())
return thread_id, content
# -------------------------------------------------------------------------
# Async Completion
# -------------------------------------------------------------------------
async def acompletion(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
client: Optional[AsyncHTTPHandler] = None,
headers: Optional[dict] = None,
) -> ModelResponse:
"""Execute asynchronous completion using Azure Agent Service."""
import litellm
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
if client is None:
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
async def make_request(
method: str, url: str, json_data: Optional[dict] = None
) -> httpx.Response:
if method == "GET":
return await client.get(url=url, headers=headers)
return await client.post(
url=url,
headers=headers,
data=json.dumps(json_data) if json_data else None,
)
# Execute the agent flow
thread_id, content = await self._execute_agent_flow_async(
make_request=make_request,
api_base=api_base,
api_version=api_version,
agent_id=agent_id,
thread_id=thread_id,
messages=messages,
optional_params=optional_params,
)
return self._build_model_response(
model, content, model_response, thread_id, messages
)
async def _execute_agent_flow_async(
self,
make_request: Callable,
api_base: str,
api_version: str,
agent_id: str,
thread_id: Optional[str],
messages: List[Dict[str, Any]],
optional_params: dict,
) -> Tuple[str, str]:
"""Execute the agent flow asynchronously. Returns (thread_id, content)."""
# Step 1: Create thread if not provided
if not thread_id:
verbose_logger.debug(
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
)
response = await make_request(
"POST", self._build_thread_url(api_base, api_version), {}
)
self._check_response(response, [200, 201], "Failed to create thread")
thread_id = response.json()["id"]
verbose_logger.debug(f"Created thread: {thread_id}")
# At this point thread_id is guaranteed to be a string
assert thread_id is not None
# Step 2: Add messages to thread
for msg in messages:
if msg.get("role") in ["user", "system"]:
url = self._build_messages_url(api_base, thread_id, api_version)
response = await make_request(
"POST", url, {"role": "user", "content": msg.get("content", "")}
)
self._check_response(response, [200, 201], "Failed to add message")
# Step 3: Create run
run_payload = {"assistant_id": agent_id}
if "instructions" in optional_params:
run_payload["instructions"] = optional_params["instructions"]
response = await make_request(
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
)
self._check_response(response, [200, 201], "Failed to create run")
run_id = response.json()["id"]
verbose_logger.debug(f"Created run: {run_id}")
# Step 4: Poll for completion
status_url = self._build_run_status_url(
api_base, thread_id, run_id, api_version
)
for _ in range(self.config.MAX_POLL_ATTEMPTS):
response = await make_request("GET", status_url)
self._check_response(response, [200], "Failed to get run status")
status = response.json().get("status")
verbose_logger.debug(f"Run status: {status}")
if status == "completed":
break
elif status in ["failed", "cancelled", "expired"]:
error_msg = (
response.json()
.get("last_error", {})
.get("message", "Unknown error")
)
raise AzureAIAgentsError(
status_code=500, message=f"Run {status}: {error_msg}"
)
await asyncio.sleep(self.config.POLL_INTERVAL_SECONDS)
else:
raise AzureAIAgentsError(
status_code=408, message="Run timed out waiting for completion"
)
# Step 5: Get messages
response = await make_request(
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
)
self._check_response(response, [200], "Failed to get messages")
content = self._extract_content_from_messages(response.json())
return thread_id, content
# -------------------------------------------------------------------------
# Streaming Completion (Native SSE)
# -------------------------------------------------------------------------
async def acompletion_stream(
self,
model: str,
messages: List[Dict[str, Any]],
api_base: str,
api_key: str,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: float,
headers: Optional[dict] = None,
) -> AsyncIterator:
"""Execute async streaming completion using Azure Agent Service with native SSE."""
import litellm
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
(
headers,
api_version,
agent_id,
thread_id,
api_base,
) = self._prepare_completion_params(
model, api_base, api_key, optional_params, headers
)
# Build payload for create-thread-and-run with streaming
thread_messages = []
for msg in messages:
if msg.get("role") in ["user", "system"]:
thread_messages.append(
{"role": "user", "content": msg.get("content", "")}
)
payload: Dict[str, Any] = {
"assistant_id": agent_id,
"stream": True,
}
# Add thread with messages if we don't have an existing thread
if not thread_id:
payload["thread"] = {"messages": thread_messages}
if "instructions" in optional_params:
payload["instructions"] = optional_params["instructions"]
url = self._build_create_thread_and_run_url(api_base, api_version)
verbose_logger.debug(f"Azure AI Agents streaming - URL: {url}")
# Use LiteLLM's async HTTP client for streaming
client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
response = await client.post(
url=url,
headers=headers,
data=json.dumps(payload),
stream=True,
)
if response.status_code not in [200, 201]:
error_text = await response.aread()
raise AzureAIAgentsError(
status_code=response.status_code,
message=f"Streaming request failed: {error_text.decode()}",
)
async for chunk in self._process_sse_stream(response, model):
yield chunk
async def _process_sse_stream(
self,
response: httpx.Response,
model: str,
) -> AsyncIterator:
"""Process SSE stream and yield OpenAI-compatible streaming chunks."""
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
created = int(time.time())
thread_id = None
current_event = None
async for line in response.aiter_lines():
line = line.strip()
if line.startswith("event:"):
current_event = line[6:].strip()
continue
if line.startswith("data:"):
data_str = line[5:].strip()
if data_str == "[DONE]":
# Send final chunk with finish_reason
final_chunk = ModelResponseStream(
id=response_id,
created=created,
model=model,
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason="stop",
index=0,
delta=Delta(content=None),
)
],
)
if thread_id:
final_chunk._hidden_params = {"thread_id": thread_id}
yield final_chunk
return
try:
data = json.loads(data_str)
except json.JSONDecodeError:
continue
# Extract thread_id from thread.created event
if current_event == "thread.created" and "id" in data:
thread_id = data["id"]
verbose_logger.debug(f"Stream created thread: {thread_id}")
# Process message deltas - this is where the actual content comes
if current_event == "thread.message.delta":
delta_content = data.get("delta", {}).get("content", [])
for content_item in delta_content:
if content_item.get("type") == "text":
text_value = content_item.get("text", {}).get("value", "")
if text_value:
chunk = ModelResponseStream(
id=response_id,
created=created,
model=model,
object="chat.completion.chunk",
choices=[
StreamingChoices(
finish_reason=None,
index=0,
delta=Delta(
content=text_value, role="assistant"
),
)
],
)
if thread_id:
chunk._hidden_params = {"thread_id": thread_id}
yield chunk
# Singleton instance
azure_ai_agents_handler = AzureAIAgentsHandler()

View File

@@ -0,0 +1,402 @@
"""
Transformation for Azure Foundry Agent Service API.
Azure Foundry Agent Service provides an Assistants-like API for running agents.
This follows the OpenAI Assistants pattern: create thread -> add messages -> create/poll run.
Model format: azure_ai/agents/<agent_id>
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
Authentication: Uses Azure AD Bearer tokens (not API keys)
Get token via: az account get-access-token --resource 'https://ai.azure.com'
The API uses these endpoints:
- POST /threads - Create a thread
- POST /threads/{thread_id}/messages - Add message to thread
- POST /threads/{thread_id}/runs - Create a run
- GET /threads/{thread_id}/runs/{run_id} - Poll run status
- GET /threads/{thread_id}/messages - List messages in thread
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
HTTPHandler = Any
AsyncHTTPHandler = Any
class AzureAIAgentsError(BaseLLMException):
"""Exception class for Azure AI Agent Service API errors."""
pass
class AzureAIAgentsConfig(BaseConfig):
"""
Configuration for Azure AI Agent Service API.
Azure AI Agent Service is a fully managed service for building AI agents
that can understand natural language and perform tasks.
Model format: azure_ai/agents/<agent_id>
The flow is:
1. Create a thread
2. Add user messages to the thread
3. Create and poll a run
4. Retrieve the assistant's response messages
"""
# Default API version for Azure Foundry Agent Service
# GA version: 2025-05-01, Preview: 2025-05-15-preview
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
DEFAULT_API_VERSION = "2025-05-01"
# Polling configuration
MAX_POLL_ATTEMPTS = 60
POLL_INTERVAL_SECONDS = 1.0
def __init__(self, **kwargs):
super().__init__(**kwargs)
@staticmethod
def is_azure_ai_agents_route(model: str) -> bool:
"""
Check if the model is an Azure AI Agents route.
Model format: azure_ai/agents/<agent_id>
"""
return "agents/" in model
@staticmethod
def get_agent_id_from_model(model: str) -> str:
"""
Extract agent ID from the model string.
Model format: azure_ai/agents/<agent_id> -> <agent_id>
or: agents/<agent_id> -> <agent_id>
"""
if "agents/" in model:
# Split on "agents/" and take the part after it
parts = model.split("agents/", 1)
if len(parts) == 2:
return parts[1]
return model
def _get_openai_compatible_provider_info(
self,
api_base: Optional[str],
api_key: Optional[str],
) -> Tuple[Optional[str], Optional[str]]:
"""
Get Azure AI Agent Service API base and key from params or environment.
Returns:
Tuple of (api_base, api_key)
"""
from litellm.secret_managers.main import get_secret_str
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
return api_base, api_key
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Azure Agents supports minimal OpenAI params since it's an agent runtime.
"""
return ["stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI params to Azure Agents params.
"""
return optional_params
def _get_api_version(self, optional_params: dict) -> str:
"""Get API version from optional params or use default."""
return optional_params.get("api_version", self.DEFAULT_API_VERSION)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the base URL for Azure AI Agent Service.
The actual endpoint will vary based on the operation:
- /openai/threads for creating threads
- /openai/threads/{thread_id}/messages for adding messages
- /openai/threads/{thread_id}/runs for creating runs
This returns the base URL that will be modified for each operation.
"""
if api_base is None:
raise ValueError(
"api_base is required for Azure AI Agents. Set it via AZURE_AI_API_BASE env var or api_base parameter."
)
# Remove trailing slash if present
api_base = api_base.rstrip("/")
# Return base URL - actual endpoints will be constructed during request
return api_base
def _get_agent_id(self, model: str, optional_params: dict) -> str:
"""
Get the agent ID from model or optional_params.
model format: "azure_ai/agents/<agent_id>" or "agents/<agent_id>" or just "<agent_id>"
"""
agent_id = optional_params.get("agent_id") or optional_params.get(
"assistant_id"
)
if agent_id:
return agent_id
# Extract from model name using the static method
return self.get_agent_id_from_model(model)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request for Azure Agents.
This stores the necessary data for the multi-step agent flow.
The actual API calls happen in the custom handler.
"""
agent_id = self._get_agent_id(model, optional_params)
# Convert messages to a format we can use
converted_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Handle content that might be a list
if isinstance(content, list):
content = convert_content_list_to_str(msg)
# Ensure content is a string
if not isinstance(content, str):
content = str(content)
converted_messages.append({"role": role, "content": content})
payload: Dict[str, Any] = {
"agent_id": agent_id,
"messages": converted_messages,
"api_version": self._get_api_version(optional_params),
}
# Pass through thread_id if provided (for continuing conversations)
if "thread_id" in optional_params:
payload["thread_id"] = optional_params["thread_id"]
# Pass through any additional instructions
if "instructions" in optional_params:
payload["instructions"] = optional_params["instructions"]
verbose_logger.debug(f"Azure AI Agents request payload: {payload}")
return payload
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate and set up environment for Azure Foundry Agents requests.
Azure Foundry Agents uses Bearer token authentication with Azure AD tokens.
Get token via: az account get-access-token --resource 'https://ai.azure.com'
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
headers["Content-Type"] = "application/json"
# Azure Foundry Agents uses Bearer token authentication
# The api_key here is expected to be an Azure AD token
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return AzureAIAgentsError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
"""
Azure Agents uses polling, so we fake stream by returning the final response.
"""
return True
@property
def has_custom_stream_wrapper(self) -> bool:
"""Azure Agents doesn't have native streaming - uses fake stream."""
return False
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Azure Agents does not use a stream param in request body.
"""
return False
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform the Azure Agents response to LiteLLM ModelResponse format.
"""
# This is not used since we have a custom handler
return model_response
@staticmethod
def completion(
model: str,
messages: List,
api_base: str,
api_key: Optional[str],
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
optional_params: dict,
litellm_params: dict,
timeout: Union[float, int, Any],
acompletion: bool,
stream: Optional[bool] = False,
headers: Optional[dict] = None,
) -> Any:
"""
Dispatch method for Azure Foundry Agents completion.
Routes to sync or async completion based on acompletion flag.
Supports native streaming via SSE when stream=True and acompletion=True.
Authentication: Uses Azure AD Bearer tokens.
- Pass api_key directly as an Azure AD token
- Or set up Azure AD credentials via environment variables for automatic token retrieval:
- AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET (Service Principal)
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
"""
from litellm.llms.azure.common_utils import get_azure_ad_token
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
from litellm.types.router import GenericLiteLLMParams
# If no api_key is provided, try to get Azure AD token
if api_key is None:
# Try to get Azure AD token using the existing Azure auth mechanisms
# This uses the scope for Azure AI (ai.azure.com) instead of cognitive services
# Create a GenericLiteLLMParams with the scope override for Azure Foundry Agents
azure_auth_params = dict(litellm_params) if litellm_params else {}
azure_auth_params["azure_scope"] = "https://ai.azure.com/.default"
api_key = get_azure_ad_token(GenericLiteLLMParams(**azure_auth_params))
if api_key is None:
raise ValueError(
"api_key (Azure AD token) is required for Azure Foundry Agents. "
"Either pass api_key directly, or set AZURE_TENANT_ID, AZURE_CLIENT_ID, "
"and AZURE_CLIENT_SECRET environment variables for Service Principal auth. "
"Manual token: az account get-access-token --resource 'https://ai.azure.com'"
)
if acompletion:
if stream:
# Native async streaming via SSE - return the async generator directly
return azure_ai_agents_handler.acompletion_stream(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)
else:
return azure_ai_agents_handler.acompletion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)
else:
# Sync completion - streaming not supported for sync
return azure_ai_agents_handler.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=api_key,
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
timeout=timeout,
headers=headers,
)

View File

@@ -0,0 +1,16 @@
"""
Azure Anthropic provider - supports Claude models via Azure Foundry
"""
from .handler import AzureAnthropicChatCompletion
from .transformation import AzureAnthropicConfig
try:
from .messages_transformation import AzureAnthropicMessagesConfig
__all__ = [
"AzureAnthropicChatCompletion",
"AzureAnthropicConfig",
"AzureAnthropicMessagesConfig",
]
except ImportError:
__all__ = ["AzureAnthropicChatCompletion", "AzureAnthropicConfig"]

View File

@@ -0,0 +1,19 @@
"""
Azure AI Anthropic CountTokens API implementation.
"""
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
AzureAIAnthropicCountTokensHandler,
)
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
AzureAIAnthropicTokenCounter,
)
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
AzureAIAnthropicCountTokensConfig,
)
__all__ = [
"AzureAIAnthropicCountTokensHandler",
"AzureAIAnthropicCountTokensConfig",
"AzureAIAnthropicTokenCounter",
]

View File

@@ -0,0 +1,133 @@
"""
Azure AI Anthropic CountTokens API handler.
Uses httpx for HTTP requests with Azure authentication.
"""
from typing import Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.llms.anthropic.common_utils import AnthropicError
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
AzureAIAnthropicCountTokensConfig,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
class AzureAIAnthropicCountTokensHandler(AzureAIAnthropicCountTokensConfig):
"""
Handler for Azure AI Anthropic CountTokens API requests.
Uses httpx for HTTP requests with Azure authentication.
"""
async def handle_count_tokens_request(
self,
model: str,
messages: List[Dict[str, Any]],
api_key: str,
api_base: str,
litellm_params: Optional[Dict[str, Any]] = None,
timeout: Optional[Union[float, httpx.Timeout]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Dict[str, Any]:
"""
Handle a CountTokens request using httpx with Azure authentication.
Args:
model: The model identifier (e.g., "claude-3-5-sonnet")
messages: The messages to count tokens for
api_key: The Azure AI API key
api_base: The Azure AI API base URL
litellm_params: Optional LiteLLM parameters
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
Returns:
Dictionary containing token count response
Raises:
AnthropicError: If the API request fails
"""
try:
# Validate the request
self.validate_request(model, messages)
verbose_logger.debug(
f"Processing Azure AI Anthropic CountTokens request for model: {model}"
)
# Transform request to Anthropic format
request_body = self.transform_request_to_count_tokens(
model=model,
messages=messages,
tools=tools,
system=system,
)
verbose_logger.debug(f"Transformed request: {request_body}")
# Get endpoint URL
endpoint_url = self.get_count_tokens_endpoint(api_base)
verbose_logger.debug(f"Making request to: {endpoint_url}")
# Get required headers with Azure authentication
headers = self.get_required_headers(
api_key=api_key,
litellm_params=litellm_params,
)
# Use LiteLLM's async httpx client
async_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.AZURE_AI
)
# Use provided timeout or fall back to litellm.request_timeout
request_timeout = (
timeout if timeout is not None else litellm.request_timeout
)
response = await async_client.post(
endpoint_url,
headers=headers,
json=request_body,
timeout=request_timeout,
)
verbose_logger.debug(f"Response status: {response.status_code}")
if response.status_code != 200:
error_text = response.text
verbose_logger.error(f"Azure AI Anthropic API error: {error_text}")
raise AnthropicError(
status_code=response.status_code,
message=error_text,
)
azure_response = response.json()
verbose_logger.debug(f"Azure AI Anthropic response: {azure_response}")
# Return Anthropic-compatible response directly - no transformation needed
return azure_response
except AnthropicError:
# Re-raise Anthropic exceptions as-is
raise
except httpx.HTTPStatusError as e:
# HTTP errors - preserve the actual status code
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=e.response.status_code,
message=e.response.text,
)
except Exception as e:
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
raise AnthropicError(
status_code=500,
message=f"CountTokens processing error: {str(e)}",
)

View File

@@ -0,0 +1,123 @@
"""
Azure AI Anthropic Token Counter implementation using the CountTokens API.
"""
import os
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
AzureAIAnthropicCountTokensHandler,
)
from litellm.llms.base_llm.base_utils import BaseTokenCounter
from litellm.types.utils import LlmProviders, TokenCountResponse
# Global handler instance - reuse across all token counting requests
azure_ai_anthropic_count_tokens_handler = AzureAIAnthropicCountTokensHandler()
class AzureAIAnthropicTokenCounter(BaseTokenCounter):
"""Token counter implementation for Azure AI Anthropic provider using the CountTokens API."""
def should_use_token_counting_api(
self,
custom_llm_provider: Optional[str] = None,
) -> bool:
return custom_llm_provider == LlmProviders.AZURE_AI.value
async def count_tokens(
self,
model_to_use: str,
messages: Optional[List[Dict[str, Any]]],
contents: Optional[List[Dict[str, Any]]],
deployment: Optional[Dict[str, Any]] = None,
request_model: str = "",
tools: Optional[List[Dict[str, Any]]] = None,
system: Optional[Any] = None,
) -> Optional[TokenCountResponse]:
"""
Count tokens using Azure AI Anthropic's CountTokens API.
Args:
model_to_use: The model identifier
messages: The messages to count tokens for
contents: Alternative content format (not used for Anthropic)
deployment: Deployment configuration containing litellm_params
request_model: The original request model name
Returns:
TokenCountResponse with token count, or None if counting fails
"""
from litellm.llms.anthropic.common_utils import AnthropicError
if not messages:
return None
deployment = deployment or {}
litellm_params = deployment.get("litellm_params", {})
# Get Azure AI API key from deployment config or environment
api_key = litellm_params.get("api_key")
if not api_key:
api_key = os.getenv("AZURE_AI_API_KEY")
# Get API base from deployment config or environment
api_base = litellm_params.get("api_base")
if not api_base:
api_base = os.getenv("AZURE_AI_API_BASE")
if not api_key:
verbose_logger.warning("No Azure AI API key found for token counting")
return None
if not api_base:
verbose_logger.warning("No Azure AI API base found for token counting")
return None
try:
result = await azure_ai_anthropic_count_tokens_handler.handle_count_tokens_request(
model=model_to_use,
messages=messages,
api_key=api_key,
api_base=api_base,
litellm_params=litellm_params,
tools=tools,
system=system,
)
if result is not None:
return TokenCountResponse(
total_tokens=result.get("input_tokens", 0),
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
original_response=result,
)
except AnthropicError as e:
verbose_logger.warning(
f"Azure AI Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
error=True,
error_message=e.message,
status_code=e.status_code,
)
except Exception as e:
verbose_logger.warning(
f"Error calling Azure AI Anthropic CountTokens API: {e}"
)
return TokenCountResponse(
total_tokens=0,
request_model=request_model,
model_used=model_to_use,
tokenizer_type="azure_ai_anthropic_api",
error=True,
error_message=str(e),
status_code=500,
)
return None

View File

@@ -0,0 +1,90 @@
"""
Azure AI Anthropic CountTokens API transformation logic.
Extends the base Anthropic CountTokens transformation with Azure authentication.
"""
from typing import Any, Dict, Optional
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
from litellm.llms.anthropic.count_tokens.transformation import (
AnthropicCountTokensConfig,
)
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.router import GenericLiteLLMParams
class AzureAIAnthropicCountTokensConfig(AnthropicCountTokensConfig):
"""
Configuration and transformation logic for Azure AI Anthropic CountTokens API.
Extends AnthropicCountTokensConfig with Azure authentication.
Azure AI Anthropic uses the same endpoint format but with Azure auth headers.
"""
def get_required_headers(
self,
api_key: str,
litellm_params: Optional[Dict[str, Any]] = None,
) -> Dict[str, str]:
"""
Get the required headers for the Azure AI Anthropic CountTokens API.
Azure AI Anthropic uses Anthropic's native API format, which requires the
x-api-key header for authentication (in addition to Azure's api-key header).
Args:
api_key: The Azure AI API key
litellm_params: Optional LiteLLM parameters for additional auth config
Returns:
Dictionary of required headers with both x-api-key and Azure authentication
"""
# Start with base headers including x-api-key for Anthropic API compatibility
headers = {
"Content-Type": "application/json",
"anthropic-version": "2023-06-01",
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
"x-api-key": api_key, # Azure AI Anthropic requires this header
}
# Also set up Azure auth headers for flexibility
litellm_params = litellm_params or {}
if "api_key" not in litellm_params:
litellm_params["api_key"] = api_key
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
# Get Azure auth headers (api-key or Authorization)
azure_headers = BaseAzureLLM._base_validate_azure_environment(
headers={}, litellm_params=litellm_params_obj
)
# Merge Azure auth headers
headers.update(azure_headers)
return headers
def get_count_tokens_endpoint(self, api_base: str) -> str:
"""
Get the Azure AI Anthropic CountTokens API endpoint.
Args:
api_base: The Azure AI API base URL
(e.g., https://my-resource.services.ai.azure.com or
https://my-resource.services.ai.azure.com/anthropic)
Returns:
The endpoint URL for the CountTokens API
"""
# Azure AI Anthropic endpoint format:
# https://<resource>.services.ai.azure.com/anthropic/v1/messages/count_tokens
api_base = api_base.rstrip("/")
# Ensure the URL has /anthropic path
if not api_base.endswith("/anthropic"):
if "/anthropic" not in api_base:
api_base = f"{api_base}/anthropic"
# Add the count_tokens path
return f"{api_base}/v1/messages/count_tokens"

View File

@@ -0,0 +1,226 @@
"""
Azure Anthropic handler - reuses AnthropicChatCompletion logic with Azure authentication
"""
import copy
import json
from typing import TYPE_CHECKING, Callable, Union
import httpx
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
)
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from .transformation import AzureAnthropicConfig
if TYPE_CHECKING:
pass
class AzureAnthropicChatCompletion(AnthropicChatCompletion):
"""
Azure Anthropic chat completion handler.
Reuses all Anthropic logic but with Azure authentication.
"""
def __init__(self) -> None:
super().__init__()
def completion(
self,
model: str,
messages: list,
api_base: str,
custom_llm_provider: str,
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
timeout: Union[float, httpx.Timeout],
litellm_params: dict,
acompletion=None,
logger_fn=None,
headers={},
client=None,
):
"""
Completion method that uses Azure authentication instead of Anthropic's x-api-key.
All other logic is the same as AnthropicChatCompletion.
"""
optional_params = copy.deepcopy(optional_params)
stream = optional_params.pop("stream", None)
json_mode: bool = optional_params.pop("json_mode", False)
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
_is_function_call = False
messages = copy.deepcopy(messages)
# Use AzureAnthropicConfig for both azure_anthropic and azure_ai Claude models
config = AzureAnthropicConfig()
headers = config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
litellm_params=litellm_params,
)
data = config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": headers,
},
)
print_verbose(f"_is_function_call: {_is_function_call}")
if acompletion is True:
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
print_verbose("makes async azure anthropic streaming POST request")
data["stream"] = stream
return self.acompletion_stream_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
json_mode=json_mode,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=(
client
if client is not None and isinstance(client, AsyncHTTPHandler)
else None
),
)
else:
return self.acompletion_function(
model=model,
messages=messages,
data=data,
api_base=api_base,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
encoding=encoding,
api_key=api_key,
provider_config=config,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream,
_is_function_call=_is_function_call,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
client=client,
json_mode=json_mode,
timeout=timeout,
)
else:
## COMPLETION CALL
if (
stream is True
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
data["stream"] = stream
# Import the make_sync_call from parent
from litellm.llms.anthropic.chat.handler import make_sync_call
completion_stream, response_headers = make_sync_call(
client=client,
api_base=api_base,
headers=headers, # type: ignore
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
timeout=timeout,
json_mode=json_mode,
)
from litellm.llms.anthropic.common_utils import (
process_anthropic_headers,
)
return CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="azure_ai",
logging_obj=logging_obj,
_response_headers=process_anthropic_headers(response_headers),
)
else:
if client is None or not isinstance(client, HTTPHandler):
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
client = _get_httpx_client(params={"timeout": timeout})
else:
client = client
try:
response = client.post(
api_base,
headers=headers,
data=json.dumps(data),
timeout=timeout,
)
except Exception as e:
from litellm.llms.anthropic.common_utils import AnthropicError
status_code = getattr(e, "status_code", 500)
error_headers = getattr(e, "headers", None)
error_text = getattr(e, "text", str(e))
error_response = getattr(e, "response", None)
if error_headers is None and error_response:
error_headers = getattr(error_response, "headers", None)
if error_response and hasattr(error_response, "text"):
error_text = getattr(error_response, "text", error_text)
raise AnthropicError(
message=error_text,
status_code=status_code,
headers=error_headers,
)
return config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
json_mode=json_mode,
)

View File

@@ -0,0 +1,166 @@
"""
Azure Anthropic messages transformation config - extends AnthropicMessagesConfig with Azure authentication
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig,
)
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
pass
class AzureAnthropicMessagesConfig(AnthropicMessagesConfig):
"""
Azure Anthropic messages configuration that extends AnthropicMessagesConfig.
The only difference is authentication - Azure uses x-api-key header (not api-key)
and Azure endpoint format.
"""
def validate_anthropic_messages_environment(
self,
headers: dict,
model: str,
messages: List[Any],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Tuple[dict, Optional[str]]:
"""
Validate environment and set up Azure authentication headers for /v1/messages endpoint.
Azure Anthropic uses x-api-key header (not api-key).
"""
# Convert dict to GenericLiteLLMParams if needed
if isinstance(litellm_params, dict):
if api_key and "api_key" not in litellm_params:
litellm_params = {**litellm_params, "api_key": api_key}
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
else:
litellm_params_obj = litellm_params or GenericLiteLLMParams()
if api_key and not litellm_params_obj.api_key:
litellm_params_obj.api_key = api_key
# Use Azure authentication logic
headers = BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params_obj
)
# Azure Anthropic uses x-api-key header (not api-key)
# Convert api-key to x-api-key if present
if "api-key" in headers and "x-api-key" not in headers:
headers["x-api-key"] = headers.pop("api-key")
# Set anthropic-version header
if "anthropic-version" not in headers:
headers["anthropic-version"] = "2023-06-01"
# Set content-type header
if "content-type" not in headers:
headers["content-type"] = "application/json"
headers = self._update_headers_with_anthropic_beta(
headers=headers,
optional_params=optional_params,
)
return headers, api_base
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for Azure Anthropic /v1/messages endpoint.
Azure Foundry endpoint format: https://<resource-name>.services.ai.azure.com/anthropic/v1/messages
"""
from litellm.secret_managers.main import get_secret_str
api_base = api_base or get_secret_str("AZURE_API_BASE")
if api_base is None:
raise ValueError(
"Missing Azure API Base - Please set `api_base` or `AZURE_API_BASE` environment variable. "
"Expected format: https://<resource-name>.services.ai.azure.com/anthropic"
)
# Ensure the URL ends with /v1/messages
api_base = api_base.rstrip("/")
if api_base.endswith("/v1/messages"):
# Already correct
pass
elif api_base.endswith("/anthropic/v1/messages"):
# Already correct
pass
else:
# Check if /anthropic is already in the path
if "/anthropic" in api_base:
# /anthropic exists, ensure we end with /anthropic/v1/messages
# Extract the base URL up to and including /anthropic
parts = api_base.split("/anthropic", 1)
api_base = parts[0] + "/anthropic"
else:
# /anthropic not in path, add it
api_base = api_base + "/anthropic"
# Add /v1/messages
api_base = api_base + "/v1/messages"
return api_base
def _remove_scope_from_cache_control(
self, anthropic_messages_request: Dict
) -> None:
"""
Remove `scope` field from cache_control for Azure AI Foundry.
Azure AI Foundry's Anthropic endpoint does not support the `scope` field
(e.g., "global" for cross-request caching). Only `type` and `ttl` are supported.
Processes both `system` and `messages` content blocks.
"""
def _sanitize(cache_control: Any) -> None:
if isinstance(cache_control, dict):
cache_control.pop("scope", None)
def _process_content_list(content: list) -> None:
for item in content:
if isinstance(item, dict) and "cache_control" in item:
_sanitize(item["cache_control"])
if "system" in anthropic_messages_request:
system = anthropic_messages_request["system"]
if isinstance(system, list):
_process_content_list(system)
if "messages" in anthropic_messages_request:
for message in anthropic_messages_request["messages"]:
if isinstance(message, dict) and "content" in message:
content = message["content"]
if isinstance(content, list):
_process_content_list(content)
def transform_anthropic_messages_request(
self,
model: str,
messages: List[Dict],
anthropic_messages_optional_request_params: Dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> Dict:
anthropic_messages_request = super().transform_anthropic_messages_request(
model=model,
messages=messages,
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
litellm_params=litellm_params,
headers=headers,
)
self._remove_scope_from_cache_control(anthropic_messages_request)
return anthropic_messages_request

View File

@@ -0,0 +1,117 @@
"""
Azure Anthropic transformation config - extends AnthropicConfig with Azure authentication
"""
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.azure.common_utils import BaseAzureLLM
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
pass
class AzureAnthropicConfig(AnthropicConfig):
"""
Azure Anthropic configuration that extends AnthropicConfig.
The only difference is authentication - Azure uses api-key header or Azure AD token
instead of x-api-key header.
"""
@property
def custom_llm_provider(self) -> Optional[str]:
return "azure_ai"
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: Union[dict, GenericLiteLLMParams],
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> Dict:
"""
Validate environment and set up Azure authentication headers.
Azure supports:
1. API key via 'api-key' header
2. Azure AD token via 'Authorization: Bearer <token>' header
"""
# Convert dict to GenericLiteLLMParams if needed
if isinstance(litellm_params, dict):
# Ensure api_key is included if provided
if api_key and "api_key" not in litellm_params:
litellm_params = {**litellm_params, "api_key": api_key}
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
else:
litellm_params_obj = litellm_params or GenericLiteLLMParams()
# Set api_key if provided and not already set
if api_key and not litellm_params_obj.api_key:
litellm_params_obj.api_key = api_key
# Use Azure authentication logic
headers = BaseAzureLLM._base_validate_azure_environment(
headers=headers, litellm_params=litellm_params_obj
)
# Get tools and other anthropic-specific setup
tools = optional_params.get("tools")
prompt_caching_set = self.is_cache_control_set(messages=messages)
computer_tool_used = self.is_computer_tool_used(tools=tools)
mcp_server_used = self.is_mcp_server_used(
mcp_servers=optional_params.get("mcp_servers")
)
pdf_used = self.is_pdf_used(messages=messages)
file_id_used = self.is_file_id_used(messages=messages)
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
anthropic_beta_header=headers.get("anthropic-beta")
)
# Get anthropic headers (but we'll replace x-api-key with Azure auth)
anthropic_headers = self.get_anthropic_headers(
computer_tool_used=computer_tool_used,
prompt_caching_set=prompt_caching_set,
pdf_used=pdf_used,
api_key=api_key or "", # Azure auth is already in headers
file_id_used=file_id_used,
is_vertex_request=optional_params.get("is_vertex_request", False),
user_anthropic_beta_headers=user_anthropic_beta_headers,
mcp_server_used=mcp_server_used,
)
# Merge headers - Azure auth (api-key or Authorization) takes precedence
headers = {**anthropic_headers, **headers}
# Ensure anthropic-version header is set
if "anthropic-version" not in headers:
headers["anthropic-version"] = "2023-06-01"
return headers
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform request using parent AnthropicConfig, then remove unsupported params.
Azure Anthropic doesn't support extra_body, max_retries, or stream_options parameters.
"""
# Call parent transform_request
data = super().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
# Remove unsupported parameters for Azure AI Anthropic
data.pop("extra_body", None)
data.pop("max_retries", None)
data.pop("stream_options", None)
return data

Some files were not shown because too many files have changed in this diff Show More