chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,12 @@
|
||||
## File Structure
|
||||
|
||||
### August 27th, 2024
|
||||
|
||||
To make it easy to see how calls are transformed for each model/provider:
|
||||
|
||||
we are working on moving all supported litellm providers to a folder structure, where folder name is the supported litellm provider name.
|
||||
|
||||
Each folder will contain a `*_transformation.py` file, which has all the request/response transformation logic, making it easy to see how calls are modified.
|
||||
|
||||
E.g. `cohere/`, `bedrock/`.
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
import importlib
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Type
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
from . import *
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import (
|
||||
BaseTranslation,
|
||||
)
|
||||
from litellm.types.utils import ModelInfo, Usage
|
||||
|
||||
|
||||
def get_cost_for_web_search_request(
|
||||
custom_llm_provider: str, usage: "Usage", model_info: "ModelInfo"
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for a web search request for a given model.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The custom LLM provider.
|
||||
usage: The usage object.
|
||||
model_info: The model info.
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
from .gemini.cost_calculator import cost_per_web_search_request
|
||||
|
||||
return cost_per_web_search_request(usage=usage, model_info=model_info)
|
||||
elif custom_llm_provider == "anthropic":
|
||||
from .anthropic.cost_calculation import get_cost_for_anthropic_web_search
|
||||
|
||||
return get_cost_for_anthropic_web_search(model_info=model_info, usage=usage)
|
||||
elif custom_llm_provider.startswith("vertex_ai"):
|
||||
from .vertex_ai.gemini.cost_calculator import (
|
||||
cost_per_web_search_request as cost_per_web_search_request_vertex_ai,
|
||||
)
|
||||
|
||||
return cost_per_web_search_request_vertex_ai(usage=usage, model_info=model_info)
|
||||
elif custom_llm_provider == "perplexity":
|
||||
# Perplexity handles search costs internally in its own cost calculator
|
||||
# Return 0.0 to indicate costs are already accounted for
|
||||
return 0.0
|
||||
elif custom_llm_provider == "xai":
|
||||
from .xai.cost_calculator import cost_per_web_search_request
|
||||
|
||||
return cost_per_web_search_request(usage=usage, model_info=model_info)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def discover_guardrail_translation_mappings() -> (
|
||||
Dict[CallTypes, Type["BaseTranslation"]]
|
||||
):
|
||||
"""
|
||||
Discover guardrail translation mappings by scanning the llms directory structure.
|
||||
|
||||
Scans for modules with guardrail_translation_mappings dictionaries and aggregates them.
|
||||
|
||||
Returns:
|
||||
Dict[CallTypes, Type[BaseTranslation]]: A dictionary mapping call types to their translation handler classes
|
||||
"""
|
||||
discovered_mappings: Dict[CallTypes, Type["BaseTranslation"]] = {}
|
||||
|
||||
try:
|
||||
# Get the path to the llms directory
|
||||
current_dir = os.path.dirname(__file__)
|
||||
llms_dir = current_dir
|
||||
|
||||
if not os.path.exists(llms_dir):
|
||||
verbose_logger.debug("llms directory not found")
|
||||
return discovered_mappings
|
||||
|
||||
# Recursively scan for guardrail_translation directories
|
||||
for root, dirs, files in os.walk(llms_dir):
|
||||
# Skip __pycache__ and base_llm directories
|
||||
dirs[:] = [d for d in dirs if not d.startswith("__") and d != "base_llm"]
|
||||
|
||||
# Check if this is a guardrail_translation directory with __init__.py
|
||||
if (
|
||||
os.path.basename(root) == "guardrail_translation"
|
||||
and "__init__.py" in files
|
||||
):
|
||||
# Build the module path relative to litellm
|
||||
rel_path = os.path.relpath(root, os.path.dirname(llms_dir))
|
||||
module_path = "litellm." + rel_path.replace(os.sep, ".")
|
||||
|
||||
try:
|
||||
# Import the module
|
||||
verbose_logger.debug(
|
||||
f"Discovering guardrail translations in: {module_path}"
|
||||
)
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Check for guardrail_translation_mappings dictionary
|
||||
if hasattr(module, "guardrail_translation_mappings"):
|
||||
mappings = getattr(module, "guardrail_translation_mappings")
|
||||
if isinstance(mappings, dict):
|
||||
discovered_mappings.update(mappings)
|
||||
verbose_logger.debug(
|
||||
f"Found guardrail_translation_mappings in {module_path}: {list(mappings.keys())}"
|
||||
)
|
||||
|
||||
except ImportError as e:
|
||||
verbose_logger.error(f"Could not import {module_path}: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error processing {module_path}: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.guardrail_translation import (
|
||||
guardrail_translation_mappings as mcp_guardrail_translation_mappings,
|
||||
)
|
||||
|
||||
discovered_mappings.update(mcp_guardrail_translation_mappings)
|
||||
verbose_logger.debug(
|
||||
"Loaded MCP guardrail translation mappings: %s",
|
||||
list(mcp_guardrail_translation_mappings.keys()),
|
||||
)
|
||||
except ImportError:
|
||||
verbose_logger.debug(
|
||||
"MCP guardrail translation mappings not available; skipping"
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Discovered {len(discovered_mappings)} guardrail translation mappings: {list(discovered_mappings.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error discovering guardrail translation mappings: {e}")
|
||||
|
||||
return discovered_mappings
|
||||
|
||||
|
||||
# Cache the discovered mappings
|
||||
endpoint_guardrail_translation_mappings: Optional[
|
||||
Dict[CallTypes, Type["BaseTranslation"]]
|
||||
] = None
|
||||
|
||||
|
||||
def load_guardrail_translation_mappings():
|
||||
global endpoint_guardrail_translation_mappings
|
||||
if endpoint_guardrail_translation_mappings is None:
|
||||
endpoint_guardrail_translation_mappings = (
|
||||
discover_guardrail_translation_mappings()
|
||||
)
|
||||
return endpoint_guardrail_translation_mappings
|
||||
|
||||
|
||||
def get_guardrail_translation_mapping(call_type: CallTypes) -> Type["BaseTranslation"]:
|
||||
"""
|
||||
Get the guardrail translation handler for a given call type.
|
||||
|
||||
Args:
|
||||
call_type: The type of call (e.g., completion, acompletion, anthropic_messages)
|
||||
|
||||
Returns:
|
||||
The translation handler class for the given call type
|
||||
|
||||
Raises:
|
||||
ValueError: If no translation mapping exists for the given call type
|
||||
"""
|
||||
global endpoint_guardrail_translation_mappings
|
||||
|
||||
# Lazy load the mappings on first access
|
||||
if endpoint_guardrail_translation_mappings is None:
|
||||
endpoint_guardrail_translation_mappings = (
|
||||
discover_guardrail_translation_mappings()
|
||||
)
|
||||
|
||||
# Get the translation handler class for the call type
|
||||
if call_type not in endpoint_guardrail_translation_mappings:
|
||||
raise ValueError(
|
||||
f"No guardrail translation mapping found for call_type: {call_type}. "
|
||||
f"Available mappings: {list(endpoint_guardrail_translation_mappings.keys())}"
|
||||
)
|
||||
|
||||
# Return the handler class directly
|
||||
return endpoint_guardrail_translation_mappings[call_type]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
A2A (Agent-to-Agent) Protocol Provider for LiteLLM
|
||||
"""
|
||||
from .chat.transformation import A2AConfig
|
||||
|
||||
__all__ = ["A2AConfig"]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
A2A Chat Completion Implementation
|
||||
"""
|
||||
from .transformation import A2AConfig
|
||||
|
||||
__all__ = ["A2AConfig"]
|
||||
@@ -0,0 +1,155 @@
|
||||
# A2A Protocol Guardrail Translation Handler
|
||||
|
||||
Handler for processing A2A (Agent-to-Agent) Protocol messages with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes A2A JSON-RPC 2.0 input/output by:
|
||||
1. Extracting text from message parts (`kind: "text"`)
|
||||
2. Applying guardrails to text content
|
||||
3. Mapping guardrailed text back to original structure
|
||||
|
||||
## A2A Protocol Format
|
||||
|
||||
### Input Format (JSON-RPC 2.0)
|
||||
|
||||
```json
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id",
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"kind": "message",
|
||||
"messageId": "...",
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}
|
||||
]
|
||||
},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Output Formats
|
||||
|
||||
The handler supports multiple A2A response formats:
|
||||
|
||||
**Direct message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "message",
|
||||
"parts": [{"kind": "text", "text": "Response text"}]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Nested message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "Response text"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Task with artifacts:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"artifacts": [
|
||||
{"parts": [{"kind": "text", "text": "Artifact text"}]}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Task with status message:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "Status message"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Streaming artifact-update:**
|
||||
```json
|
||||
{
|
||||
"result": {
|
||||
"kind": "artifact-update",
|
||||
"artifact": {
|
||||
"parts": [{"kind": "text", "text": "Streaming text"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with A2A endpoints.
|
||||
|
||||
### Via LiteLLM Proxy
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/a2a/my-agent' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "message/send",
|
||||
"params": {
|
||||
"message": {
|
||||
"kind": "message",
|
||||
"messageId": "msg-1",
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello, my SSN is 123-45-6789"}]
|
||||
},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn"]
|
||||
}
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Specifying Guardrails
|
||||
|
||||
Guardrails can be specified in the A2A request via the `metadata.guardrails` field:
|
||||
|
||||
```json
|
||||
{
|
||||
"params": {
|
||||
"message": {...},
|
||||
"metadata": {
|
||||
"guardrails": ["block-ssn", "pii-filter"]
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `_extract_texts_from_result()`: Custom text extraction from A2A responses
|
||||
- `_extract_texts_from_parts()`: Custom text extraction from message parts
|
||||
- `_apply_text_to_path()`: Custom application of guardrailed text
|
||||
|
||||
## Call Types
|
||||
|
||||
This handler is registered for:
|
||||
- `CallTypes.send_message`: Synchronous A2A message sending
|
||||
- `CallTypes.asend_message`: Asynchronous A2A message sending
|
||||
@@ -0,0 +1,11 @@
|
||||
"""A2A Protocol handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.a2a.chat.guardrail_translation.handler import A2AGuardrailHandler
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.send_message: A2AGuardrailHandler,
|
||||
CallTypes.asend_message: A2AGuardrailHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
A2A Protocol Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for A2A (Agent-to-Agent) Protocol.
|
||||
It handles both JSON-RPC 2.0 input requests and output responses, extracting text
|
||||
from message parts and applying guardrails.
|
||||
|
||||
A2A Protocol Format:
|
||||
- Input: JSON-RPC 2.0 with params.message.parts containing text parts
|
||||
- Output: JSON-RPC 2.0 with result containing message/artifact parts
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class A2AGuardrailHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing A2A Protocol messages with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input messages (pre-call hook) - extracts text from A2A message parts
|
||||
2. Process output responses (post-call hook) - extracts text from A2A response parts
|
||||
|
||||
A2A Message Format:
|
||||
- Input: params.message.parts[].text (where kind == "text")
|
||||
- Output: result.message.parts[].text or result.artifacts[].parts[].text
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process A2A input messages by applying guardrails to text content.
|
||||
|
||||
Extracts text from A2A message parts and applies guardrails.
|
||||
|
||||
Args:
|
||||
data: The A2A JSON-RPC 2.0 request data
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to text content
|
||||
"""
|
||||
# A2A request format: { "params": { "message": { "parts": [...] } } }
|
||||
params = data.get("params", {})
|
||||
message = params.get("message", {})
|
||||
parts = message.get("parts", [])
|
||||
|
||||
if not parts:
|
||||
verbose_proxy_logger.debug("A2A: No parts in message, skipping guardrail")
|
||||
return data
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
text_part_indices: List[int] = [] # Track which parts contain text
|
||||
|
||||
# Step 1: Extract text from all text parts
|
||||
for part_idx, part in enumerate(parts):
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
texts_to_check.append(text)
|
||||
text_part_indices.append(part_idx)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
|
||||
# Pass the structured A2A message to guardrails
|
||||
inputs["structured_messages"] = [message]
|
||||
|
||||
# Include agent model info if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Apply guardrailed text back to original parts
|
||||
if guardrailed_texts and len(guardrailed_texts) == len(text_part_indices):
|
||||
for task_idx, part_idx in enumerate(text_part_indices):
|
||||
parts[part_idx]["text"] = guardrailed_texts[task_idx]
|
||||
|
||||
verbose_proxy_logger.debug("A2A: Processed input message: %s", message)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: Any,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process A2A output response by applying guardrails to text content.
|
||||
|
||||
Handles multiple A2A response formats:
|
||||
- Direct message: {"result": {"kind": "message", "parts": [...]}}
|
||||
- Nested message: {"result": {"message": {"parts": [...]}}}
|
||||
- Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
|
||||
- Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
|
||||
|
||||
Args:
|
||||
response: A2A JSON-RPC 2.0 response dict or object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata
|
||||
|
||||
Returns:
|
||||
Modified response with guardrails applied to text content
|
||||
"""
|
||||
# Handle both dict and Pydantic model responses
|
||||
if hasattr(response, "model_dump"):
|
||||
response_dict = response.model_dump()
|
||||
is_pydantic = True
|
||||
elif isinstance(response, dict):
|
||||
response_dict = response
|
||||
is_pydantic = False
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"A2A: Unknown response type %s, skipping guardrail", type(response)
|
||||
)
|
||||
return response
|
||||
|
||||
result = response_dict.get("result", {})
|
||||
if not result or not isinstance(result, dict):
|
||||
verbose_proxy_logger.debug("A2A: No result in response, skipping guardrail")
|
||||
return response
|
||||
|
||||
# Find all text-containing parts in the response
|
||||
texts_to_check: List[str] = []
|
||||
# Each mapping is (path_to_parts_list, part_index)
|
||||
# path_to_parts_list is a tuple of keys to navigate to the parts list
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]] = []
|
||||
|
||||
# Extract texts from all possible locations
|
||||
self._extract_texts_from_result(
|
||||
result=result,
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
if not texts_to_check:
|
||||
verbose_proxy_logger.debug("A2A: No text content in response")
|
||||
return response
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response_dict}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Apply guardrailed text back to original response
|
||||
if guardrailed_texts and len(guardrailed_texts) == len(task_mappings):
|
||||
for task_idx, (path, part_idx) in enumerate(task_mappings):
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text=guardrailed_texts[task_idx],
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("A2A: Processed output response")
|
||||
|
||||
# Update the original response
|
||||
if is_pydantic:
|
||||
# For Pydantic models, we need to update the underlying dict
|
||||
# and the model will reflect the changes
|
||||
response_dict["result"] = result
|
||||
return response
|
||||
else:
|
||||
response["result"] = result
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
user_api_key_dict: Optional["UserAPIKeyAuth"] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process A2A streaming output by applying guardrails to accumulated text.
|
||||
|
||||
responses_so_far can be a list of JSON-RPC 2.0 objects (dict or NDJSON str), e.g.:
|
||||
- task with history, status-update, artifact-update (with result.artifact.parts),
|
||||
- then status-update (final). Text is extracted from result.artifact.parts,
|
||||
result.message.parts, result.parts, etc., concatenated in order, guardrailed once,
|
||||
then the combined guardrailed text is written into the first chunk that had text
|
||||
and all other text parts in other chunks are cleared (in-place).
|
||||
"""
|
||||
from litellm.llms.a2a.common_utils import extract_text_from_a2a_response
|
||||
|
||||
# Parse each item; keep alignment with responses_so_far (None where unparseable)
|
||||
parsed: List[Optional[Dict[str, Any]]] = [None] * len(responses_so_far)
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, dict):
|
||||
obj = item
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
obj = json.loads(item.strip())
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
if isinstance(obj.get("result"), dict):
|
||||
parsed[i] = obj
|
||||
|
||||
valid_parsed = [(i, obj) for i, obj in enumerate(parsed) if obj is not None]
|
||||
if not valid_parsed:
|
||||
return responses_so_far
|
||||
|
||||
# Collect text from each chunk in order (by original index in responses_so_far)
|
||||
text_parts: List[str] = []
|
||||
chunk_indices_with_text: List[int] = [] # indices into valid_parsed
|
||||
for idx, (orig_i, obj) in enumerate(valid_parsed):
|
||||
t = extract_text_from_a2a_response(obj)
|
||||
if t:
|
||||
text_parts.append(t)
|
||||
chunk_indices_with_text.append(orig_i)
|
||||
|
||||
combined_text = "".join(text_parts)
|
||||
if not combined_text:
|
||||
return responses_so_far
|
||||
|
||||
request_data: dict = {"responses_so_far": responses_so_far}
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(user_api_key_dict)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=[combined_text])
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
if not guardrailed_texts:
|
||||
return responses_so_far
|
||||
guardrailed_text = guardrailed_texts[0]
|
||||
|
||||
# Find first chunk (by original index) that has text; put full guardrailed text there and clear rest
|
||||
first_chunk_with_text: Optional[int] = (
|
||||
chunk_indices_with_text[0] if chunk_indices_with_text else None
|
||||
)
|
||||
|
||||
for orig_i, obj in valid_parsed:
|
||||
result = obj.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
continue
|
||||
texts_in_chunk: List[str] = []
|
||||
mappings: List[Tuple[Tuple[str, ...], int]] = []
|
||||
self._extract_texts_from_result(
|
||||
result=result,
|
||||
texts_to_check=texts_in_chunk,
|
||||
task_mappings=mappings,
|
||||
)
|
||||
if not mappings:
|
||||
continue
|
||||
if orig_i == first_chunk_with_text:
|
||||
# Put full guardrailed text in first text part; clear others
|
||||
for task_idx, (path, part_idx) in enumerate(mappings):
|
||||
text = guardrailed_text if task_idx == 0 else ""
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text=text,
|
||||
)
|
||||
else:
|
||||
for path, part_idx in mappings:
|
||||
self._apply_text_to_path(
|
||||
result=result,
|
||||
path=path,
|
||||
part_idx=part_idx,
|
||||
text="",
|
||||
)
|
||||
|
||||
# Write back to responses_so_far where we had NDJSON strings
|
||||
for i, item in enumerate(responses_so_far):
|
||||
if isinstance(item, str) and parsed[i] is not None:
|
||||
responses_so_far[i] = json.dumps(parsed[i]) + "\n"
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _extract_texts_from_result(
|
||||
self,
|
||||
result: Dict[str, Any],
|
||||
texts_to_check: List[str],
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text from all possible locations in an A2A result.
|
||||
|
||||
Handles multiple response formats:
|
||||
1. Direct message with parts: {"parts": [...]}
|
||||
2. Nested message: {"message": {"parts": [...]}}
|
||||
3. Task with artifacts: {"artifacts": [{"parts": [...]}]}
|
||||
4. Task with status message: {"status": {"message": {"parts": [...]}}}
|
||||
5. Streaming artifact-update: {"artifact": {"parts": [...]}}
|
||||
"""
|
||||
# Case 1: Direct parts in result (direct message)
|
||||
if "parts" in result:
|
||||
self._extract_texts_from_parts(
|
||||
parts=result["parts"],
|
||||
path=("parts",),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 2: Nested message
|
||||
message = result.get("message")
|
||||
if message and isinstance(message, dict) and "parts" in message:
|
||||
self._extract_texts_from_parts(
|
||||
parts=message["parts"],
|
||||
path=("message", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 3: Streaming artifact-update (singular artifact)
|
||||
artifact = result.get("artifact")
|
||||
if artifact and isinstance(artifact, dict) and "parts" in artifact:
|
||||
self._extract_texts_from_parts(
|
||||
parts=artifact["parts"],
|
||||
path=("artifact", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 4: Task with status message
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
status_message = status.get("message")
|
||||
if (
|
||||
status_message
|
||||
and isinstance(status_message, dict)
|
||||
and "parts" in status_message
|
||||
):
|
||||
self._extract_texts_from_parts(
|
||||
parts=status_message["parts"],
|
||||
path=("status", "message", "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Case 5: Task with artifacts (plural, array)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts and isinstance(artifacts, list):
|
||||
for artifact_idx, art in enumerate(artifacts):
|
||||
if isinstance(art, dict) and "parts" in art:
|
||||
self._extract_texts_from_parts(
|
||||
parts=art["parts"],
|
||||
path=("artifacts", str(artifact_idx), "parts"),
|
||||
texts_to_check=texts_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
def _extract_texts_from_parts(
|
||||
self,
|
||||
parts: List[Dict[str, Any]],
|
||||
path: Tuple[str, ...],
|
||||
texts_to_check: List[str],
|
||||
task_mappings: List[Tuple[Tuple[str, ...], int]],
|
||||
) -> None:
|
||||
"""Extract text from message parts."""
|
||||
for part_idx, part in enumerate(parts):
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
texts_to_check.append(text)
|
||||
task_mappings.append((path, part_idx))
|
||||
|
||||
def _apply_text_to_path(
|
||||
self,
|
||||
result: Dict[Union[str, int], Any],
|
||||
path: Tuple[str, ...],
|
||||
part_idx: int,
|
||||
text: str,
|
||||
) -> None:
|
||||
"""Apply guardrailed text back to the specified path in the result."""
|
||||
# Navigate to the parts list
|
||||
current = result
|
||||
for key in path:
|
||||
if key.isdigit():
|
||||
# Array index
|
||||
current = current[int(key)]
|
||||
else:
|
||||
current = current[key]
|
||||
|
||||
# Update the text in the part
|
||||
current[part_idx]["text"] = text
|
||||
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
A2A Streaming Response Iterator
|
||||
"""
|
||||
from typing import Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||
|
||||
from ..common_utils import extract_text_from_a2a_response
|
||||
|
||||
|
||||
class A2AModelResponseIterator(BaseModelResponseIterator):
|
||||
"""
|
||||
Iterator for parsing A2A streaming responses.
|
||||
|
||||
Converts A2A JSON-RPC streaming chunks to OpenAI-compatible format.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
streaming_response,
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
model: str = "a2a/agent",
|
||||
):
|
||||
super().__init__(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
self.model = model
|
||||
|
||||
def chunk_parser(
|
||||
self, chunk: dict
|
||||
) -> Union[GenericStreamingChunk, ModelResponseStream]:
|
||||
"""
|
||||
Parse A2A streaming chunk to OpenAI format.
|
||||
|
||||
A2A chunk format:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "request-id",
|
||||
"result": {
|
||||
"message": {
|
||||
"parts": [{"kind": "text", "text": "content"}]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Or for tasks:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"kind": "task",
|
||||
"status": {"state": "running"},
|
||||
"artifacts": [{"parts": [{"kind": "text", "text": "content"}]}]
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# Extract text from A2A response
|
||||
text = extract_text_from_a2a_response(chunk)
|
||||
|
||||
# Determine finish reason
|
||||
finish_reason = self._get_finish_reason(chunk)
|
||||
|
||||
# Return generic streaming chunk
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
is_finished=bool(finish_reason),
|
||||
finish_reason=finish_reason or "",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
except Exception:
|
||||
# Return empty chunk on parse error
|
||||
return GenericStreamingChunk(
|
||||
text="",
|
||||
is_finished=False,
|
||||
finish_reason="",
|
||||
usage=None,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
def _get_finish_reason(self, chunk: dict) -> Optional[str]:
|
||||
"""Extract finish reason from A2A chunk"""
|
||||
result = chunk.get("result", {})
|
||||
|
||||
# Check for task completion
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
state = status.get("state")
|
||||
if state == "completed":
|
||||
return "stop"
|
||||
elif state == "failed":
|
||||
return "stop" # Map failed state to 'stop' (valid finish_reason)
|
||||
|
||||
# Check for [DONE] marker
|
||||
if chunk.get("done") is True:
|
||||
return "stop"
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,373 @@
|
||||
"""
|
||||
A2A Protocol Transformation for LiteLLM
|
||||
"""
|
||||
import uuid
|
||||
from typing import Any, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
from ..common_utils import (
|
||||
A2AError,
|
||||
convert_messages_to_prompt,
|
||||
extract_text_from_a2a_response,
|
||||
)
|
||||
from .streaming_iterator import A2AModelResponseIterator
|
||||
|
||||
|
||||
class A2AConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for A2A (Agent-to-Agent) Protocol.
|
||||
|
||||
Handles transformation between OpenAI and A2A JSON-RPC 2.0 formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def resolve_agent_config_from_registry(
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
headers: Optional[Dict[str, Any]],
|
||||
optional_params: Dict[str, Any],
|
||||
) -> tuple[Optional[str], Optional[str], Optional[Dict[str, Any]]]:
|
||||
"""
|
||||
Resolve agent configuration from registry if model format is "a2a/<agent-name>".
|
||||
|
||||
Extracts agent name from model string and looks up configuration in the
|
||||
agent registry (if available in proxy context).
|
||||
|
||||
Args:
|
||||
model: Model string (e.g., "a2a/my-agent")
|
||||
api_base: Explicit api_base (takes precedence over registry)
|
||||
api_key: Explicit api_key (takes precedence over registry)
|
||||
headers: Explicit headers (takes precedence over registry)
|
||||
optional_params: Dict to merge additional litellm_params into
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key, headers) with registry values filled in
|
||||
"""
|
||||
# Extract agent name from model (e.g., "a2a/my-agent" -> "my-agent")
|
||||
agent_name = model.split("/", 1)[1] if "/" in model else None
|
||||
|
||||
# Only lookup if agent name exists and some config is missing
|
||||
if not agent_name or (
|
||||
api_base is not None and api_key is not None and headers is not None
|
||||
):
|
||||
return api_base, api_key, headers
|
||||
|
||||
# Try registry lookup (only available in proxy context)
|
||||
try:
|
||||
from litellm.proxy.agent_endpoints.agent_registry import (
|
||||
global_agent_registry,
|
||||
)
|
||||
|
||||
agent = global_agent_registry.get_agent_by_name(agent_name)
|
||||
if agent:
|
||||
# Get api_base from agent card URL
|
||||
if api_base is None and agent.agent_card_params:
|
||||
api_base = agent.agent_card_params.get("url")
|
||||
|
||||
# Get api_key, headers, and other params from litellm_params
|
||||
if agent.litellm_params:
|
||||
if api_key is None:
|
||||
api_key = agent.litellm_params.get("api_key")
|
||||
|
||||
if headers is None:
|
||||
agent_headers = agent.litellm_params.get("headers")
|
||||
if agent_headers:
|
||||
headers = agent_headers
|
||||
|
||||
# Merge other litellm_params (timeout, max_retries, etc.)
|
||||
for key, value in agent.litellm_params.items():
|
||||
if (
|
||||
key not in ["api_key", "api_base", "headers", "model"]
|
||||
and key not in optional_params
|
||||
):
|
||||
optional_params[key] = value
|
||||
except ImportError:
|
||||
pass # Registry not available (not running in proxy context)
|
||||
|
||||
return api_base, api_key, headers
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""Return list of supported OpenAI parameters"""
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to A2A parameters.
|
||||
|
||||
For A2A protocol, we need to map the stream parameter so
|
||||
transform_request can determine which JSON-RPC method to use.
|
||||
"""
|
||||
# Map stream parameter
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set headers for A2A requests.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict
|
||||
model: Model name
|
||||
messages: Messages list
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
api_key: API key (optional for A2A)
|
||||
api_base: API base URL
|
||||
|
||||
Returns:
|
||||
Updated headers dict
|
||||
"""
|
||||
# Ensure Content-Type is set to application/json for JSON-RPC 2.0
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Add Authorization header if API key is provided
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete A2A agent endpoint URL.
|
||||
|
||||
A2A agents use JSON-RPC 2.0 at the base URL, not specific paths.
|
||||
The method (message/send or message/stream) is specified in the
|
||||
JSON-RPC request body, not in the URL.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the A2A agent (e.g., "http://0.0.0.0:9999")
|
||||
api_key: API key (not used for URL construction)
|
||||
model: Model name (not used for A2A, agent determined by api_base)
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
stream: Whether this is a streaming request (affects JSON-RPC method)
|
||||
|
||||
Returns:
|
||||
Complete URL for the A2A endpoint (base URL)
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for A2A provider")
|
||||
|
||||
# A2A uses JSON-RPC 2.0 at the base URL
|
||||
# Remove trailing slash for consistency
|
||||
return api_base.rstrip("/")
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI request to A2A JSON-RPC 2.0 format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: List of OpenAI messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
headers: Request headers
|
||||
|
||||
Returns:
|
||||
A2A JSON-RPC 2.0 request dict
|
||||
"""
|
||||
# Generate request ID
|
||||
request_id = str(uuid.uuid4())
|
||||
|
||||
if not messages:
|
||||
raise ValueError("At least one message is required for A2A completion")
|
||||
|
||||
# Convert all messages to maintain conversation history
|
||||
# Use helper to format conversation with role prefixes
|
||||
full_context = convert_messages_to_prompt(messages)
|
||||
|
||||
# Create single A2A message with full conversation context
|
||||
a2a_message = {
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": full_context}],
|
||||
"messageId": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
# Build JSON-RPC 2.0 request
|
||||
# For A2A protocol, the method is "message/send" for non-streaming
|
||||
# and "message/stream" for streaming
|
||||
stream = optional_params.get("stream", False)
|
||||
method = "message/stream" if stream else "message/send"
|
||||
|
||||
request_data = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": method,
|
||||
"params": {"message": a2a_message},
|
||||
}
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Any,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform A2A JSON-RPC 2.0 response to OpenAI format.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
raw_response: HTTP response from A2A agent
|
||||
model_response: Model response object to populate
|
||||
logging_obj: Logging object
|
||||
request_data: Original request data
|
||||
messages: Original messages
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
encoding: Encoding object
|
||||
api_key: API key
|
||||
json_mode: JSON mode flag
|
||||
|
||||
Returns:
|
||||
Populated ModelResponse object
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise A2AError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"Failed to parse A2A response: {str(e)}",
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in response_json:
|
||||
error = response_json["error"]
|
||||
raise A2AError(
|
||||
status_code=raw_response.status_code,
|
||||
message=f"A2A error: {error.get('message', 'Unknown error')}",
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
# Extract text from A2A response
|
||||
text = extract_text_from_a2a_response(response_json)
|
||||
|
||||
# Populate model response
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=text,
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Set model
|
||||
model_response.model = model
|
||||
|
||||
# Set ID from response
|
||||
model_response.id = response_json.get("id", str(uuid.uuid4()))
|
||||
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator, Any],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> BaseModelResponseIterator:
|
||||
"""
|
||||
Get streaming iterator for A2A responses.
|
||||
|
||||
Args:
|
||||
streaming_response: Streaming response iterator
|
||||
sync_stream: Whether this is a sync stream
|
||||
json_mode: JSON mode flag
|
||||
|
||||
Returns:
|
||||
A2A streaming iterator
|
||||
"""
|
||||
return A2AModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def _openai_message_to_a2a_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert OpenAI message to A2A message format.
|
||||
|
||||
Args:
|
||||
message: OpenAI message dict
|
||||
|
||||
Returns:
|
||||
A2A message dict
|
||||
"""
|
||||
content = message.get("content", "")
|
||||
role = message.get("role", "user")
|
||||
|
||||
return {
|
||||
"role": role,
|
||||
"parts": [{"kind": "text", "text": str(content)}],
|
||||
"messageId": str(uuid.uuid4()),
|
||||
}
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
"""Return appropriate error class for A2A errors"""
|
||||
# Convert headers to dict if needed
|
||||
headers_dict = dict(headers) if isinstance(headers, httpx.Headers) else headers
|
||||
return A2AError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers_dict,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
Common utilities for A2A (Agent-to-Agent) Protocol
|
||||
"""
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
class A2AError(BaseLLMException):
|
||||
"""Base exception for A2A protocol errors"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Dict[str, Any] = {},
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
def convert_messages_to_prompt(messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Convert OpenAI messages to a single prompt string for A2A agent.
|
||||
|
||||
Formats each message as "{role}: {content}" and joins with newlines
|
||||
to preserve conversation history. Handles both string and list content.
|
||||
|
||||
Args:
|
||||
messages: List of OpenAI-format messages
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with full conversation context
|
||||
"""
|
||||
conversation_parts = []
|
||||
for msg in messages:
|
||||
# Use LiteLLM's helper to extract text from content (handles both str and list)
|
||||
content_text = convert_content_list_to_str(message=msg)
|
||||
|
||||
# Get role
|
||||
if isinstance(msg, BaseModel):
|
||||
role = msg.model_dump().get("role", "user")
|
||||
elif isinstance(msg, dict):
|
||||
role = msg.get("role", "user")
|
||||
else:
|
||||
role = dict(msg).get("role", "user") # type: ignore
|
||||
|
||||
if content_text:
|
||||
conversation_parts.append(f"{role}: {content_text}")
|
||||
|
||||
return "\n".join(conversation_parts)
|
||||
|
||||
|
||||
def extract_text_from_a2a_message(
|
||||
message: Dict[str, Any], depth: int = 0, max_depth: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
Extract text content from A2A message parts.
|
||||
|
||||
Args:
|
||||
message: A2A message dict with 'parts' containing text parts
|
||||
depth: Current recursion depth (internal use)
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
if message is None or depth >= max_depth:
|
||||
return ""
|
||||
|
||||
parts = message.get("parts", [])
|
||||
text_parts: List[str] = []
|
||||
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
# Handle nested parts if they exist
|
||||
elif "parts" in part:
|
||||
nested_text = extract_text_from_a2a_message(part, depth + 1, max_depth)
|
||||
if nested_text:
|
||||
text_parts.append(nested_text)
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
|
||||
def extract_text_from_a2a_response(
|
||||
response_dict: Dict[str, Any], max_depth: int = 10
|
||||
) -> str:
|
||||
"""
|
||||
Extract text content from A2A response result.
|
||||
|
||||
Args:
|
||||
response_dict: A2A response dict with 'result' containing message
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
Text from response message parts
|
||||
"""
|
||||
result = response_dict.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
# A2A response can have different formats:
|
||||
# 1. Direct message: {"result": {"kind": "message", "parts": [...]}}
|
||||
# 2. Nested message: {"result": {"message": {"parts": [...]}}}
|
||||
# 3. Task with artifacts: {"result": {"kind": "task", "artifacts": [{"parts": [...]}]}}
|
||||
# 4. Task with status message: {"result": {"kind": "task", "status": {"message": {"parts": [...]}}}}
|
||||
# 5. Streaming artifact-update: {"result": {"kind": "artifact-update", "artifact": {"parts": [...]}}}
|
||||
|
||||
# Check if result itself has parts (direct message)
|
||||
if "parts" in result:
|
||||
return extract_text_from_a2a_message(result, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for nested message
|
||||
message = result.get("message")
|
||||
if message:
|
||||
return extract_text_from_a2a_message(message, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for streaming artifact-update (singular artifact)
|
||||
artifact = result.get("artifact")
|
||||
if artifact and isinstance(artifact, dict):
|
||||
return extract_text_from_a2a_message(artifact, depth=0, max_depth=max_depth)
|
||||
|
||||
# Check for task status message (common in Gemini A2A agents)
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
status_message = status.get("message")
|
||||
if status_message:
|
||||
return extract_text_from_a2a_message(
|
||||
status_message, depth=0, max_depth=max_depth
|
||||
)
|
||||
|
||||
# Handle task result with artifacts (plural, array)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts and len(artifacts) > 0:
|
||||
first_artifact = artifacts[0]
|
||||
return extract_text_from_a2a_message(
|
||||
first_artifact, depth=0, max_depth=max_depth
|
||||
)
|
||||
|
||||
return ""
|
||||
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
AI21 Chat Completions API
|
||||
|
||||
this is OpenAI compatible - no translation needed / occurs
|
||||
"""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class AI21ChatConfig(OpenAILikeChatConfig):
|
||||
"""
|
||||
Reference: https://docs.ai21.com/reference/jamba-15-api-ref#request-parameters
|
||||
|
||||
Below are the parameters:
|
||||
"""
|
||||
|
||||
tools: Optional[list] = None
|
||||
response_format: Optional[dict] = None
|
||||
documents: Optional[list] = None
|
||||
max_tokens: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
seed: Optional[int] = None
|
||||
tool_choice: Optional[str] = None
|
||||
user: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
seed: Optional[int] = None,
|
||||
tool_choice: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
return [
|
||||
"tools",
|
||||
"response_format",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"n",
|
||||
"stream",
|
||||
"seed",
|
||||
"tool_choice",
|
||||
]
|
||||
@@ -0,0 +1,5 @@
|
||||
from .image_generation import get_aiml_image_generation_config
|
||||
|
||||
__all__ = [
|
||||
"get_aiml_image_generation_config",
|
||||
]
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AIMLChatConfig(OpenAIGPTConfig):
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "aiml"
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# AIML is openai compatible, we just need to set the api_base
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("AIML_API_BASE")
|
||||
or "https://api.aimlapi.com/v1" # Default AIML API base URL
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("AIML_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,13 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .transformation import AimlImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"AimlImageGenerationConfig",
|
||||
]
|
||||
|
||||
|
||||
def get_aiml_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
return AimlImageGenerationConfig()
|
||||
@@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
AI/ML flux image generation cost calculator
|
||||
"""
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider=litellm.LlmProviders.AIML.value,
|
||||
)
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if isinstance(image_response, ImageResponse):
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
else:
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
@@ -0,0 +1,234 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.aiml import AimlImageGenerationRequestParams
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import ImageObject, ImageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AimlImageGenerationConfig(BaseImageGenerationConfig):
|
||||
DEFAULT_BASE_URL: str = "https://api.aimlapi.com"
|
||||
IMAGE_GENERATION_ENDPOINT: str = "v1/images/generations"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
https://api.aimlapi.com/v1/images/generations
|
||||
"""
|
||||
return ["n", "response_format", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI params to AI/ML params
|
||||
if k == "n":
|
||||
optional_params["num_images"] = non_default_params[k]
|
||||
elif k == "response_format":
|
||||
optional_params["output_format"] = non_default_params[k]
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to AI/ML image_size
|
||||
size_value = non_default_params[k]
|
||||
if isinstance(size_value, str):
|
||||
# Handle standard OpenAI sizes like "1024x1024"
|
||||
if "x" in size_value:
|
||||
width, height = map(int, size_value.split("x"))
|
||||
optional_params["image_size"] = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
}
|
||||
else:
|
||||
# Pass through predefined sizes
|
||||
optional_params["image_size"] = size_value
|
||||
else:
|
||||
optional_params["image_size"] = size_value
|
||||
else:
|
||||
optional_params[k] = non_default_params[k]
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
complete_url: str = (
|
||||
api_base or get_secret_str("AIML_API_BASE") or self.DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
complete_url = complete_url.rstrip("/")
|
||||
# Strip /v1 suffix if present since IMAGE_GENERATION_ENDPOINT already includes v1
|
||||
if complete_url.endswith("/v1"):
|
||||
complete_url = complete_url[:-3]
|
||||
complete_url = f"{complete_url}/{self.IMAGE_GENERATION_ENDPOINT}"
|
||||
return complete_url
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
final_api_key: Optional[str] = (
|
||||
api_key
|
||||
or get_secret_str("AIML_API_KEY")
|
||||
or get_secret_str("AIMLAPI_KEY") # Alternative name
|
||||
)
|
||||
if not final_api_key:
|
||||
raise ValueError("AIML_API_KEY or AIMLAPI_KEY is not set")
|
||||
|
||||
headers["Authorization"] = f"Bearer {final_api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to the AI/ML flux image generation request body
|
||||
|
||||
https://api.aimlapi.com/v1/images/generations
|
||||
"""
|
||||
aiml_image_generation_request_body: AimlImageGenerationRequestParams = (
|
||||
AimlImageGenerationRequestParams(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
**optional_params,
|
||||
)
|
||||
)
|
||||
return dict(aiml_image_generation_request_body)
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the image generation response to the litellm image response
|
||||
|
||||
https://api.aimlapi.com/v1/images/generations
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# AI/ML API can return images in multiple formats:
|
||||
# 1. Top-level data array with url (OpenAI-like format)
|
||||
# 2. output.choices array with image_base64
|
||||
# 3. images array with url (and optional width, height, content_type)
|
||||
|
||||
if "data" in response_data and isinstance(response_data["data"], list):
|
||||
# Handle OpenAI-like format: {"data": [{"url": "...", "width": 1024, "height": 768, "content_type": "image/jpeg"}]}
|
||||
for image in response_data["data"]:
|
||||
if "url" in image:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=None,
|
||||
url=image["url"],
|
||||
revised_prompt=image.get("revised_prompt"),
|
||||
)
|
||||
)
|
||||
elif "b64_json" in image or "image_base64" in image:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=image.get("b64_json") or image.get("image_base64"),
|
||||
url=None,
|
||||
revised_prompt=image.get("revised_prompt"),
|
||||
)
|
||||
)
|
||||
elif "output" in response_data and "choices" in response_data["output"]:
|
||||
for choice in response_data["output"]["choices"]:
|
||||
if "image_base64" in choice:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=choice["image_base64"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
elif "url" in choice:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=None,
|
||||
url=choice["url"],
|
||||
)
|
||||
)
|
||||
elif "images" in response_data:
|
||||
# Handle alternative format: {"images": [{"url": "...", "width": 1024, "height": 768, "content_type": "image/jpeg"}]}
|
||||
for image in response_data["images"]:
|
||||
if "url" in image:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=None,
|
||||
url=image["url"],
|
||||
)
|
||||
)
|
||||
elif "image_base64" in image:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=image["image_base64"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
return model_response
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
*New config* for using aiohttp to make the request to the custom OpenAI-like provider
|
||||
|
||||
This leads to 10x higher RPS than httpx
|
||||
https://github.com/BerriAI/litellm/issues/6592
|
||||
|
||||
New config to ensure we introduce this without causing breaking changes for users
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
|
||||
from litellm.llms.openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AiohttpOpenAIChatConfig(OpenAILikeChatConfig):
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Ensure - /v1/chat/completions is at the end of the url
|
||||
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
|
||||
if not api_base.endswith("/chat/completions"):
|
||||
api_base += "/chat/completions"
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
async def transform_response( # type: ignore
|
||||
self,
|
||||
model: str,
|
||||
raw_response: ClientResponse,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
_json_response = await raw_response.json()
|
||||
model_response.id = _json_response.get("id")
|
||||
model_response.choices = [
|
||||
Choices(**choice) for choice in _json_response.get("choices")
|
||||
]
|
||||
model_response.created = _json_response.get("created")
|
||||
model_response.model = _json_response.get("model")
|
||||
model_response.object = _json_response.get("object")
|
||||
model_response.system_fingerprint = _json_response.get("system_fingerprint")
|
||||
return model_response
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Amazon Nova's `/v1/chat/completions`
|
||||
"""
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class AmazonNovaChatConfig(OpenAILikeChatConfig):
|
||||
max_completion_tokens: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
metadata: Optional[int] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
reasoning_effort: Optional[list] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
reasoning_effort: Optional[list] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "amazon_nova"
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# Amazon Nova is openai compatible, we just need to set this to custom_openai and have the api_base be Nova's endpoint
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("AMAZON_NOVA_API_BASE")
|
||||
or "https://api.nova.amazon.com/v1"
|
||||
) # type: ignore
|
||||
|
||||
# Get API key from multiple sources
|
||||
key = (
|
||||
api_key
|
||||
or litellm.amazon_nova_api_key
|
||||
or get_secret_str("AMAZON_NOVA_API_KEY")
|
||||
or litellm.api_key
|
||||
)
|
||||
return api_base, key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"top_p",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"metadata",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"reasoning_effort",
|
||||
]
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response = super().transform_response(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
raw_response=raw_response,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
encoding=encoding,
|
||||
optional_params=optional_params,
|
||||
json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Storing amazon_nova in the model response for easier cost calculation later
|
||||
setattr(model_response, "model", "amazon-nova/" + model)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Helper util for handling amazon nova cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: "Usage") -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
Follows the same logic as Anthropic's cost per token calculation.
|
||||
"""
|
||||
return generic_cost_per_token(
|
||||
model=model, usage=usage, custom_llm_provider="amazon_nova"
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from typing import Type, Union
|
||||
|
||||
from .batches.transformation import AnthropicBatchesConfig
|
||||
from .chat.transformation import AnthropicConfig
|
||||
|
||||
__all__ = ["AnthropicBatchesConfig", "AnthropicConfig"]
|
||||
|
||||
|
||||
def get_anthropic_config(
|
||||
url_route: str,
|
||||
) -> Union[Type[AnthropicBatchesConfig], Type[AnthropicConfig]]:
|
||||
if "messages/batches" in url_route and "results" in url_route:
|
||||
return AnthropicBatchesConfig
|
||||
else:
|
||||
return AnthropicConfig
|
||||
@@ -0,0 +1,4 @@
|
||||
from .handler import AnthropicBatchesHandler
|
||||
from .transformation import AnthropicBatchesConfig
|
||||
|
||||
__all__ = ["AnthropicBatchesHandler", "AnthropicBatchesConfig"]
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Anthropic Batches API Handler
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch, LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from ..common_utils import AnthropicModelInfo
|
||||
from .transformation import AnthropicBatchesConfig
|
||||
|
||||
|
||||
class AnthropicBatchesHandler:
|
||||
"""
|
||||
Handler for Anthropic Message Batches API.
|
||||
|
||||
Supports:
|
||||
- retrieve_batch() - Retrieve batch status and information
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.anthropic_model_info = AnthropicModelInfo()
|
||||
self.provider_config = AnthropicBatchesConfig()
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Async: Retrieve a batch from Anthropic.
|
||||
|
||||
Args:
|
||||
batch_id: The batch ID to retrieve
|
||||
api_base: Anthropic API base URL
|
||||
api_key: Anthropic API key
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts (unused for now)
|
||||
logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
LiteLLMBatch: Batch information in OpenAI format
|
||||
"""
|
||||
# Resolve API credentials
|
||||
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
|
||||
api_key = api_key or self.anthropic_model_info.get_api_key()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Missing Anthropic API Key")
|
||||
|
||||
# Create a minimal logging object if not provided
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObjClass,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObjClass(
|
||||
model="anthropic/unknown",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="batch_retrieve",
|
||||
start_time=None,
|
||||
litellm_call_id=f"batch_retrieve_{batch_id}",
|
||||
function_id="batch_retrieve",
|
||||
)
|
||||
|
||||
# Get the complete URL for batch retrieval
|
||||
retrieve_url = self.provider_config.get_retrieve_batch_url(
|
||||
api_base=api_base,
|
||||
batch_id=batch_id,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
# Validate environment and get headers
|
||||
headers = self.provider_config.validate_environment(
|
||||
headers={},
|
||||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=batch_id,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"api_base": retrieve_url,
|
||||
"headers": headers,
|
||||
"complete_input_dict": {},
|
||||
},
|
||||
)
|
||||
# Make the request
|
||||
async_client = get_async_httpx_client(llm_provider=LlmProviders.ANTHROPIC)
|
||||
response = await async_client.get(url=retrieve_url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
# Transform response to LiteLLM format
|
||||
return self.provider_config.transform_retrieve_batch_response(
|
||||
model=None,
|
||||
raw_response=response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
batch_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
"""
|
||||
Retrieve a batch from Anthropic.
|
||||
|
||||
Args:
|
||||
_is_async: Whether to run asynchronously
|
||||
batch_id: The batch ID to retrieve
|
||||
api_base: Anthropic API base URL
|
||||
api_key: Anthropic API key
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts (unused for now)
|
||||
logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
LiteLLMBatch or Coroutine: Batch information in OpenAI format
|
||||
"""
|
||||
if _is_async:
|
||||
return self.aretrieve_batch(
|
||||
batch_id=batch_id,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.aretrieve_batch(
|
||||
batch_id=batch_id,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,312 @@
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues, CreateBatchRequest
|
||||
from litellm.types.utils import LiteLLMBatch, LlmProviders, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
|
||||
class AnthropicBatchesConfig(BaseBatchesConfig):
|
||||
def __init__(self):
|
||||
from ..chat.transformation import AnthropicConfig
|
||||
from ..common_utils import AnthropicModelInfo
|
||||
|
||||
self.anthropic_chat_config = AnthropicConfig() # initialize once
|
||||
self.anthropic_model_info = AnthropicModelInfo()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
"""Return the LLM provider type for this configuration."""
|
||||
return LlmProviders.ANTHROPIC
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Validate and prepare environment-specific headers and parameters."""
|
||||
# Resolve api_key from environment if not provided
|
||||
api_key = api_key or self.anthropic_model_info.get_api_key()
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
|
||||
)
|
||||
_headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
# Add beta header for message batches
|
||||
if "anthropic-beta" not in headers:
|
||||
headers["anthropic-beta"] = "message-batches-2024-09-24"
|
||||
headers.update(_headers)
|
||||
return headers
|
||||
|
||||
def get_complete_batch_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateBatchRequest,
|
||||
) -> str:
|
||||
"""Get the complete URL for batch creation request."""
|
||||
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
|
||||
if not api_base.endswith("/v1/messages/batches"):
|
||||
api_base = f"{api_base.rstrip('/')}/v1/messages/batches"
|
||||
return api_base
|
||||
|
||||
def transform_create_batch_request(
|
||||
self,
|
||||
model: str,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform the batch creation request to Anthropic format.
|
||||
|
||||
Not currently implemented - placeholder to satisfy abstract base class.
|
||||
"""
|
||||
raise NotImplementedError("Batch creation not yet implemented for Anthropic")
|
||||
|
||||
def transform_create_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LoggingClass,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform Anthropic MessageBatch creation response to LiteLLM format.
|
||||
|
||||
Not currently implemented - placeholder to satisfy abstract base class.
|
||||
"""
|
||||
raise NotImplementedError("Batch creation not yet implemented for Anthropic")
|
||||
|
||||
def get_retrieve_batch_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
batch_id: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for batch retrieval request.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL (optional, will use default if not provided)
|
||||
batch_id: Batch ID to retrieve
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Complete URL for Anthropic batch retrieval: {api_base}/v1/messages/batches/{batch_id}
|
||||
"""
|
||||
api_base = api_base or self.anthropic_model_info.get_api_base(api_base)
|
||||
return f"{api_base.rstrip('/')}/v1/messages/batches/{batch_id}"
|
||||
|
||||
def transform_retrieve_batch_request(
|
||||
self,
|
||||
batch_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, Dict[str, Any]]:
|
||||
"""
|
||||
Transform batch retrieval request for Anthropic.
|
||||
|
||||
For Anthropic, the URL is constructed by get_retrieve_batch_url(),
|
||||
so this method returns an empty dict (no additional request params needed).
|
||||
"""
|
||||
# No additional request params needed - URL is handled by get_retrieve_batch_url
|
||||
return {}
|
||||
|
||||
def transform_retrieve_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LoggingClass,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""Transform Anthropic MessageBatch retrieval response to LiteLLM format."""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Anthropic batch response: {e}")
|
||||
|
||||
# Map Anthropic MessageBatch to OpenAI Batch format
|
||||
batch_id = response_data.get("id", "")
|
||||
processing_status = response_data.get("processing_status", "in_progress")
|
||||
|
||||
# Map Anthropic processing_status to OpenAI status
|
||||
status_mapping: Dict[
|
||||
str,
|
||||
Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
],
|
||||
] = {
|
||||
"in_progress": "in_progress",
|
||||
"canceling": "cancelling",
|
||||
"ended": "completed",
|
||||
}
|
||||
openai_status = status_mapping.get(processing_status, "in_progress")
|
||||
|
||||
# Parse timestamps
|
||||
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
|
||||
if not ts_str:
|
||||
return None
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
dt = datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
created_at = parse_timestamp(response_data.get("created_at"))
|
||||
ended_at = parse_timestamp(response_data.get("ended_at"))
|
||||
expires_at = parse_timestamp(response_data.get("expires_at"))
|
||||
cancel_initiated_at = parse_timestamp(response_data.get("cancel_initiated_at"))
|
||||
archived_at = parse_timestamp(response_data.get("archived_at"))
|
||||
|
||||
# Extract request counts
|
||||
request_counts_data = response_data.get("request_counts", {})
|
||||
from openai.types.batch import BatchRequestCounts
|
||||
|
||||
request_counts = BatchRequestCounts(
|
||||
total=sum(
|
||||
[
|
||||
request_counts_data.get("processing", 0),
|
||||
request_counts_data.get("succeeded", 0),
|
||||
request_counts_data.get("errored", 0),
|
||||
request_counts_data.get("canceled", 0),
|
||||
request_counts_data.get("expired", 0),
|
||||
]
|
||||
),
|
||||
completed=request_counts_data.get("succeeded", 0),
|
||||
failed=request_counts_data.get("errored", 0),
|
||||
)
|
||||
|
||||
return LiteLLMBatch(
|
||||
id=batch_id,
|
||||
object="batch",
|
||||
endpoint="/v1/messages",
|
||||
errors=None,
|
||||
input_file_id="None",
|
||||
completion_window="24h",
|
||||
status=openai_status,
|
||||
output_file_id=batch_id,
|
||||
error_file_id=None,
|
||||
created_at=created_at or int(time.time()),
|
||||
in_progress_at=created_at if processing_status == "in_progress" else None,
|
||||
expires_at=expires_at,
|
||||
finalizing_at=None,
|
||||
completed_at=ended_at if processing_status == "ended" else None,
|
||||
failed_at=None,
|
||||
expired_at=archived_at if archived_at else None,
|
||||
cancelling_at=cancel_initiated_at
|
||||
if processing_status == "canceling"
|
||||
else None,
|
||||
cancelled_at=ended_at
|
||||
if processing_status == "canceling" and ended_at
|
||||
else None,
|
||||
request_counts=request_counts,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> "BaseLLMException":
|
||||
"""Get the appropriate error class for Anthropic."""
|
||||
from ..common_utils import AnthropicError
|
||||
|
||||
# Convert Dict to Headers if needed
|
||||
if isinstance(headers, dict):
|
||||
headers_obj: Optional[Headers] = Headers(headers)
|
||||
else:
|
||||
headers_obj = headers if isinstance(headers, Headers) else None
|
||||
|
||||
return AnthropicError(
|
||||
status_code=status_code, message=error_message, headers=headers_obj
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
request_data: Dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: Dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
from litellm.cost_calculator import BaseTokenUsageProcessor
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
response_text = raw_response.text.strip()
|
||||
all_usage: List[Usage] = []
|
||||
|
||||
try:
|
||||
# Split by newlines and try to parse each line as JSON
|
||||
lines = response_text.split("\n")
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
response_json = json.loads(line)
|
||||
# Update model_response with the parsed JSON
|
||||
completion_response = response_json["result"]["message"]
|
||||
transformed_response = (
|
||||
self.anthropic_chat_config.transform_parsed_response(
|
||||
completion_response=completion_response,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
)
|
||||
)
|
||||
|
||||
transformed_response_usage = getattr(
|
||||
transformed_response, "usage", None
|
||||
)
|
||||
if transformed_response_usage:
|
||||
all_usage.append(cast(Usage, transformed_response_usage))
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
## SUM ALL USAGE
|
||||
combined_usage = BaseTokenUsageProcessor.combine_usage_objects(all_usage)
|
||||
setattr(model_response, "usage", combined_usage)
|
||||
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1 @@
|
||||
from .handler import AnthropicChatCompletion, ModelResponseIterator
|
||||
@@ -0,0 +1,10 @@
|
||||
from litellm.llms.anthropic.chat.guardrail_translation.handler import (
|
||||
AnthropicMessagesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.anthropic_messages: AnthropicMessagesHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,688 @@
|
||||
"""
|
||||
Anthropic Message Handler for Unified Guardrails
|
||||
|
||||
This module provides a class-based handler for Anthropic-format messages.
|
||||
The class methods can be overridden for custom behavior.
|
||||
|
||||
Pattern Overview:
|
||||
-----------------
|
||||
1. Extract text content from messages/responses (both string and list formats)
|
||||
2. Create async tasks to apply guardrails to each text segment
|
||||
3. Track mappings to know where each response belongs
|
||||
4. Apply guardrail responses back to the original structure
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import (
|
||||
LiteLLMAnthropicMessagesAdapter,
|
||||
)
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
AnthropicMessagesRequest,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
GenericGuardrailAPIInputs,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicMessagesHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing Anthropic messages with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input messages (pre-call hook)
|
||||
2. Process output responses (post-call hook)
|
||||
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.adapter = LiteLLMAnthropicMessagesAdapter()
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages by applying guardrails to text content.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
if messages is None:
|
||||
return data
|
||||
|
||||
(
|
||||
chat_completion_compatible_request,
|
||||
_tool_name_mapping,
|
||||
) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
|
||||
# Use a shallow copy to avoid mutating request data (pop on litellm_metadata).
|
||||
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
|
||||
)
|
||||
|
||||
structured_messages = chat_completion_compatible_request.get("messages", [])
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tools_to_check: List[
|
||||
ChatCompletionToolParam
|
||||
] = chat_completion_compatible_request.get("tools", [])
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
|
||||
# Step 1: Extract all text content and images
|
||||
for msg_idx, message in enumerate(messages):
|
||||
self._extract_input_text_and_images(
|
||||
message=message,
|
||||
msg_idx=msg_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tools_to_check:
|
||||
inputs["tools"] = tools_to_check
|
||||
if structured_messages:
|
||||
inputs["structured_messages"] = structured_messages
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
guardrailed_tools = guardrailed_inputs.get("tools")
|
||||
if guardrailed_tools is not None:
|
||||
data["tools"] = guardrailed_tools
|
||||
|
||||
# Step 3: Map guardrail responses back to original message structure
|
||||
await self._apply_guardrail_responses_to_input(
|
||||
messages=messages,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Anthropic Messages: Processed input messages: %s", messages
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from Anthropic messages request (tools[].name)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if isinstance(tool, dict) and tool.get("name"):
|
||||
names.append(str(tool["name"]))
|
||||
return names
|
||||
|
||||
def _extract_input_text_and_images(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
msg_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content and images from a message.
|
||||
|
||||
Override this method to customize text/image extraction logic.
|
||||
"""
|
||||
content = message.get("content", None)
|
||||
tools = message.get("tools", None)
|
||||
if content is None and tools is None:
|
||||
return
|
||||
|
||||
## CHECK FOR TEXT + IMAGES
|
||||
if content is not None and isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
task_mappings.append((msg_idx, None))
|
||||
|
||||
elif content is not None and isinstance(content, list):
|
||||
# List content (e.g., multimodal with text and images)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Extract text
|
||||
text_str = content_item.get("text", None)
|
||||
if text_str is not None:
|
||||
texts_to_check.append(text_str)
|
||||
task_mappings.append((msg_idx, int(content_idx)))
|
||||
|
||||
# Extract images
|
||||
if content_item.get("type") == "image":
|
||||
source = content_item.get("source", {})
|
||||
if isinstance(source, dict):
|
||||
# Could be base64 or url
|
||||
data = source.get("data")
|
||||
if data:
|
||||
images_to_check.append(data)
|
||||
|
||||
def _extract_input_tools(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
tools_to_check: List[ChatCompletionToolParam],
|
||||
) -> None:
|
||||
"""
|
||||
Extract tools from a message.
|
||||
"""
|
||||
## CHECK FOR TOOLS
|
||||
if tools is not None and isinstance(tools, list):
|
||||
# TRANSFORM ANTHROPIC TOOLS TO OPENAI TOOLS
|
||||
openai_tools = self.adapter.translate_anthropic_tools_to_openai(
|
||||
tools=cast(List[AllAnthropicToolsValues], tools)
|
||||
)
|
||||
tools_to_check.extend(openai_tools) # type: ignore
|
||||
|
||||
async def _apply_guardrail_responses_to_input(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to input messages.
|
||||
|
||||
Override this method to customize how responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
msg_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
content = messages[msg_idx].get("content", None)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
messages[msg_idx]["content"] = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
messages[msg_idx]["content"][content_idx_optional][
|
||||
"text"
|
||||
] = guardrail_response
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "AnthropicMessagesResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content and tool calls.
|
||||
|
||||
Args:
|
||||
response: Anthropic MessagesResponse object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- List content: response.content = [
|
||||
{"type": "text", "text": "text here"},
|
||||
{"type": "tool_use", "id": "...", "name": "...", "input": {...}},
|
||||
...
|
||||
]
|
||||
"""
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (content_index, None) for each text
|
||||
|
||||
# Handle both dict and object responses
|
||||
response_content: List[Any] = []
|
||||
if isinstance(response, dict):
|
||||
response_content = response.get("content", []) or []
|
||||
elif hasattr(response, "content"):
|
||||
content = getattr(response, "content", None)
|
||||
response_content = content or []
|
||||
else:
|
||||
response_content = []
|
||||
|
||||
if not response_content:
|
||||
return response
|
||||
|
||||
# Step 1: Extract all text content and tool calls from response
|
||||
for content_idx, content_block in enumerate(response_content):
|
||||
# Handle both dict and Pydantic object content blocks
|
||||
block_dict: Dict[str, Any] = {}
|
||||
if isinstance(content_block, dict):
|
||||
block_type = content_block.get("type")
|
||||
block_dict = cast(Dict[str, Any], content_block)
|
||||
elif hasattr(content_block, "type"):
|
||||
block_type = getattr(content_block, "type", None)
|
||||
# Convert Pydantic object to dict for processing
|
||||
if hasattr(content_block, "model_dump"):
|
||||
block_dict = content_block.model_dump()
|
||||
else:
|
||||
block_dict = {
|
||||
"type": block_type,
|
||||
"text": getattr(content_block, "text", None),
|
||||
}
|
||||
else:
|
||||
continue
|
||||
|
||||
if block_type in ["text", "tool_use"]:
|
||||
self._extract_output_text_and_images(
|
||||
content_block=block_dict,
|
||||
content_idx=content_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check
|
||||
# Include model information from the response if available
|
||||
response_model = None
|
||||
if isinstance(response, dict):
|
||||
response_model = response.get("model")
|
||||
elif hasattr(response, "model"):
|
||||
response_model = getattr(response, "model", None)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Map guardrail responses back to original response structure
|
||||
await self._apply_guardrail_responses_to_output(
|
||||
response=response,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Anthropic Messages: Processed output response: %s", response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process output streaming response by applying guardrails to text content.
|
||||
|
||||
Get the string so far, check the apply guardrail to the string so far, and return the list of responses so far.
|
||||
"""
|
||||
has_ended = self._check_streaming_has_ended(responses_so_far)
|
||||
if has_ended:
|
||||
# build the model response from the responses_so_far
|
||||
built_response = (
|
||||
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=responses_so_far,
|
||||
litellm_logging_obj=cast("LiteLLMLoggingObj", litellm_logging_obj),
|
||||
model="",
|
||||
)
|
||||
)
|
||||
|
||||
# Check if model_response is valid and has choices before accessing
|
||||
if (
|
||||
built_response is not None
|
||||
and hasattr(built_response, "choices")
|
||||
and built_response.choices
|
||||
):
|
||||
model_response = cast(ModelResponse, built_response)
|
||||
first_choice = cast(Choices, model_response.choices[0])
|
||||
tool_calls_list = cast(
|
||||
Optional[List[ChatCompletionMessageToolCall]],
|
||||
first_choice.message.tool_calls,
|
||||
)
|
||||
string_so_far = first_choice.message.content
|
||||
guardrail_inputs = GenericGuardrailAPIInputs()
|
||||
if string_so_far:
|
||||
guardrail_inputs["texts"] = [string_so_far]
|
||||
if tool_calls_list:
|
||||
guardrail_inputs["tool_calls"] = tool_calls_list
|
||||
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
|
||||
inputs=guardrail_inputs,
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping output guardrail - model response has no choices"
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
string_so_far = self.get_streaming_string_so_far(responses_so_far)
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail( # allow rejecting the response, if invalid
|
||||
inputs={"texts": [string_so_far]},
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
|
||||
"""
|
||||
Parse streaming responses and extract accumulated text content.
|
||||
|
||||
Handles two formats:
|
||||
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
|
||||
2. Parsed dict objects (for backwards compatibility)
|
||||
|
||||
SSE format example:
|
||||
b'event: content_block_delta\\ndata: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" curious"}}\\n\\n'
|
||||
|
||||
Dict format example:
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"type": "text_delta",
|
||||
"text": " curious"
|
||||
}
|
||||
}
|
||||
"""
|
||||
text_so_far = ""
|
||||
for response in responses_so_far:
|
||||
# Handle raw bytes in SSE format
|
||||
if isinstance(response, bytes):
|
||||
text_so_far += self._extract_text_from_sse(response)
|
||||
# Handle already-parsed dict format
|
||||
elif isinstance(response, dict):
|
||||
delta = response.get("delta") if response.get("delta") else None
|
||||
if delta and delta.get("type") == "text_delta":
|
||||
text = delta.get("text", "")
|
||||
if text:
|
||||
text_so_far += text
|
||||
return text_so_far
|
||||
|
||||
def _extract_text_from_sse(self, sse_bytes: bytes) -> str:
|
||||
"""
|
||||
Extract text content from Server-Sent Events (SSE) format.
|
||||
|
||||
Args:
|
||||
sse_bytes: Raw bytes in SSE format
|
||||
|
||||
Returns:
|
||||
Accumulated text from all content_block_delta events
|
||||
"""
|
||||
text = ""
|
||||
try:
|
||||
# Decode bytes to string
|
||||
sse_string = sse_bytes.decode("utf-8")
|
||||
|
||||
# Split by double newline to get individual events
|
||||
events = sse_string.split("\n\n")
|
||||
|
||||
for event in events:
|
||||
if not event.strip():
|
||||
continue
|
||||
|
||||
# Parse event lines
|
||||
lines = event.strip().split("\n")
|
||||
event_type = None
|
||||
data_line = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_line = line[5:].strip()
|
||||
|
||||
# Only process content_block_delta events
|
||||
if event_type == "content_block_delta" and data_line:
|
||||
try:
|
||||
data = json.loads(data_line)
|
||||
delta = data.get("delta", {})
|
||||
if delta.get("type") == "text_delta":
|
||||
text += delta.get("text", "")
|
||||
except json.JSONDecodeError:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse JSON from SSE data: {data_line}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error extracting text from SSE: {e}")
|
||||
|
||||
return text
|
||||
|
||||
def _check_streaming_has_ended(self, responses_so_far: List[Any]) -> bool:
|
||||
"""
|
||||
Check if streaming response has ended by looking for non-null stop_reason.
|
||||
|
||||
Handles two formats:
|
||||
1. Raw bytes in SSE (Server-Sent Events) format from Anthropic API
|
||||
2. Parsed dict objects (for backwards compatibility)
|
||||
|
||||
SSE format example:
|
||||
b'event: message_delta\\ndata: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},...}\\n\\n'
|
||||
|
||||
Dict format example:
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": "tool_use",
|
||||
"stop_sequence": null
|
||||
}
|
||||
}
|
||||
|
||||
Returns:
|
||||
True if stop_reason is set to a non-null value, indicating stream has ended
|
||||
"""
|
||||
for response in responses_so_far:
|
||||
# Handle raw bytes in SSE format
|
||||
if isinstance(response, bytes):
|
||||
try:
|
||||
# Decode bytes to string
|
||||
sse_string = response.decode("utf-8")
|
||||
|
||||
# Split by double newline to get individual events
|
||||
events = sse_string.split("\n\n")
|
||||
|
||||
for event in events:
|
||||
if not event.strip():
|
||||
continue
|
||||
|
||||
# Parse event lines
|
||||
lines = event.strip().split("\n")
|
||||
event_type = None
|
||||
data_line = None
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("event:"):
|
||||
event_type = line[6:].strip()
|
||||
elif line.startswith("data:"):
|
||||
data_line = line[5:].strip()
|
||||
|
||||
# Check for message_delta event with stop_reason
|
||||
if event_type == "message_delta" and data_line:
|
||||
try:
|
||||
data = json.loads(data_line)
|
||||
delta = data.get("delta", {})
|
||||
stop_reason = delta.get("stop_reason")
|
||||
if stop_reason is not None:
|
||||
return True
|
||||
except json.JSONDecodeError:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Failed to parse JSON from SSE data: {data_line}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error checking streaming end in SSE: {e}"
|
||||
)
|
||||
|
||||
# Handle already-parsed dict format
|
||||
elif isinstance(response, dict):
|
||||
if response.get("type") == "message_delta":
|
||||
delta = response.get("delta", {})
|
||||
stop_reason = delta.get("stop_reason")
|
||||
if stop_reason is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _has_text_content(self, response: "AnthropicMessagesResponse") -> bool:
|
||||
"""
|
||||
Check if response has any text content to process.
|
||||
|
||||
Override this method to customize text content detection.
|
||||
"""
|
||||
if isinstance(response, dict):
|
||||
response_content = response.get("content", [])
|
||||
else:
|
||||
response_content = getattr(response, "content", None) or []
|
||||
|
||||
if not response_content:
|
||||
return False
|
||||
for content_block in response_content:
|
||||
# Check if this is a text block by checking the 'type' field
|
||||
if isinstance(content_block, dict) and content_block.get("type") == "text":
|
||||
content_text = content_block.get("text")
|
||||
if content_text and isinstance(content_text, str):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_output_text_and_images(
|
||||
self,
|
||||
content_block: Dict[str, Any],
|
||||
content_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_calls_to_check: Optional[List[ChatCompletionToolCallChunk]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a response content block.
|
||||
|
||||
Override this method to customize text/image/tool extraction logic.
|
||||
"""
|
||||
content_type = content_block.get("type")
|
||||
|
||||
# Extract text content
|
||||
if content_type == "text":
|
||||
content_text = content_block.get("text")
|
||||
if content_text and isinstance(content_text, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content_text)
|
||||
task_mappings.append((content_idx, None))
|
||||
|
||||
# Extract tool calls
|
||||
elif content_type == "tool_use":
|
||||
tool_call = AnthropicConfig.convert_tool_use_to_openai_format(
|
||||
anthropic_tool_content=content_block,
|
||||
index=content_idx,
|
||||
)
|
||||
if tool_calls_to_check is None:
|
||||
tool_calls_to_check = []
|
||||
tool_calls_to_check.append(tool_call)
|
||||
|
||||
async def _apply_guardrail_responses_to_output(
|
||||
self,
|
||||
response: "AnthropicMessagesResponse",
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to output response.
|
||||
|
||||
Override this method to customize how responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
content_idx = cast(int, mapping[0])
|
||||
|
||||
# Handle both dict and object responses
|
||||
response_content: List[Any] = []
|
||||
if isinstance(response, dict):
|
||||
response_content = response.get("content", []) or []
|
||||
elif hasattr(response, "content"):
|
||||
content = getattr(response, "content", None)
|
||||
response_content = content or []
|
||||
else:
|
||||
continue
|
||||
|
||||
if not response_content:
|
||||
continue
|
||||
|
||||
# Get the content block at the index
|
||||
if content_idx >= len(response_content):
|
||||
continue
|
||||
|
||||
content_block = response_content[content_idx]
|
||||
|
||||
# Verify it's a text block and update the text field
|
||||
# Handle both dict and Pydantic object content blocks
|
||||
if isinstance(content_block, dict):
|
||||
if content_block.get("type") == "text":
|
||||
cast(Dict[str, Any], content_block)["text"] = guardrail_response
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and getattr(content_block, "type", None) == "text"
|
||||
):
|
||||
# Update Pydantic object's text attribute
|
||||
if hasattr(content_block, "text"):
|
||||
content_block.text = guardrail_response
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,627 @@
|
||||
"""
|
||||
This file contains common utils for anthropic calls.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_file_ids_from_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_HOSTED_TOOLS,
|
||||
ANTHROPIC_OAUTH_BETA_HEADER,
|
||||
ANTHROPIC_OAUTH_TOKEN_PREFIX,
|
||||
AllAnthropicToolsValues,
|
||||
AnthropicMcpServerTool,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
|
||||
def is_anthropic_oauth_key(value: Optional[str]) -> bool:
|
||||
"""Check if a value contains an Anthropic OAuth token (sk-ant-oat*)."""
|
||||
if value is None:
|
||||
return False
|
||||
# Handle both raw token and "Bearer <token>" format
|
||||
if value.startswith("Bearer "):
|
||||
value = value[7:]
|
||||
return value.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
|
||||
|
||||
def _merge_beta_headers(existing: Optional[str], new_beta: str) -> str:
|
||||
"""Merge a new beta value into an existing comma-separated anthropic-beta header."""
|
||||
if not existing:
|
||||
return new_beta
|
||||
betas = {b.strip() for b in existing.split(",") if b.strip()}
|
||||
betas.add(new_beta)
|
||||
return ",".join(sorted(betas))
|
||||
|
||||
|
||||
def optionally_handle_anthropic_oauth(
|
||||
headers: dict, api_key: Optional[str]
|
||||
) -> tuple[dict, Optional[str]]:
|
||||
"""
|
||||
Handle Anthropic OAuth token detection and header setup.
|
||||
|
||||
If an OAuth token is detected in the Authorization header, extracts it
|
||||
and sets the required OAuth headers.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict
|
||||
api_key: Current API key (may be None)
|
||||
|
||||
Returns:
|
||||
Tuple of (updated headers, api_key)
|
||||
"""
|
||||
# Check Authorization header (passthrough / forwarded requests)
|
||||
auth_header = headers.get("authorization", "")
|
||||
if auth_header and auth_header.startswith(f"Bearer {ANTHROPIC_OAUTH_TOKEN_PREFIX}"):
|
||||
api_key = auth_header.replace("Bearer ", "")
|
||||
headers.pop("x-api-key", None)
|
||||
headers["anthropic-beta"] = _merge_beta_headers(
|
||||
headers.get("anthropic-beta"), ANTHROPIC_OAUTH_BETA_HEADER
|
||||
)
|
||||
headers["anthropic-dangerous-direct-browser-access"] = "true"
|
||||
return headers, api_key
|
||||
# Check api_key directly (standard chat/completion flow)
|
||||
if api_key and api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX):
|
||||
headers.pop("x-api-key", None)
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
headers["anthropic-beta"] = _merge_beta_headers(
|
||||
headers.get("anthropic-beta"), ANTHROPIC_OAUTH_BETA_HEADER
|
||||
)
|
||||
headers["anthropic-dangerous-direct-browser-access"] = "true"
|
||||
return headers, api_key
|
||||
|
||||
|
||||
class AnthropicError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||
|
||||
|
||||
class AnthropicModelInfo(BaseLLMModelInfo):
|
||||
def is_cache_control_set(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Return if {"cache_control": ..} in message content block
|
||||
|
||||
Used to check if anthropic prompt caching headers need to be set.
|
||||
"""
|
||||
for message in messages:
|
||||
if message.get("cache_control", None) is not None:
|
||||
return True
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
for content in _message_content:
|
||||
if "cache_control" in content:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_file_id_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Return if {"source": {"type": "file", "file_id": ..}} in message content block
|
||||
"""
|
||||
file_ids = get_file_ids_from_messages(messages)
|
||||
return len(file_ids) > 0
|
||||
|
||||
def is_mcp_server_used(
|
||||
self, mcp_servers: Optional[List[AnthropicMcpServerTool]]
|
||||
) -> bool:
|
||||
if mcp_servers is None:
|
||||
return False
|
||||
if mcp_servers:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_computer_tool_used(
|
||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||
) -> Optional[str]:
|
||||
"""Returns the computer tool version if used, e.g. 'computer_20250124' or None"""
|
||||
if tools is None:
|
||||
return None
|
||||
for tool in tools:
|
||||
if "type" in tool and tool["type"].startswith("computer_"):
|
||||
return tool["type"]
|
||||
return None
|
||||
|
||||
def is_web_search_tool_used(
|
||||
self, tools: Optional[List[AllAnthropicToolsValues]]
|
||||
) -> bool:
|
||||
"""Returns True if web_search tool is used"""
|
||||
if tools is None:
|
||||
return False
|
||||
for tool in tools:
|
||||
if "type" in tool and tool["type"].startswith(
|
||||
ANTHROPIC_HOSTED_TOOLS.WEB_SEARCH.value
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_pdf_used(self, messages: List[AllMessageValues]) -> bool:
|
||||
"""
|
||||
Set to true if media passed into messages.
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
if (
|
||||
"content" in message
|
||||
and message["content"] is not None
|
||||
and isinstance(message["content"], list)
|
||||
):
|
||||
for content in message["content"]:
|
||||
if "type" in content and content["type"] != "text":
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_tool_search_used(self, tools: Optional[List]) -> bool:
|
||||
"""
|
||||
Check if tool search tools are present in the tools list.
|
||||
"""
|
||||
if not tools:
|
||||
return False
|
||||
|
||||
for tool in tools:
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in [
|
||||
"tool_search_tool_regex_20251119",
|
||||
"tool_search_tool_bm25_20251119",
|
||||
]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_programmatic_tool_calling_used(self, tools: Optional[List]) -> bool:
|
||||
"""
|
||||
Check if programmatic tool calling is being used (tools with allowed_callers field).
|
||||
|
||||
Returns True if any tool has allowed_callers containing 'code_execution_20250825'.
|
||||
"""
|
||||
if not tools:
|
||||
return False
|
||||
|
||||
for tool in tools:
|
||||
# Check top-level allowed_callers
|
||||
allowed_callers = tool.get("allowed_callers", None)
|
||||
if allowed_callers and isinstance(allowed_callers, list):
|
||||
if "code_execution_20250825" in allowed_callers:
|
||||
return True
|
||||
|
||||
# Check function.allowed_callers for OpenAI format tools
|
||||
function = tool.get("function", {})
|
||||
if isinstance(function, dict):
|
||||
function_allowed_callers = function.get("allowed_callers", None)
|
||||
if function_allowed_callers and isinstance(
|
||||
function_allowed_callers, list
|
||||
):
|
||||
if "code_execution_20250825" in function_allowed_callers:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_input_examples_used(self, tools: Optional[List]) -> bool:
|
||||
"""
|
||||
Check if input_examples is being used in any tools.
|
||||
|
||||
Returns True if any tool has input_examples field.
|
||||
"""
|
||||
if not tools:
|
||||
return False
|
||||
|
||||
for tool in tools:
|
||||
# Check top-level input_examples
|
||||
input_examples = tool.get("input_examples", None)
|
||||
if (
|
||||
input_examples
|
||||
and isinstance(input_examples, list)
|
||||
and len(input_examples) > 0
|
||||
):
|
||||
return True
|
||||
|
||||
# Check function.input_examples for OpenAI format tools
|
||||
function = tool.get("function", {})
|
||||
if isinstance(function, dict):
|
||||
function_input_examples = function.get("input_examples", None)
|
||||
if (
|
||||
function_input_examples
|
||||
and isinstance(function_input_examples, list)
|
||||
and len(function_input_examples) > 0
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _is_claude_4_6_model(model: str) -> bool:
|
||||
"""Check if the model is a Claude 4.6 model (Opus 4.6 or Sonnet 4.6)."""
|
||||
model_lower = model.lower()
|
||||
return any(
|
||||
v in model_lower
|
||||
for v in (
|
||||
"opus-4-6",
|
||||
"opus_4_6",
|
||||
"opus-4.6",
|
||||
"opus_4.6",
|
||||
"sonnet-4-6",
|
||||
"sonnet_4_6",
|
||||
"sonnet-4.6",
|
||||
"sonnet_4.6",
|
||||
)
|
||||
)
|
||||
|
||||
def is_effort_used(
|
||||
self, optional_params: Optional[dict], model: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Check if effort parameter is being used and requires a beta header.
|
||||
|
||||
Returns True if effort-related parameters are present and
|
||||
the model requires the effort beta header. Claude 4.6 models
|
||||
use output_config as a stable API feature — no beta header needed.
|
||||
"""
|
||||
if not optional_params:
|
||||
return False
|
||||
|
||||
# Claude 4.6 models use output_config as a stable API feature — no beta header needed
|
||||
if model and self._is_claude_4_6_model(model):
|
||||
return False
|
||||
|
||||
# Check if reasoning_effort is provided for Claude Opus 4.5
|
||||
if model and ("opus-4-5" in model.lower() or "opus_4_5" in model.lower()):
|
||||
reasoning_effort = optional_params.get("reasoning_effort")
|
||||
if reasoning_effort and isinstance(reasoning_effort, str):
|
||||
return True
|
||||
|
||||
# Check if output_config is directly provided (for non-4.6 models)
|
||||
output_config = optional_params.get("output_config")
|
||||
if output_config and isinstance(output_config, dict):
|
||||
effort = output_config.get("effort")
|
||||
if effort and isinstance(effort, str):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_code_execution_tool_used(self, tools: Optional[List]) -> bool:
|
||||
"""
|
||||
Check if code execution tool is being used.
|
||||
|
||||
Returns True if any tool has type "code_execution_20250825".
|
||||
"""
|
||||
if not tools:
|
||||
return False
|
||||
|
||||
for tool in tools:
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type == "code_execution_20250825":
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_container_with_skills_used(self, optional_params: Optional[dict]) -> bool:
|
||||
"""
|
||||
Check if container with skills is being used.
|
||||
|
||||
Returns True if optional_params contains container with skills.
|
||||
"""
|
||||
if not optional_params:
|
||||
return False
|
||||
|
||||
container = optional_params.get("container")
|
||||
if container and isinstance(container, dict):
|
||||
skills = container.get("skills")
|
||||
if skills and isinstance(skills, list) and len(skills) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_user_anthropic_beta_headers(
|
||||
self, anthropic_beta_header: Optional[str]
|
||||
) -> Optional[List[str]]:
|
||||
if anthropic_beta_header is None:
|
||||
return None
|
||||
return anthropic_beta_header.split(",")
|
||||
|
||||
def get_computer_tool_beta_header(self, computer_tool_version: str) -> str:
|
||||
"""
|
||||
Get the appropriate beta header for a given computer tool version.
|
||||
|
||||
Args:
|
||||
computer_tool_version: The computer tool version (e.g., 'computer_20250124', 'computer_20241022')
|
||||
|
||||
Returns:
|
||||
The corresponding beta header string
|
||||
"""
|
||||
computer_tool_beta_mapping = {
|
||||
"computer_20250124": "computer-use-2025-01-24",
|
||||
"computer_20241022": "computer-use-2024-10-22",
|
||||
}
|
||||
return computer_tool_beta_mapping.get(
|
||||
computer_tool_version, "computer-use-2024-10-22" # Default fallback
|
||||
)
|
||||
|
||||
def get_anthropic_beta_list(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
computer_tool_used: Optional[str] = None,
|
||||
prompt_caching_set: bool = False,
|
||||
file_id_used: bool = False,
|
||||
mcp_server_used: bool = False,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of common beta headers based on the features that are active.
|
||||
|
||||
Returns:
|
||||
List of beta header strings
|
||||
"""
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_EFFORT_BETA_HEADER,
|
||||
)
|
||||
|
||||
betas = []
|
||||
|
||||
# Detect features
|
||||
effort_used = self.is_effort_used(optional_params, model)
|
||||
|
||||
if effort_used:
|
||||
betas.append(ANTHROPIC_EFFORT_BETA_HEADER) # effort-2025-11-24
|
||||
|
||||
if computer_tool_used:
|
||||
beta_header = self.get_computer_tool_beta_header(computer_tool_used)
|
||||
betas.append(beta_header)
|
||||
|
||||
# Anthropic no longer requires the prompt-caching beta header
|
||||
# Prompt caching now works automatically when cache_control is used in messages
|
||||
# Reference: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
|
||||
if file_id_used:
|
||||
betas.append("files-api-2025-04-14")
|
||||
betas.append("code-execution-2025-05-22")
|
||||
|
||||
if mcp_server_used:
|
||||
betas.append("mcp-client-2025-04-04")
|
||||
|
||||
return list(set(betas))
|
||||
|
||||
def get_anthropic_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
anthropic_version: Optional[str] = None,
|
||||
computer_tool_used: Optional[str] = None,
|
||||
prompt_caching_set: bool = False,
|
||||
pdf_used: bool = False,
|
||||
file_id_used: bool = False,
|
||||
mcp_server_used: bool = False,
|
||||
web_search_tool_used: bool = False,
|
||||
tool_search_used: bool = False,
|
||||
programmatic_tool_calling_used: bool = False,
|
||||
input_examples_used: bool = False,
|
||||
effort_used: bool = False,
|
||||
is_vertex_request: bool = False,
|
||||
user_anthropic_beta_headers: Optional[List[str]] = None,
|
||||
code_execution_tool_used: bool = False,
|
||||
container_with_skills_used: bool = False,
|
||||
) -> dict:
|
||||
betas = set()
|
||||
# Anthropic no longer requires the prompt-caching beta header
|
||||
# Prompt caching now works automatically when cache_control is used in messages
|
||||
# Reference: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
|
||||
if computer_tool_used:
|
||||
beta_header = self.get_computer_tool_beta_header(computer_tool_used)
|
||||
betas.add(beta_header)
|
||||
# if pdf_used:
|
||||
# betas.add("pdfs-2024-09-25")
|
||||
if file_id_used:
|
||||
betas.add("files-api-2025-04-14")
|
||||
betas.add("code-execution-2025-05-22")
|
||||
if mcp_server_used:
|
||||
betas.add("mcp-client-2025-04-04")
|
||||
# Tool search, programmatic tool calling, and input_examples all use the same beta header
|
||||
if tool_search_used or programmatic_tool_calling_used or input_examples_used:
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
|
||||
|
||||
betas.add(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
|
||||
|
||||
# Effort parameter uses a separate beta header
|
||||
if effort_used:
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_EFFORT_BETA_HEADER
|
||||
|
||||
betas.add(ANTHROPIC_EFFORT_BETA_HEADER)
|
||||
|
||||
# Code execution tool uses a separate beta header
|
||||
if code_execution_tool_used:
|
||||
betas.add("code-execution-2025-08-25")
|
||||
|
||||
# Container with skills uses a separate beta header
|
||||
if container_with_skills_used:
|
||||
betas.add("skills-2025-10-02")
|
||||
|
||||
_is_oauth = api_key and api_key.startswith(ANTHROPIC_OAUTH_TOKEN_PREFIX)
|
||||
headers = {
|
||||
"anthropic-version": anthropic_version or "2023-06-01",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if _is_oauth:
|
||||
headers["authorization"] = f"Bearer {api_key}"
|
||||
headers["anthropic-dangerous-direct-browser-access"] = "true"
|
||||
betas.add(ANTHROPIC_OAUTH_BETA_HEADER)
|
||||
else:
|
||||
headers["x-api-key"] = api_key
|
||||
|
||||
if user_anthropic_beta_headers is not None:
|
||||
betas.update(user_anthropic_beta_headers)
|
||||
|
||||
# Don't send any beta headers to Vertex, except web search which is required
|
||||
if is_vertex_request is True:
|
||||
# Vertex AI requires web search beta header for web search to work
|
||||
if web_search_tool_used:
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_BETA_HEADER_VALUES
|
||||
|
||||
headers[
|
||||
"anthropic-beta"
|
||||
] = ANTHROPIC_BETA_HEADER_VALUES.WEB_SEARCH_2025_03_05.value
|
||||
elif len(betas) > 0:
|
||||
headers["anthropic-beta"] = ",".join(betas)
|
||||
|
||||
return headers
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
# Check for Anthropic OAuth token in headers
|
||||
headers, api_key = optionally_handle_anthropic_oauth(
|
||||
headers=headers, api_key=api_key
|
||||
)
|
||||
if api_key is None:
|
||||
raise litellm.AuthenticationError(
|
||||
message="Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params. Please set `ANTHROPIC_API_KEY` in your environment vars",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
mcp_server_used = self.is_mcp_server_used(
|
||||
mcp_servers=optional_params.get("mcp_servers")
|
||||
)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
file_id_used = self.is_file_id_used(messages=messages)
|
||||
web_search_tool_used = self.is_web_search_tool_used(tools=tools)
|
||||
tool_search_used = self.is_tool_search_used(tools=tools)
|
||||
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(
|
||||
tools=tools
|
||||
)
|
||||
input_examples_used = self.is_input_examples_used(tools=tools)
|
||||
effort_used = self.is_effort_used(optional_params=optional_params, model=model)
|
||||
code_execution_tool_used = self.is_code_execution_tool_used(tools=tools)
|
||||
container_with_skills_used = self.is_container_with_skills_used(
|
||||
optional_params=optional_params
|
||||
)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key,
|
||||
file_id_used=file_id_used,
|
||||
web_search_tool_used=web_search_tool_used,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
mcp_server_used=mcp_server_used,
|
||||
tool_search_used=tool_search_used,
|
||||
programmatic_tool_calling_used=programmatic_tool_calling_used,
|
||||
input_examples_used=input_examples_used,
|
||||
effort_used=effort_used,
|
||||
code_execution_tool_used=code_execution_tool_used,
|
||||
container_with_skills_used=container_with_skills_used,
|
||||
)
|
||||
|
||||
headers = {**headers, **anthropic_headers}
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
return (
|
||||
api_base
|
||||
or get_secret_str("ANTHROPIC_API_BASE")
|
||||
or "https://api.anthropic.com"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
return api_key or get_secret_str("ANTHROPIC_API_KEY")
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model.replace("anthropic/", "") if model else None
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = AnthropicModelInfo.get_api_base(api_base)
|
||||
api_key = AnthropicModelInfo.get_api_key(api_key)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"ANTHROPIC_API_BASE or ANTHROPIC_API_KEY is not set. Please set the environment variable, to query Anthropic's `/models` endpoint."
|
||||
)
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"x-api-key": api_key, "anthropic-version": "2023-06-01"},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
raise Exception(
|
||||
f"Failed to fetch models from Anthropic. Status code: {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
models = response.json()["data"]
|
||||
|
||||
litellm_model_names = []
|
||||
for model in models:
|
||||
stripped_model_name = model["id"]
|
||||
litellm_model_name = "anthropic/" + stripped_model_name
|
||||
litellm_model_names.append(litellm_model_name)
|
||||
return litellm_model_names
|
||||
|
||||
def get_token_counter(self) -> Optional[BaseTokenCounter]:
|
||||
"""
|
||||
Factory method to create an Anthropic token counter.
|
||||
|
||||
Returns:
|
||||
AnthropicTokenCounter instance for this provider.
|
||||
"""
|
||||
from litellm.llms.anthropic.count_tokens.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
)
|
||||
|
||||
return AnthropicTokenCounter()
|
||||
|
||||
|
||||
def process_anthropic_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "anthropic-ratelimit-requests-limit" in headers:
|
||||
openai_headers["x-ratelimit-limit-requests"] = headers[
|
||||
"anthropic-ratelimit-requests-limit"
|
||||
]
|
||||
if "anthropic-ratelimit-requests-remaining" in headers:
|
||||
openai_headers["x-ratelimit-remaining-requests"] = headers[
|
||||
"anthropic-ratelimit-requests-remaining"
|
||||
]
|
||||
if "anthropic-ratelimit-tokens-limit" in headers:
|
||||
openai_headers["x-ratelimit-limit-tokens"] = headers[
|
||||
"anthropic-ratelimit-tokens-limit"
|
||||
]
|
||||
if "anthropic-ratelimit-tokens-remaining" in headers:
|
||||
openai_headers["x-ratelimit-remaining-tokens"] = headers[
|
||||
"anthropic-ratelimit-tokens-remaining"
|
||||
]
|
||||
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
|
||||
}
|
||||
|
||||
additional_headers = {**llm_response_headers, **openai_headers}
|
||||
return additional_headers
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Anthropic /complete API - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,310 @@
|
||||
"""
|
||||
Translation logic for anthropic's `/v1/complete` endpoint
|
||||
|
||||
Litellm provider slug: `anthropic_text/<model_name>`
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import DEFAULT_MAX_TOKENS
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
BaseLLMException,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.anthropic.com/v1/complete"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
message=self.message,
|
||||
status_code=self.status_code,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class AnthropicTextConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.anthropic.com/claude/reference/complete_post
|
||||
|
||||
to pass metadata to anthropic, it's {"user_id": "any-relevant-information"}
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = litellm.max_tokens # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
top_k: Optional[int] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[
|
||||
int
|
||||
] = DEFAULT_MAX_TOKENS, # anthropic requires a default
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
# makes headers for API call
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Missing Anthropic API Key - A call is being made to anthropic but no key is set either in the environment variables or via params"
|
||||
)
|
||||
_headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"content-type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
headers.update(_headers)
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
## Load Config
|
||||
config = litellm.AnthropicTextConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
"""
|
||||
Anthropic /complete API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
return [
|
||||
"stream",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Follows the same logic as the AnthropicConfig.map_openai_params method (which is the Anthropic /messages API)
|
||||
|
||||
Note: the only difference is in the get supported openai params method between the AnthropicConfig and AnthropicTextConfig
|
||||
API Ref: https://docs.anthropic.com/en/api/complete
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop" and (isinstance(value, str) or isinstance(value, list)):
|
||||
_value = litellm.AnthropicConfig()._map_stop_sequences(value)
|
||||
if _value is not None:
|
||||
optional_params["stop_sequences"] = _value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "user":
|
||||
optional_params["metadata"] = {"user_id": value}
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: str,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise AnthropicTextError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
prompt = self._get_anthropic_text_prompt_from_messages(
|
||||
messages=messages, model=model
|
||||
)
|
||||
if "error" in completion_response:
|
||||
raise AnthropicTextError(
|
||||
message=str(completion_response["error"]),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
else:
|
||||
if len(completion_response["completion"]) > 0:
|
||||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"completion"
|
||||
]
|
||||
model_response.choices[0].finish_reason = completion_response["stop_reason"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
) ##[TODO] use the anthropic tokenizer here
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AnthropicTextError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_text_model(model: str) -> bool:
|
||||
return model == "claude-2" or model == "claude-instant-1"
|
||||
|
||||
def _get_anthropic_text_prompt_from_messages(
|
||||
self, messages: List[AllMessageValues], model: str
|
||||
) -> str:
|
||||
custom_prompt_dict = litellm.custom_prompt_dict
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
return str(prompt)
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return AnthropicTextCompletionResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
index = int(chunk.get("index", 0))
|
||||
_chunk_text = chunk.get("completion", None)
|
||||
if _chunk_text is not None and isinstance(_chunk_text, str):
|
||||
text = _chunk_text
|
||||
finish_reason = chunk.get("stop_reason") or ""
|
||||
if finish_reason is not None:
|
||||
is_finished = True
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Helper util for handling anthropic-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
_get_token_base_cost,
|
||||
_parse_prompt_tokens_details,
|
||||
calculate_cache_writing_cost,
|
||||
generic_cost_per_token,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelInfo, Usage
|
||||
import litellm
|
||||
|
||||
|
||||
def _compute_cache_only_cost(model_info: "ModelInfo", usage: "Usage") -> float:
|
||||
"""
|
||||
Return only the cache-related portion of the prompt cost (cache read + cache write).
|
||||
|
||||
These costs must NOT be scaled by geo/speed multipliers because the old
|
||||
explicit ``fast/`` model entries carried unchanged cache rates while
|
||||
multiplying only the regular input/output token costs.
|
||||
"""
|
||||
if usage.prompt_tokens_details is None:
|
||||
return 0.0
|
||||
|
||||
prompt_tokens_details = _parse_prompt_tokens_details(usage)
|
||||
(
|
||||
_,
|
||||
_,
|
||||
cache_creation_cost,
|
||||
cache_creation_cost_above_1hr,
|
||||
cache_read_cost,
|
||||
) = _get_token_base_cost(model_info=model_info, usage=usage)
|
||||
|
||||
cache_cost = float(prompt_tokens_details["cache_hit_tokens"]) * cache_read_cost
|
||||
|
||||
if (
|
||||
prompt_tokens_details["cache_creation_tokens"]
|
||||
or prompt_tokens_details["cache_creation_token_details"] is not None
|
||||
):
|
||||
cache_cost += calculate_cache_writing_cost(
|
||||
cache_creation_tokens=prompt_tokens_details["cache_creation_tokens"],
|
||||
cache_creation_token_details=prompt_tokens_details[
|
||||
"cache_creation_token_details"
|
||||
],
|
||||
cache_creation_cost_above_1hr=cache_creation_cost_above_1hr,
|
||||
cache_creation_cost=cache_creation_cost,
|
||||
)
|
||||
|
||||
return cache_cost
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: "Usage") -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
prompt_cost, completion_cost = generic_cost_per_token(
|
||||
model=model, usage=usage, custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
# Apply provider_specific_entry multipliers for geo/speed routing
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider="anthropic"
|
||||
)
|
||||
provider_specific_entry: dict = model_info.get("provider_specific_entry") or {}
|
||||
|
||||
multiplier = 1.0
|
||||
if (
|
||||
hasattr(usage, "inference_geo")
|
||||
and usage.inference_geo
|
||||
and usage.inference_geo.lower() not in ["global", "not_available"]
|
||||
):
|
||||
multiplier *= provider_specific_entry.get(usage.inference_geo.lower(), 1.0)
|
||||
if hasattr(usage, "speed") and usage.speed == "fast":
|
||||
multiplier *= provider_specific_entry.get("fast", 1.0)
|
||||
|
||||
if multiplier != 1.0:
|
||||
cache_cost = _compute_cache_only_cost(model_info=model_info, usage=usage)
|
||||
prompt_cost = (prompt_cost - cache_cost) * multiplier + cache_cost
|
||||
completion_cost *= multiplier
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def get_cost_for_anthropic_web_search(
|
||||
model_info: Optional["ModelInfo"] = None,
|
||||
usage: Optional["Usage"] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost of using a web search tool for Anthropic.
|
||||
"""
|
||||
from litellm.types.utils import SearchContextCostPerQuery
|
||||
|
||||
## Check if web search requests are in the usage object
|
||||
if model_info is None:
|
||||
return 0.0
|
||||
|
||||
if (
|
||||
usage is None
|
||||
or usage.server_tool_use is None
|
||||
or usage.server_tool_use.web_search_requests is None
|
||||
):
|
||||
return 0.0
|
||||
|
||||
## Get the cost per web search request
|
||||
search_context_pricing: SearchContextCostPerQuery = (
|
||||
model_info.get("search_context_cost_per_query") or SearchContextCostPerQuery()
|
||||
)
|
||||
cost_per_web_search_request = search_context_pricing.get(
|
||||
"search_context_size_medium", 0.0
|
||||
)
|
||||
if cost_per_web_search_request is None or cost_per_web_search_request == 0.0:
|
||||
return 0.0
|
||||
|
||||
## Calculate the total cost
|
||||
total_cost = cost_per_web_search_request * usage.server_tool_use.web_search_requests
|
||||
return total_cost
|
||||
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Anthropic CountTokens API implementation.
|
||||
"""
|
||||
|
||||
from litellm.llms.anthropic.count_tokens.handler import AnthropicCountTokensHandler
|
||||
from litellm.llms.anthropic.count_tokens.token_counter import AnthropicTokenCounter
|
||||
from litellm.llms.anthropic.count_tokens.transformation import (
|
||||
AnthropicCountTokensConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicCountTokensHandler",
|
||||
"AnthropicCountTokensConfig",
|
||||
"AnthropicTokenCounter",
|
||||
]
|
||||
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Anthropic CountTokens API handler.
|
||||
|
||||
Uses httpx for HTTP requests instead of the Anthropic SDK.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
from litellm.llms.anthropic.count_tokens.transformation import (
|
||||
AnthropicCountTokensConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
|
||||
class AnthropicCountTokensHandler(AnthropicCountTokensConfig):
|
||||
"""
|
||||
Handler for Anthropic CountTokens API requests.
|
||||
|
||||
Uses httpx for HTTP requests, following the same pattern as BedrockCountTokensHandler.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a CountTokens request using httpx.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "claude-3-5-sonnet-20241022")
|
||||
messages: The messages to count tokens for
|
||||
api_key: The Anthropic API key
|
||||
api_base: Optional custom API base URL
|
||||
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
|
||||
|
||||
Returns:
|
||||
Dictionary containing token count response
|
||||
|
||||
Raises:
|
||||
AnthropicError: If the API request fails
|
||||
"""
|
||||
try:
|
||||
# Validate the request
|
||||
self.validate_request(model, messages)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing Anthropic CountTokens request for model: {model}"
|
||||
)
|
||||
|
||||
# Transform request to Anthropic format
|
||||
request_body = self.transform_request_to_count_tokens(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Transformed request: {request_body}")
|
||||
|
||||
# Get endpoint URL
|
||||
endpoint_url = api_base or self.get_anthropic_count_tokens_endpoint()
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
# Get required headers
|
||||
headers = self.get_required_headers(api_key)
|
||||
|
||||
# Use LiteLLM's async httpx client
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.ANTHROPIC
|
||||
)
|
||||
|
||||
# Use provided timeout or fall back to litellm.request_timeout
|
||||
request_timeout = (
|
||||
timeout if timeout is not None else litellm.request_timeout
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"Anthropic API error: {error_text}")
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
anthropic_response = response.json()
|
||||
|
||||
verbose_logger.debug(f"Anthropic response: {anthropic_response}")
|
||||
|
||||
# Return Anthropic response directly - no transformation needed
|
||||
return anthropic_response
|
||||
|
||||
except AnthropicError:
|
||||
# Re-raise Anthropic exceptions as-is
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
# HTTP errors - preserve the actual status code
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Anthropic Token Counter implementation using the CountTokens API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.anthropic.count_tokens.handler import AnthropicCountTokensHandler
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
# Global handler instance - reuse across all token counting requests
|
||||
anthropic_count_tokens_handler = AnthropicCountTokensHandler()
|
||||
|
||||
|
||||
class AnthropicTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for Anthropic provider using the CountTokens API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return custom_llm_provider == LlmProviders.ANTHROPIC.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using Anthropic's CountTokens API.
|
||||
|
||||
Args:
|
||||
model_to_use: The model identifier
|
||||
messages: The messages to count tokens for
|
||||
contents: Alternative content format (not used for Anthropic)
|
||||
deployment: Deployment configuration containing litellm_params
|
||||
request_model: The original request model name
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with token count, or None if counting fails
|
||||
"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Get Anthropic API key from deployment config or environment
|
||||
api_key = litellm_params.get("api_key")
|
||||
if not api_key:
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
verbose_logger.warning("No Anthropic API key found for token counting")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await anthropic_count_tokens_handler.handle_count_tokens_request(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="anthropic_api",
|
||||
original_response=result,
|
||||
)
|
||||
except AnthropicError as e:
|
||||
verbose_logger.warning(
|
||||
f"Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="anthropic_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error calling Anthropic CountTokens API: {e}")
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="anthropic_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Anthropic CountTokens API transformation logic.
|
||||
|
||||
This module handles the transformation of requests to Anthropic's CountTokens API format.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
|
||||
|
||||
|
||||
class AnthropicCountTokensConfig:
|
||||
"""
|
||||
Configuration and transformation logic for Anthropic CountTokens API.
|
||||
|
||||
Anthropic CountTokens API Specification:
|
||||
- Endpoint: POST https://api.anthropic.com/v1/messages/count_tokens
|
||||
- Beta header required: anthropic-beta: token-counting-2024-11-01
|
||||
- Response: {"input_tokens": <number>}
|
||||
"""
|
||||
|
||||
def get_anthropic_count_tokens_endpoint(self) -> str:
|
||||
"""
|
||||
Get the Anthropic CountTokens API endpoint.
|
||||
|
||||
Returns:
|
||||
The endpoint URL for the CountTokens API
|
||||
"""
|
||||
return "https://api.anthropic.com/v1/messages/count_tokens"
|
||||
|
||||
def transform_request_to_count_tokens(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform request to Anthropic CountTokens format.
|
||||
|
||||
Includes optional system and tools fields for accurate token counting.
|
||||
"""
|
||||
request: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if system is not None:
|
||||
request["system"] = system
|
||||
|
||||
if tools is not None:
|
||||
request["tools"] = tools
|
||||
|
||||
return request
|
||||
|
||||
def get_required_headers(self, api_key: str) -> Dict[str, str]:
|
||||
"""
|
||||
Get the required headers for the CountTokens API.
|
||||
|
||||
Args:
|
||||
api_key: The Anthropic API key
|
||||
|
||||
Returns:
|
||||
Dictionary of required headers
|
||||
"""
|
||||
from litellm.llms.anthropic.common_utils import (
|
||||
optionally_handle_anthropic_oauth,
|
||||
)
|
||||
|
||||
headers: Dict[str, str] = {
|
||||
"Content-Type": "application/json",
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
|
||||
}
|
||||
headers, _ = optionally_handle_anthropic_oauth(headers=headers, api_key=api_key)
|
||||
return headers
|
||||
|
||||
def validate_request(self, model: str, messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Validate the incoming count tokens request.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
messages: The messages to count tokens for
|
||||
|
||||
Raises:
|
||||
ValueError: If the request is invalid
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("model parameter is required")
|
||||
|
||||
if not messages:
|
||||
raise ValueError("messages parameter is required")
|
||||
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError("messages must be a list")
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise ValueError(f"Message {i} must be a dictionary")
|
||||
|
||||
if "role" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'role' field")
|
||||
|
||||
if "content" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'content' field")
|
||||
@@ -0,0 +1,3 @@
|
||||
from .transformation import LiteLLMAnthropicMessagesAdapter
|
||||
|
||||
__all__ = ["LiteLLMAnthropicMessagesAdapter"]
|
||||
@@ -0,0 +1,345 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Coroutine,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import litellm
|
||||
from litellm.llms.anthropic.experimental_pass_through.adapters.transformation import (
|
||||
AnthropicAdapter,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
########################################################
|
||||
# init adapter
|
||||
ANTHROPIC_ADAPTER = AnthropicAdapter()
|
||||
########################################################
|
||||
|
||||
|
||||
class LiteLLMMessagesToCompletionTransformationHandler:
|
||||
@staticmethod
|
||||
def _route_openai_thinking_to_responses_api_if_needed(
|
||||
completion_kwargs: Dict[str, Any],
|
||||
*,
|
||||
thinking: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
When users call `litellm.anthropic.messages.*` with a non-Anthropic model and
|
||||
`thinking={"type": "enabled", ...}`, LiteLLM converts this into OpenAI
|
||||
`reasoning_effort`.
|
||||
|
||||
For OpenAI models, Chat Completions typically does not return reasoning text
|
||||
(only token accounting). To return a thinking-like content block in the
|
||||
Anthropic response format, we route the request through OpenAI's Responses API
|
||||
and request a reasoning summary.
|
||||
"""
|
||||
custom_llm_provider = completion_kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
try:
|
||||
_, inferred_provider, _, _ = litellm.utils.get_llm_provider(
|
||||
model=cast(str, completion_kwargs.get("model"))
|
||||
)
|
||||
custom_llm_provider = inferred_provider
|
||||
except Exception:
|
||||
custom_llm_provider = None
|
||||
|
||||
if custom_llm_provider != "openai":
|
||||
return
|
||||
|
||||
if not isinstance(thinking, dict) or thinking.get("type") != "enabled":
|
||||
return
|
||||
|
||||
model = completion_kwargs.get("model")
|
||||
try:
|
||||
model_info = get_model_info(
|
||||
model=cast(str, model), custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
if model_info and model_info.get("supports_reasoning") is False:
|
||||
# Model doesn't support reasoning/responses API, don't route
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if isinstance(model, str) and model and not model.startswith("responses/"):
|
||||
# Prefix model with "responses/" to route to OpenAI Responses API
|
||||
completion_kwargs["model"] = f"responses/{model}"
|
||||
|
||||
reasoning_effort = completion_kwargs.get("reasoning_effort")
|
||||
if isinstance(reasoning_effort, str) and reasoning_effort:
|
||||
completion_kwargs["reasoning_effort"] = {
|
||||
"effort": reasoning_effort,
|
||||
"summary": "detailed",
|
||||
}
|
||||
elif isinstance(reasoning_effort, dict):
|
||||
if (
|
||||
"summary" not in reasoning_effort
|
||||
and "generate_summary" not in reasoning_effort
|
||||
):
|
||||
updated_reasoning_effort = dict(reasoning_effort)
|
||||
updated_reasoning_effort["summary"] = "detailed"
|
||||
completion_kwargs["reasoning_effort"] = updated_reasoning_effort
|
||||
|
||||
@staticmethod
|
||||
def _prepare_completion_kwargs(
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
extra_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, str]]:
|
||||
"""Prepare kwargs for litellm.completion/acompletion.
|
||||
|
||||
Returns:
|
||||
Tuple of (completion_kwargs, tool_name_mapping)
|
||||
- tool_name_mapping maps truncated tool names back to original names
|
||||
for tools that exceeded OpenAI's 64-char limit
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObject,
|
||||
)
|
||||
|
||||
request_data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
|
||||
if metadata:
|
||||
request_data["metadata"] = metadata
|
||||
if stop_sequences:
|
||||
request_data["stop_sequences"] = stop_sequences
|
||||
if system:
|
||||
request_data["system"] = system
|
||||
if temperature is not None:
|
||||
request_data["temperature"] = temperature
|
||||
if thinking:
|
||||
request_data["thinking"] = thinking
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if top_k is not None:
|
||||
request_data["top_k"] = top_k
|
||||
if top_p is not None:
|
||||
request_data["top_p"] = top_p
|
||||
if output_format:
|
||||
request_data["output_format"] = output_format
|
||||
|
||||
(
|
||||
openai_request,
|
||||
tool_name_mapping,
|
||||
) = ANTHROPIC_ADAPTER.translate_completion_input_params_with_tool_mapping(
|
||||
request_data
|
||||
)
|
||||
|
||||
if openai_request is None:
|
||||
raise ValueError("Failed to translate request to OpenAI format")
|
||||
|
||||
completion_kwargs: Dict[str, Any] = dict(openai_request)
|
||||
|
||||
if stream:
|
||||
completion_kwargs["stream"] = stream
|
||||
completion_kwargs["stream_options"] = {
|
||||
"include_usage": True,
|
||||
}
|
||||
|
||||
excluded_keys = {"anthropic_messages"}
|
||||
extra_kwargs = extra_kwargs or {}
|
||||
for key, value in extra_kwargs.items():
|
||||
if (
|
||||
key == "litellm_logging_obj"
|
||||
and value is not None
|
||||
and isinstance(value, LiteLLMLoggingObject)
|
||||
):
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
setattr(value, "call_type", CallTypes.completion.value)
|
||||
setattr(
|
||||
value, "stream_options", completion_kwargs.get("stream_options")
|
||||
)
|
||||
if (
|
||||
key not in excluded_keys
|
||||
and key not in completion_kwargs
|
||||
and value is not None
|
||||
):
|
||||
completion_kwargs[key] = value
|
||||
|
||||
LiteLLMMessagesToCompletionTransformationHandler._route_openai_thinking_to_responses_api_if_needed(
|
||||
completion_kwargs,
|
||||
thinking=thinking,
|
||||
)
|
||||
|
||||
return completion_kwargs, tool_name_mapping
|
||||
|
||||
@staticmethod
|
||||
async def async_anthropic_messages_handler(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""Handle non-Anthropic models asynchronously using the adapter"""
|
||||
(
|
||||
completion_kwargs,
|
||||
tool_name_mapping,
|
||||
) = LiteLLMMessagesToCompletionTransformationHandler._prepare_completion_kwargs(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
completion_response = await litellm.acompletion(**completion_kwargs)
|
||||
|
||||
if stream:
|
||||
transformed_stream = (
|
||||
ANTHROPIC_ADAPTER.translate_completion_output_params_streaming(
|
||||
completion_response,
|
||||
model=model,
|
||||
tool_name_mapping=tool_name_mapping,
|
||||
)
|
||||
)
|
||||
if transformed_stream is not None:
|
||||
return transformed_stream
|
||||
raise ValueError("Failed to transform streaming response")
|
||||
else:
|
||||
anthropic_response = ANTHROPIC_ADAPTER.translate_completion_output_params(
|
||||
cast(ModelResponse, completion_response),
|
||||
tool_name_mapping=tool_name_mapping,
|
||||
)
|
||||
if anthropic_response is not None:
|
||||
return anthropic_response
|
||||
raise ValueError("Failed to transform response to Anthropic format")
|
||||
|
||||
@staticmethod
|
||||
def anthropic_messages_handler(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
_is_async: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
AnthropicMessagesResponse,
|
||||
AsyncIterator[Any],
|
||||
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
|
||||
]:
|
||||
"""Handle non-Anthropic models using the adapter."""
|
||||
if _is_async is True:
|
||||
return LiteLLMMessagesToCompletionTransformationHandler.async_anthropic_messages_handler(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
(
|
||||
completion_kwargs,
|
||||
tool_name_mapping,
|
||||
) = LiteLLMMessagesToCompletionTransformationHandler._prepare_completion_kwargs(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
completion_response = litellm.completion(**completion_kwargs)
|
||||
|
||||
if stream:
|
||||
transformed_stream = (
|
||||
ANTHROPIC_ADAPTER.translate_completion_output_params_streaming(
|
||||
completion_response,
|
||||
model=model,
|
||||
tool_name_mapping=tool_name_mapping,
|
||||
)
|
||||
)
|
||||
if transformed_stream is not None:
|
||||
return transformed_stream
|
||||
raise ValueError("Failed to transform streaming response")
|
||||
else:
|
||||
anthropic_response = ANTHROPIC_ADAPTER.translate_completion_output_params(
|
||||
cast(ModelResponse, completion_response),
|
||||
tool_name_mapping=tool_name_mapping,
|
||||
)
|
||||
if anthropic_response is not None:
|
||||
return anthropic_response
|
||||
raise ValueError("Failed to transform response to Anthropic format")
|
||||
@@ -0,0 +1,488 @@
|
||||
# What is this?
|
||||
## Translates OpenAI call to Anthropic `/v1/messages` format
|
||||
import json
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, Literal, Optional
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.llms.anthropic import UsageDelta
|
||||
from litellm.types.utils import AdapterCompletionStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
|
||||
class AnthropicStreamWrapper(AdapterCompletionStreamWrapper):
|
||||
"""
|
||||
- first chunk return 'message_start'
|
||||
- content block must be started and stopped
|
||||
- finish_reason must map exactly to anthropic reason, else anthropic client won't be able to parse it.
|
||||
"""
|
||||
|
||||
from litellm.types.llms.anthropic import (
|
||||
ContentBlockContentBlockDict,
|
||||
ContentBlockStart,
|
||||
ContentBlockStartText,
|
||||
TextBlock,
|
||||
)
|
||||
|
||||
sent_first_chunk: bool = False
|
||||
sent_content_block_start: bool = False
|
||||
sent_content_block_finish: bool = False
|
||||
current_content_block_type: Literal["text", "tool_use", "thinking"] = "text"
|
||||
sent_last_message: bool = False
|
||||
holding_chunk: Optional[Any] = None
|
||||
holding_stop_reason_chunk: Optional[Any] = None
|
||||
queued_usage_chunk: bool = False
|
||||
current_content_block_index: int = 0
|
||||
current_content_block_start: ContentBlockContentBlockDict = TextBlock(
|
||||
type="text",
|
||||
text="",
|
||||
)
|
||||
chunk_queue: deque = deque() # Queue for buffering multiple chunks
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
completion_stream: Any,
|
||||
model: str,
|
||||
tool_name_mapping: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
super().__init__(completion_stream)
|
||||
self.model = model
|
||||
# Mapping of truncated tool names to original names (for OpenAI's 64-char limit)
|
||||
self.tool_name_mapping = tool_name_mapping or {}
|
||||
|
||||
def _create_initial_usage_delta(self) -> UsageDelta:
|
||||
"""
|
||||
Create the initial UsageDelta for the message_start event.
|
||||
|
||||
Initializes cache token fields (cache_creation_input_tokens, cache_read_input_tokens)
|
||||
to 0 to indicate to clients (like Claude Code) that prompt caching is supported.
|
||||
|
||||
The actual cache token values will be provided in the message_delta event at the
|
||||
end of the stream, since Bedrock Converse API only returns usage data in the final
|
||||
response chunk.
|
||||
|
||||
Returns:
|
||||
UsageDelta with all token counts initialized to 0.
|
||||
"""
|
||||
return UsageDelta(
|
||||
input_tokens=0,
|
||||
output_tokens=0,
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=0,
|
||||
)
|
||||
|
||||
def __next__(self):
|
||||
from .transformation import LiteLLMAnthropicMessagesAdapter
|
||||
|
||||
try:
|
||||
# Always return queued chunks first
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
# Queue initial chunks if not sent yet
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_{}".format(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": self.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": self._create_initial_usage_delta(),
|
||||
},
|
||||
}
|
||||
)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": self.current_content_block_index,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
||||
should_start_new_block = self._should_start_new_content_block(chunk)
|
||||
if should_start_new_block:
|
||||
self._increment_content_block_index()
|
||||
|
||||
processed_chunk = LiteLLMAnthropicMessagesAdapter().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk,
|
||||
current_content_block_index=self.current_content_block_index,
|
||||
)
|
||||
|
||||
if should_start_new_block and not self.sent_content_block_finish:
|
||||
# Queue the sequence: content_block_stop -> content_block_start
|
||||
# The trigger chunk itself is not emitted as a delta since the
|
||||
# content_block_start already carries the relevant information.
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": max(self.current_content_block_index - 1, 0),
|
||||
}
|
||||
)
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": self.current_content_block_index,
|
||||
"content_block": self.current_content_block_start,
|
||||
}
|
||||
)
|
||||
self.sent_content_block_finish = False
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
# Queue both the content_block_stop and the message_delta
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": self.current_content_block_index,
|
||||
}
|
||||
)
|
||||
self.sent_content_block_finish = True
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
return self.chunk_queue.popleft()
|
||||
elif self.holding_chunk is not None:
|
||||
self.chunk_queue.append(self.holding_chunk)
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
self.holding_chunk = None
|
||||
return self.chunk_queue.popleft()
|
||||
else:
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
# Handle any remaining held chunks after stream ends
|
||||
if self.holding_chunk is not None:
|
||||
self.chunk_queue.append(self.holding_chunk)
|
||||
self.holding_chunk = None
|
||||
|
||||
if not self.sent_last_message:
|
||||
self.sent_last_message = True
|
||||
self.chunk_queue.append({"type": "message_stop"})
|
||||
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
raise StopIteration
|
||||
except StopIteration:
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
if self.sent_last_message is False:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Anthropic Adapter - {}\n{}".format(e, traceback.format_exc())
|
||||
)
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def __anext__(self): # noqa: PLR0915
|
||||
from .transformation import LiteLLMAnthropicMessagesAdapter
|
||||
|
||||
try:
|
||||
# Always return queued chunks first
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
# Queue initial chunks if not sent yet
|
||||
if self.sent_first_chunk is False:
|
||||
self.sent_first_chunk = True
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": "msg_{}".format(uuid.uuid4()),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": self.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": self._create_initial_usage_delta(),
|
||||
},
|
||||
}
|
||||
)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
if self.sent_content_block_start is False:
|
||||
self.sent_content_block_start = True
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": self.current_content_block_index,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
async for chunk in self.completion_stream:
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
||||
# Check if we need to start a new content block
|
||||
should_start_new_block = self._should_start_new_content_block(chunk)
|
||||
if should_start_new_block:
|
||||
self._increment_content_block_index()
|
||||
|
||||
processed_chunk = LiteLLMAnthropicMessagesAdapter().translate_streaming_openai_response_to_anthropic(
|
||||
response=chunk,
|
||||
current_content_block_index=self.current_content_block_index,
|
||||
)
|
||||
|
||||
# Check if this is a usage chunk and we have a held stop_reason chunk
|
||||
if (
|
||||
self.holding_stop_reason_chunk is not None
|
||||
and getattr(chunk, "usage", None) is not None
|
||||
):
|
||||
# Merge usage into the held stop_reason chunk
|
||||
merged_chunk = self.holding_stop_reason_chunk.copy()
|
||||
if "delta" not in merged_chunk:
|
||||
merged_chunk["delta"] = {}
|
||||
|
||||
# Add usage to the held chunk
|
||||
uncached_input_tokens = chunk.usage.prompt_tokens or 0
|
||||
if (
|
||||
hasattr(chunk.usage, "prompt_tokens_details")
|
||||
and chunk.usage.prompt_tokens_details
|
||||
):
|
||||
cached_tokens = (
|
||||
getattr(
|
||||
chunk.usage.prompt_tokens_details, "cached_tokens", 0
|
||||
)
|
||||
or 0
|
||||
)
|
||||
uncached_input_tokens -= cached_tokens
|
||||
|
||||
usage_dict: UsageDelta = {
|
||||
"input_tokens": uncached_input_tokens,
|
||||
"output_tokens": chunk.usage.completion_tokens or 0,
|
||||
}
|
||||
# Add cache tokens if available (for prompt caching support)
|
||||
if (
|
||||
hasattr(chunk.usage, "_cache_creation_input_tokens")
|
||||
and chunk.usage._cache_creation_input_tokens > 0
|
||||
):
|
||||
usage_dict[
|
||||
"cache_creation_input_tokens"
|
||||
] = chunk.usage._cache_creation_input_tokens
|
||||
if (
|
||||
hasattr(chunk.usage, "_cache_read_input_tokens")
|
||||
and chunk.usage._cache_read_input_tokens > 0
|
||||
):
|
||||
usage_dict[
|
||||
"cache_read_input_tokens"
|
||||
] = chunk.usage._cache_read_input_tokens
|
||||
merged_chunk["usage"] = usage_dict
|
||||
|
||||
# Queue the merged chunk and reset
|
||||
self.chunk_queue.append(merged_chunk)
|
||||
self.queued_usage_chunk = True
|
||||
self.holding_stop_reason_chunk = None
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
# Check if this processed chunk has a stop_reason - hold it for next chunk
|
||||
|
||||
if not self.queued_usage_chunk:
|
||||
if should_start_new_block and not self.sent_content_block_finish:
|
||||
# Queue the sequence: content_block_stop -> content_block_start
|
||||
# The trigger chunk itself is not emitted as a delta since the
|
||||
# content_block_start already carries the relevant information.
|
||||
|
||||
# 1. Stop current content block
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": max(self.current_content_block_index - 1, 0),
|
||||
}
|
||||
)
|
||||
|
||||
# 2. Start new content block
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": self.current_content_block_index,
|
||||
"content_block": self.current_content_block_start,
|
||||
}
|
||||
)
|
||||
|
||||
# Reset state for new block
|
||||
self.sent_content_block_finish = False
|
||||
|
||||
# Return the first queued item
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
if (
|
||||
processed_chunk["type"] == "message_delta"
|
||||
and self.sent_content_block_finish is False
|
||||
):
|
||||
# Queue both the content_block_stop and the holding chunk
|
||||
self.chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": self.current_content_block_index,
|
||||
}
|
||||
)
|
||||
self.sent_content_block_finish = True
|
||||
if (
|
||||
processed_chunk.get("delta", {}).get("stop_reason")
|
||||
is not None
|
||||
):
|
||||
self.holding_stop_reason_chunk = processed_chunk
|
||||
else:
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
return self.chunk_queue.popleft()
|
||||
elif self.holding_chunk is not None:
|
||||
# Queue both chunks
|
||||
self.chunk_queue.append(self.holding_chunk)
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
self.holding_chunk = None
|
||||
return self.chunk_queue.popleft()
|
||||
else:
|
||||
# Queue the current chunk
|
||||
self.chunk_queue.append(processed_chunk)
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
# Handle any remaining held chunks after stream ends
|
||||
if not self.queued_usage_chunk:
|
||||
if self.holding_stop_reason_chunk is not None:
|
||||
self.chunk_queue.append(self.holding_stop_reason_chunk)
|
||||
self.holding_stop_reason_chunk = None
|
||||
|
||||
if self.holding_chunk is not None:
|
||||
self.chunk_queue.append(self.holding_chunk)
|
||||
self.holding_chunk = None
|
||||
|
||||
if not self.sent_last_message:
|
||||
self.sent_last_message = True
|
||||
self.chunk_queue.append({"type": "message_stop"})
|
||||
|
||||
# Return queued items if any
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
|
||||
raise StopIteration
|
||||
|
||||
except StopIteration:
|
||||
# Handle any remaining queued chunks before stopping
|
||||
if self.chunk_queue:
|
||||
return self.chunk_queue.popleft()
|
||||
# Handle any held stop_reason chunk
|
||||
if self.holding_stop_reason_chunk is not None:
|
||||
return self.holding_stop_reason_chunk
|
||||
if not self.sent_last_message:
|
||||
self.sent_last_message = True
|
||||
return {"type": "message_stop"}
|
||||
raise StopAsyncIteration
|
||||
|
||||
def anthropic_sse_wrapper(self) -> Iterator[bytes]:
|
||||
"""
|
||||
Convert AnthropicStreamWrapper dict chunks to Server-Sent Events format.
|
||||
Similar to the Bedrock bedrock_sse_wrapper implementation.
|
||||
|
||||
This wrapper ensures dict chunks are SSE formatted with both event and data lines.
|
||||
"""
|
||||
for chunk in self:
|
||||
if isinstance(chunk, dict):
|
||||
event_type: str = str(chunk.get("type", "message"))
|
||||
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
else:
|
||||
# For non-dict chunks, forward the original value unchanged
|
||||
yield chunk
|
||||
|
||||
async def async_anthropic_sse_wrapper(self) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Async version of anthropic_sse_wrapper.
|
||||
Convert AnthropicStreamWrapper dict chunks to Server-Sent Events format.
|
||||
"""
|
||||
async for chunk in self:
|
||||
if isinstance(chunk, dict):
|
||||
event_type: str = str(chunk.get("type", "message"))
|
||||
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
else:
|
||||
# For non-dict chunks, forward the original value unchanged
|
||||
yield chunk
|
||||
|
||||
def _increment_content_block_index(self):
|
||||
self.current_content_block_index += 1
|
||||
|
||||
def _should_start_new_content_block(self, chunk: "ModelResponseStream") -> bool:
|
||||
"""
|
||||
Determine if we should start a new content block based on the processed chunk.
|
||||
Override this method with your specific logic for detecting new content blocks.
|
||||
|
||||
Examples of when you might want to start a new content block:
|
||||
- Switching from text to tool calls
|
||||
- Different content types in the response
|
||||
- Specific markers in the content
|
||||
"""
|
||||
from .transformation import LiteLLMAnthropicMessagesAdapter
|
||||
|
||||
# Example logic - customize based on your needs:
|
||||
# If chunk indicates a tool call
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
return False
|
||||
|
||||
(
|
||||
block_type,
|
||||
content_block_start,
|
||||
) = LiteLLMAnthropicMessagesAdapter()._translate_streaming_openai_chunk_to_anthropic_content_block(
|
||||
choices=chunk.choices # type: ignore
|
||||
)
|
||||
|
||||
# Restore original tool name if it was truncated for OpenAI's 64-char limit
|
||||
if block_type == "tool_use":
|
||||
# Type narrowing: content_block_start is ToolUseBlock when block_type is "tool_use"
|
||||
from typing import cast
|
||||
|
||||
from litellm.types.llms.anthropic import ToolUseBlock
|
||||
|
||||
tool_block = cast(ToolUseBlock, content_block_start)
|
||||
|
||||
if tool_block.get("name"):
|
||||
truncated_name = tool_block["name"]
|
||||
original_name = self.tool_name_mapping.get(
|
||||
truncated_name, truncated_name
|
||||
)
|
||||
tool_block["name"] = original_name
|
||||
|
||||
if block_type != self.current_content_block_type:
|
||||
self.current_content_block_type = block_type
|
||||
self.current_content_block_start = content_block_start
|
||||
return True
|
||||
|
||||
# For parallel tool calls, we'll necessarily have a new content block
|
||||
# if we get a function name since it signals a new tool call
|
||||
if block_type == "tool_use":
|
||||
from typing import cast
|
||||
|
||||
from litellm.types.llms.anthropic import ToolUseBlock
|
||||
|
||||
tool_block = cast(ToolUseBlock, content_block_start)
|
||||
if tool_block.get("name"):
|
||||
self.current_content_block_type = block_type
|
||||
self.current_content_block_start = content_block_start
|
||||
return True
|
||||
|
||||
return False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
# Anthropic Messages Pass-Through Architecture
|
||||
|
||||
## Request Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[litellm.anthropic.messages.acreate] --> B{Provider?}
|
||||
|
||||
B -->|anthropic| C[AnthropicMessagesConfig]
|
||||
B -->|azure_ai| D[AzureAnthropicMessagesConfig]
|
||||
B -->|bedrock invoke| E[BedrockAnthropicMessagesConfig]
|
||||
B -->|vertex_ai| F[VertexAnthropicMessagesConfig]
|
||||
B -->|Other providers| G[LiteLLMAnthropicMessagesAdapter]
|
||||
|
||||
C --> H[Direct Anthropic API]
|
||||
D --> I[Azure AI Foundry API]
|
||||
E --> J[Bedrock Invoke API]
|
||||
F --> K[Vertex AI API]
|
||||
|
||||
G --> L[translate_anthropic_to_openai]
|
||||
L --> M[litellm.completion]
|
||||
M --> N[Provider API]
|
||||
N --> O[translate_openai_response_to_anthropic]
|
||||
O --> P[Anthropic Response Format]
|
||||
|
||||
H --> P
|
||||
I --> P
|
||||
J --> P
|
||||
K --> P
|
||||
```
|
||||
|
||||
## Adapter Flow (Non-Native Providers)
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant Handler as anthropic_messages_handler
|
||||
participant Adapter as LiteLLMAnthropicMessagesAdapter
|
||||
participant LiteLLM as litellm.completion
|
||||
participant Provider as Provider API
|
||||
|
||||
User->>Handler: Anthropic Messages Request
|
||||
Handler->>Adapter: translate_anthropic_to_openai()
|
||||
Note over Adapter: messages, tools, thinking,<br/>output_format → response_format
|
||||
Adapter->>LiteLLM: OpenAI Format Request
|
||||
LiteLLM->>Provider: Provider-specific Request
|
||||
Provider->>LiteLLM: Provider Response
|
||||
LiteLLM->>Adapter: OpenAI Format Response
|
||||
Adapter->>Handler: translate_openai_response_to_anthropic()
|
||||
Handler->>User: Anthropic Messages Response
|
||||
```
|
||||
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Fake Streaming Iterator for Anthropic Messages
|
||||
|
||||
This module provides a fake streaming iterator that converts non-streaming
|
||||
Anthropic Messages responses into proper streaming format.
|
||||
|
||||
Used when WebSearch interception converts stream=True to stream=False but
|
||||
the LLM doesn't make a tool call, and we need to return a stream to the user.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, cast
|
||||
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
class FakeAnthropicMessagesStreamIterator:
|
||||
"""
|
||||
Fake streaming iterator for Anthropic Messages responses.
|
||||
|
||||
Used when we need to convert a non-streaming response to a streaming format,
|
||||
such as when WebSearch interception converts stream=True to stream=False but
|
||||
the LLM doesn't make a tool call.
|
||||
|
||||
This creates a proper Anthropic-style streaming response with multiple events:
|
||||
- message_start
|
||||
- content_block_start (for each content block)
|
||||
- content_block_delta (for text content, chunked)
|
||||
- content_block_stop
|
||||
- message_delta (for usage)
|
||||
- message_stop
|
||||
"""
|
||||
|
||||
def __init__(self, response: AnthropicMessagesResponse):
|
||||
self.response = response
|
||||
self.chunks = self._create_streaming_chunks()
|
||||
self.current_index = 0
|
||||
|
||||
def _create_streaming_chunks(self) -> List[bytes]:
|
||||
"""Convert the non-streaming response to streaming chunks"""
|
||||
chunks = []
|
||||
|
||||
# Cast response to dict for easier access
|
||||
response_dict = cast(Dict[str, Any], self.response)
|
||||
|
||||
# 1. message_start event
|
||||
usage = response_dict.get("usage", {})
|
||||
message_start = {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": response_dict.get("id"),
|
||||
"type": "message",
|
||||
"role": response_dict.get("role", "assistant"),
|
||||
"model": response_dict.get("model"),
|
||||
"content": [],
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": usage.get("input_tokens", 0) if usage else 0,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: message_start\ndata: {json.dumps(message_start)}\n\n".encode()
|
||||
)
|
||||
|
||||
# 2-4. For each content block, send start/delta/stop events
|
||||
content_blocks = response_dict.get("content", [])
|
||||
if content_blocks:
|
||||
for index, block in enumerate(content_blocks):
|
||||
# Cast block to dict for easier access
|
||||
block_dict = cast(Dict[str, Any], block)
|
||||
block_type = block_dict.get("type")
|
||||
|
||||
if block_type == "text":
|
||||
# content_block_start
|
||||
content_block_start = {
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_delta (send full text as one delta for simplicity)
|
||||
text = block_dict.get("text", "")
|
||||
content_block_delta = {
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {"type": "text_delta", "text": text},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_stop
|
||||
content_block_stop = {"type": "content_block_stop", "index": index}
|
||||
chunks.append(
|
||||
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
|
||||
)
|
||||
|
||||
elif block_type == "thinking":
|
||||
# content_block_start for thinking
|
||||
content_block_start = {
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {
|
||||
"type": "thinking",
|
||||
"thinking": "",
|
||||
"signature": "",
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_delta for thinking text
|
||||
thinking_text = block_dict.get("thinking", "")
|
||||
if thinking_text:
|
||||
content_block_delta = {
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "thinking_delta",
|
||||
"thinking": thinking_text,
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_delta for signature (if present)
|
||||
signature = block_dict.get("signature", "")
|
||||
if signature:
|
||||
signature_delta = {
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "signature_delta",
|
||||
"signature": signature,
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_delta\ndata: {json.dumps(signature_delta)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_stop
|
||||
content_block_stop = {"type": "content_block_stop", "index": index}
|
||||
chunks.append(
|
||||
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
|
||||
)
|
||||
|
||||
elif block_type == "redacted_thinking":
|
||||
# content_block_start for redacted_thinking
|
||||
content_block_start = {
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {"type": "redacted_thinking"},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_stop (no delta for redacted thinking)
|
||||
content_block_stop = {"type": "content_block_stop", "index": index}
|
||||
chunks.append(
|
||||
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
|
||||
)
|
||||
|
||||
elif block_type == "tool_use":
|
||||
# content_block_start
|
||||
content_block_start = {
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {
|
||||
"type": "tool_use",
|
||||
"id": block_dict.get("id"),
|
||||
"name": block_dict.get("name"),
|
||||
"input": {},
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_start\ndata: {json.dumps(content_block_start)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_delta (send input as JSON delta)
|
||||
input_data = block_dict.get("input", {})
|
||||
content_block_delta = {
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {
|
||||
"type": "input_json_delta",
|
||||
"partial_json": json.dumps(input_data),
|
||||
},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: content_block_delta\ndata: {json.dumps(content_block_delta)}\n\n".encode()
|
||||
)
|
||||
|
||||
# content_block_stop
|
||||
content_block_stop = {"type": "content_block_stop", "index": index}
|
||||
chunks.append(
|
||||
f"event: content_block_stop\ndata: {json.dumps(content_block_stop)}\n\n".encode()
|
||||
)
|
||||
|
||||
# 5. message_delta event (with final usage and stop_reason)
|
||||
message_delta = {
|
||||
"type": "message_delta",
|
||||
"delta": {
|
||||
"stop_reason": response_dict.get("stop_reason"),
|
||||
"stop_sequence": response_dict.get("stop_sequence"),
|
||||
},
|
||||
"usage": {"output_tokens": usage.get("output_tokens", 0) if usage else 0},
|
||||
}
|
||||
chunks.append(
|
||||
f"event: message_delta\ndata: {json.dumps(message_delta)}\n\n".encode()
|
||||
)
|
||||
|
||||
# 6. message_stop event
|
||||
message_stop = {"type": "message_stop", "usage": usage if usage else {}}
|
||||
chunks.append(
|
||||
f"event: message_stop\ndata: {json.dumps(message_stop)}\n\n".encode()
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
if self.current_index >= len(self.chunks):
|
||||
raise StopAsyncIteration
|
||||
|
||||
chunk = self.chunks[self.current_index]
|
||||
self.current_index += 1
|
||||
return chunk
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self.current_index >= len(self.chunks):
|
||||
raise StopIteration
|
||||
|
||||
chunk = self.chunks[self.current_index]
|
||||
self.current_index += 1
|
||||
return chunk
|
||||
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
- call /messages on Anthropic API
|
||||
- Make streaming + non-streaming request - just pass it through direct to Anthropic. No need to do anything special here
|
||||
- Ensure requests are logged in the DB - stream + non-stream
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.types.llms.anthropic_messages.anthropic_request import AnthropicMetadata
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
from ..adapters.handler import LiteLLMMessagesToCompletionTransformationHandler
|
||||
from ..responses_adapters.handler import LiteLLMMessagesToResponsesAPIHandler
|
||||
from .utils import AnthropicMessagesRequestUtils, mock_response
|
||||
|
||||
# Providers that are routed directly to the OpenAI Responses API instead of
|
||||
# going through chat/completions.
|
||||
_RESPONSES_API_PROVIDERS = frozenset({"openai"})
|
||||
|
||||
|
||||
def _should_route_to_responses_api(custom_llm_provider: Optional[str]) -> bool:
|
||||
"""Return True when the provider should use the Responses API path.
|
||||
|
||||
Set ``litellm.use_chat_completions_url_for_anthropic_messages = True`` to
|
||||
opt out and route OpenAI/Azure requests through chat/completions instead.
|
||||
"""
|
||||
if litellm.use_chat_completions_url_for_anthropic_messages:
|
||||
return False
|
||||
return custom_llm_provider in _RESPONSES_API_PROVIDERS
|
||||
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
async def _execute_pre_request_hooks(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
tools: Optional[List[Dict]],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str],
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
"""
|
||||
Execute pre-request hooks from CustomLogger callbacks.
|
||||
|
||||
Allows CustomLoggers to modify request parameters before the API call.
|
||||
Used for WebSearch tool conversion, stream modification, etc.
|
||||
|
||||
Args:
|
||||
model: Model name
|
||||
messages: List of messages
|
||||
tools: Optional tools list
|
||||
stream: Optional stream flag
|
||||
custom_llm_provider: Provider name (if not set, will be extracted from model)
|
||||
**kwargs: Additional request parameters
|
||||
|
||||
Returns:
|
||||
Dict containing all (potentially modified) request parameters including tools, stream
|
||||
"""
|
||||
# If custom_llm_provider not provided, extract from model
|
||||
if not custom_llm_provider:
|
||||
try:
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
||||
except Exception:
|
||||
# If extraction fails, continue without provider
|
||||
pass
|
||||
|
||||
# Build complete request kwargs dict
|
||||
request_kwargs = {
|
||||
"tools": tools,
|
||||
"stream": stream,
|
||||
"litellm_params": {
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
},
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if not litellm.callbacks:
|
||||
return request_kwargs
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger as _CustomLogger
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
if not isinstance(callback, _CustomLogger):
|
||||
continue
|
||||
|
||||
# Call the pre-request hook
|
||||
modified_kwargs = await callback.async_pre_request_hook(
|
||||
model, messages, request_kwargs
|
||||
)
|
||||
|
||||
# If hook returned modified kwargs, use them
|
||||
if modified_kwargs is not None:
|
||||
request_kwargs = modified_kwargs
|
||||
|
||||
return request_kwargs
|
||||
|
||||
|
||||
@client
|
||||
async def anthropic_messages(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""
|
||||
Async: Make llm api request in Anthropic /messages API spec
|
||||
"""
|
||||
# Execute pre-request hooks to allow CustomLoggers to modify request
|
||||
request_kwargs = await _execute_pre_request_hooks(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
stream=stream,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract modified parameters
|
||||
tools = request_kwargs.pop("tools", tools)
|
||||
stream = request_kwargs.pop("stream", stream)
|
||||
# Remove litellm_params from kwargs (only needed for hooks)
|
||||
request_kwargs.pop("litellm_params", None)
|
||||
# Merge back any other modifications
|
||||
kwargs.update(request_kwargs)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
func = partial(
|
||||
anthropic_messages_handler,
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
return response
|
||||
|
||||
|
||||
def validate_anthropic_api_metadata(metadata: Optional[Dict] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Validate Anthropic API metadata - This is done to ensure only allowed `metadata` fields are passed to Anthropic API
|
||||
|
||||
If there are any litellm specific metadata fields, use `litellm_metadata` key to pass them.
|
||||
"""
|
||||
if metadata is None:
|
||||
return None
|
||||
anthropic_metadata_obj = AnthropicMetadata(**metadata)
|
||||
return anthropic_metadata_obj.model_dump(exclude_none=True)
|
||||
|
||||
|
||||
def anthropic_messages_handler(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
container: Optional[Dict] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
AnthropicMessagesResponse,
|
||||
AsyncIterator[Any],
|
||||
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
|
||||
]:
|
||||
"""
|
||||
Makes Anthropic `/v1/messages` API calls In the Anthropic API Spec
|
||||
|
||||
Args:
|
||||
container: Container config with skills for code execution
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
metadata = validate_anthropic_api_metadata(metadata)
|
||||
|
||||
local_vars = locals()
|
||||
is_async = kwargs.pop("is_async", False)
|
||||
# Use provided client or create a new one
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
|
||||
# Store original model name before get_llm_provider strips the provider prefix
|
||||
# This is needed by agentic hooks (e.g., websearch_interception) to make follow-up requests
|
||||
original_model = model
|
||||
|
||||
litellm_params = GenericLiteLLMParams(
|
||||
**kwargs,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
(
|
||||
model,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
)
|
||||
|
||||
# Store agentic loop params in logging object for agentic hooks
|
||||
# This provides original request context needed for follow-up calls
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model_call_details["agentic_loop_params"] = {
|
||||
"model": original_model,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
|
||||
# Check if stream was converted for WebSearch interception
|
||||
# This is set in the async wrapper above when stream=True is converted to stream=False
|
||||
if kwargs.get("_websearch_interception_converted_stream", False):
|
||||
litellm_logging_obj.model_call_details[
|
||||
"websearch_interception_converted_stream"
|
||||
] = True
|
||||
|
||||
if litellm_params.mock_response and isinstance(litellm_params.mock_response, str):
|
||||
return mock_response(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
mock_response=litellm_params.mock_response,
|
||||
)
|
||||
|
||||
anthropic_messages_provider_config: Optional[BaseAnthropicMessagesConfig] = None
|
||||
|
||||
if custom_llm_provider is not None and custom_llm_provider in [
|
||||
provider.value for provider in LlmProviders
|
||||
]:
|
||||
anthropic_messages_provider_config = (
|
||||
ProviderConfigManager.get_provider_anthropic_messages_config(
|
||||
model=model,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
)
|
||||
if anthropic_messages_provider_config is None:
|
||||
# Route to Responses API for OpenAI / Azure, chat/completions for everything else.
|
||||
_shared_kwargs = dict(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
_is_async=is_async,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
if _should_route_to_responses_api(custom_llm_provider):
|
||||
return LiteLLMMessagesToResponsesAPIHandler.anthropic_messages_handler(
|
||||
**_shared_kwargs
|
||||
)
|
||||
return (
|
||||
LiteLLMMessagesToCompletionTransformationHandler.anthropic_messages_handler(
|
||||
**_shared_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
if custom_llm_provider is None:
|
||||
raise ValueError(
|
||||
f"custom_llm_provider is required for Anthropic messages, passed in model={model}, custom_llm_provider={custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
anthropic_messages_optional_request_params = (
|
||||
AnthropicMessagesRequestUtils.get_requested_anthropic_messages_optional_param(
|
||||
params=local_vars
|
||||
)
|
||||
)
|
||||
return base_llm_http_handler.anthropic_messages_handler(
|
||||
model=model,
|
||||
messages=messages,
|
||||
anthropic_messages_provider_config=anthropic_messages_provider_config,
|
||||
anthropic_messages_optional_request_params=dict(
|
||||
anthropic_messages_optional_request_params
|
||||
),
|
||||
_is_async=is_async,
|
||||
client=client,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
stream=stream,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
@@ -0,0 +1,108 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncIterator, List, Union
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||
|
||||
GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ = PassThroughEndpointLogging()
|
||||
|
||||
|
||||
class BaseAnthropicMessagesStreamingIterator:
|
||||
"""
|
||||
Base class for Anthropic Messages streaming iterators that provides common logic
|
||||
for streaming response handling and logging.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
request_body: dict,
|
||||
):
|
||||
self.litellm_logging_obj = litellm_logging_obj
|
||||
self.request_body = request_body
|
||||
self.start_time = datetime.now()
|
||||
|
||||
async def _handle_streaming_logging(self, collected_chunks: List[bytes]):
|
||||
"""Handle the logging after all chunks have been collected."""
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
|
||||
end_time = datetime.now()
|
||||
asyncio.create_task(
|
||||
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||
litellm_logging_obj=self.litellm_logging_obj,
|
||||
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
|
||||
url_route="/v1/messages",
|
||||
request_body=self.request_body or {},
|
||||
endpoint_type=EndpointType.ANTHROPIC,
|
||||
start_time=self.start_time,
|
||||
raw_bytes=collected_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
def get_async_streaming_response_iterator(
|
||||
self,
|
||||
httpx_response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
|
||||
from litellm.proxy.pass_through_endpoints.streaming_handler import (
|
||||
PassThroughStreamingHandler,
|
||||
)
|
||||
|
||||
# Use the existing streaming handler for Anthropic
|
||||
return PassThroughStreamingHandler.chunk_processor(
|
||||
response=httpx_response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
endpoint_type=EndpointType.ANTHROPIC,
|
||||
start_time=self.start_time,
|
||||
passthrough_success_handler_obj=GLOBAL_PASS_THROUGH_SUCCESS_HANDLER_OBJ,
|
||||
url_route="/v1/messages",
|
||||
)
|
||||
|
||||
def _convert_chunk_to_sse_format(self, chunk: Union[dict, Any]) -> bytes:
|
||||
"""
|
||||
Convert a chunk to Server-Sent Events format.
|
||||
|
||||
This method should be overridden by subclasses if they need custom
|
||||
chunk formatting logic.
|
||||
"""
|
||||
if isinstance(chunk, dict):
|
||||
event_type: str = str(chunk.get("type", "message"))
|
||||
payload = f"event: {event_type}\n" f"data: {json.dumps(chunk)}\n\n"
|
||||
return payload.encode()
|
||||
else:
|
||||
# For non-dict chunks, return as is
|
||||
return chunk
|
||||
|
||||
async def async_sse_wrapper(
|
||||
self,
|
||||
completion_stream: AsyncIterator[
|
||||
Union[bytes, GenericStreamingChunk, ModelResponseStream, dict]
|
||||
],
|
||||
) -> AsyncIterator[bytes]:
|
||||
"""
|
||||
Generic async SSE wrapper that converts streaming chunks to SSE format
|
||||
and handles logging.
|
||||
|
||||
This method provides the common logic for both Anthropic and Bedrock implementations.
|
||||
"""
|
||||
collected_chunks = []
|
||||
|
||||
async for chunk in completion_stream:
|
||||
encoded_chunk = self._convert_chunk_to_sse_format(chunk)
|
||||
collected_chunks.append(encoded_chunk)
|
||||
yield encoded_chunk
|
||||
|
||||
# Handle logging after all chunks are processed
|
||||
await self._handle_streaming_logging(collected_chunks)
|
||||
@@ -0,0 +1,308 @@
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import verbose_logger
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.types.llms.anthropic import (
|
||||
ANTHROPIC_BETA_HEADER_VALUES,
|
||||
AnthropicMessagesRequest,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.llms.anthropic_tool_search import get_tool_search_beta_header
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ...common_utils import (
|
||||
AnthropicError,
|
||||
AnthropicModelInfo,
|
||||
optionally_handle_anthropic_oauth,
|
||||
)
|
||||
|
||||
DEFAULT_ANTHROPIC_API_BASE = "https://api.anthropic.com"
|
||||
DEFAULT_ANTHROPIC_API_VERSION = "2023-06-01"
|
||||
|
||||
|
||||
class AnthropicMessagesConfig(BaseAnthropicMessagesConfig):
|
||||
def get_supported_anthropic_messages_params(self, model: str) -> list:
|
||||
return [
|
||||
"messages",
|
||||
"model",
|
||||
"system",
|
||||
"max_tokens",
|
||||
"stop_sequences",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"thinking",
|
||||
"context_management",
|
||||
"output_format",
|
||||
"inference_geo",
|
||||
"speed",
|
||||
"output_config",
|
||||
# TODO: Add Anthropic `metadata` support
|
||||
# "metadata",
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _filter_billing_headers_from_system(system_param):
|
||||
"""
|
||||
Filter out x-anthropic-billing-header metadata from system parameter.
|
||||
|
||||
Args:
|
||||
system_param: Can be a string or a list of system message content blocks
|
||||
|
||||
Returns:
|
||||
Filtered system parameter (string or list), or None if all content was filtered
|
||||
"""
|
||||
if isinstance(system_param, str):
|
||||
# If it's a string and starts with billing header, filter it out
|
||||
if system_param.startswith("x-anthropic-billing-header:"):
|
||||
return None
|
||||
return system_param
|
||||
elif isinstance(system_param, list):
|
||||
# Filter list of system content blocks
|
||||
filtered_list = []
|
||||
for content_block in system_param:
|
||||
if isinstance(content_block, dict):
|
||||
text = content_block.get("text", "")
|
||||
content_type = content_block.get("type", "")
|
||||
# Skip text blocks that start with billing header
|
||||
if content_type == "text" and text.startswith(
|
||||
"x-anthropic-billing-header:"
|
||||
):
|
||||
continue
|
||||
filtered_list.append(content_block)
|
||||
else:
|
||||
# Keep non-dict items as-is
|
||||
filtered_list.append(content_block)
|
||||
return filtered_list if len(filtered_list) > 0 else None
|
||||
else:
|
||||
return system_param
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = api_base or DEFAULT_ANTHROPIC_API_BASE
|
||||
if not api_base.endswith("/v1/messages"):
|
||||
api_base = f"{api_base}/v1/messages"
|
||||
return api_base
|
||||
|
||||
def validate_anthropic_messages_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
import os
|
||||
|
||||
# Check for Anthropic OAuth token in Authorization header
|
||||
headers, api_key = optionally_handle_anthropic_oauth(
|
||||
headers=headers, api_key=api_key
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
|
||||
if "x-api-key" not in headers and "authorization" not in headers and api_key:
|
||||
headers["x-api-key"] = api_key
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = DEFAULT_ANTHROPIC_API_VERSION
|
||||
if "content-type" not in headers:
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
headers = self._update_headers_with_anthropic_beta(
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
No transformation is needed for Anthropic messages
|
||||
|
||||
|
||||
This takes in a request in the Anthropic /v1/messages API spec -> transforms it to /v1/messages API spec (i.e) no transformation is needed
|
||||
"""
|
||||
max_tokens = anthropic_messages_optional_request_params.pop("max_tokens", None)
|
||||
if max_tokens is None:
|
||||
raise AnthropicError(
|
||||
message="max_tokens is required for Anthropic /v1/messages API",
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
# Filter out x-anthropic-billing-header from system messages
|
||||
system_param = anthropic_messages_optional_request_params.get("system")
|
||||
if system_param is not None:
|
||||
filtered_system = self._filter_billing_headers_from_system(system_param)
|
||||
if filtered_system is not None and len(filtered_system) > 0:
|
||||
anthropic_messages_optional_request_params["system"] = filtered_system
|
||||
else:
|
||||
# Remove system parameter if all content was filtered out
|
||||
anthropic_messages_optional_request_params.pop("system", None)
|
||||
|
||||
# Transform context_management from OpenAI format to Anthropic format if needed
|
||||
context_management_param = anthropic_messages_optional_request_params.get(
|
||||
"context_management"
|
||||
)
|
||||
if context_management_param is not None:
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
|
||||
transformed_context_management = (
|
||||
AnthropicConfig.map_openai_context_management_to_anthropic(
|
||||
context_management_param
|
||||
)
|
||||
)
|
||||
if transformed_context_management is not None:
|
||||
anthropic_messages_optional_request_params[
|
||||
"context_management"
|
||||
] = transformed_context_management
|
||||
|
||||
####### get required params for all anthropic messages requests ######
|
||||
verbose_logger.debug(f"TRANSFORMATION DEBUG - Messages: {messages}")
|
||||
anthropic_messages_request: AnthropicMessagesRequest = AnthropicMessagesRequest(
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
model=model,
|
||||
**anthropic_messages_optional_request_params,
|
||||
)
|
||||
return dict(anthropic_messages_request)
|
||||
|
||||
def transform_anthropic_messages_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
No transformation is needed for Anthropic messages, since we want the response in the Anthropic /v1/messages API spec
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise AnthropicError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return AnthropicMessagesResponse(**raw_response_json)
|
||||
|
||||
def get_async_streaming_response_iterator(
|
||||
self,
|
||||
model: str,
|
||||
httpx_response: httpx.Response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
"""Helper function to handle Anthropic streaming responses using the existing logging handlers"""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.streaming_iterator import (
|
||||
BaseAnthropicMessagesStreamingIterator,
|
||||
)
|
||||
|
||||
# Use the shared streaming handler for Anthropic
|
||||
handler = BaseAnthropicMessagesStreamingIterator(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
request_body=request_body,
|
||||
)
|
||||
return handler.get_async_streaming_response_iterator(
|
||||
httpx_response=httpx_response,
|
||||
request_body=request_body,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _update_headers_with_anthropic_beta(
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
custom_llm_provider: str = "anthropic",
|
||||
) -> dict:
|
||||
"""
|
||||
Auto-inject anthropic-beta headers based on features used.
|
||||
|
||||
Handles:
|
||||
- context_management: adds 'context-management-2025-06-27'
|
||||
- tool_search: adds provider-specific tool search header
|
||||
- output_format: adds 'structured-outputs-2025-11-13'
|
||||
- speed: adds 'fast-mode-2026-02-01'
|
||||
|
||||
Args:
|
||||
headers: Request headers dict
|
||||
optional_params: Optional parameters including tools, context_management, output_format, speed
|
||||
custom_llm_provider: Provider name for looking up correct tool search header
|
||||
"""
|
||||
beta_values: set = set()
|
||||
|
||||
# Get existing beta headers if any
|
||||
existing_beta = headers.get("anthropic-beta")
|
||||
if existing_beta:
|
||||
beta_values.update(b.strip() for b in existing_beta.split(","))
|
||||
|
||||
# Check for context management
|
||||
context_management_param = optional_params.get("context_management")
|
||||
if context_management_param is not None:
|
||||
# Check edits array for compact_20260112 type
|
||||
edits = context_management_param.get("edits", [])
|
||||
has_compact = False
|
||||
has_other = False
|
||||
|
||||
for edit in edits:
|
||||
edit_type = edit.get("type", "")
|
||||
if edit_type == "compact_20260112":
|
||||
has_compact = True
|
||||
else:
|
||||
has_other = True
|
||||
|
||||
# Add compact header if any compact edits exist
|
||||
if has_compact:
|
||||
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.COMPACT_2026_01_12.value)
|
||||
|
||||
# Add context management header if any other edits exist
|
||||
if has_other:
|
||||
beta_values.add(
|
||||
ANTHROPIC_BETA_HEADER_VALUES.CONTEXT_MANAGEMENT_2025_06_27.value
|
||||
)
|
||||
|
||||
# Check for structured outputs
|
||||
if optional_params.get("output_format") is not None:
|
||||
beta_values.add(
|
||||
ANTHROPIC_BETA_HEADER_VALUES.STRUCTURED_OUTPUT_2025_09_25.value
|
||||
)
|
||||
|
||||
# Check for fast mode
|
||||
if optional_params.get("speed") == "fast":
|
||||
beta_values.add(ANTHROPIC_BETA_HEADER_VALUES.FAST_MODE_2026_02_01.value)
|
||||
|
||||
# Check for tool search tools
|
||||
tools = optional_params.get("tools")
|
||||
if tools:
|
||||
anthropic_model_info = AnthropicModelInfo()
|
||||
if anthropic_model_info.is_tool_search_used(tools):
|
||||
# Use provider-specific tool search header
|
||||
tool_search_header = get_tool_search_beta_header(custom_llm_provider)
|
||||
beta_values.add(tool_search_header)
|
||||
|
||||
if beta_values:
|
||||
headers["anthropic-beta"] = ",".join(sorted(beta_values))
|
||||
|
||||
return headers
|
||||
@@ -0,0 +1,75 @@
|
||||
from typing import Any, Dict, List, cast, get_type_hints
|
||||
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesRequestOptionalParams
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
class AnthropicMessagesRequestUtils:
|
||||
@staticmethod
|
||||
def get_requested_anthropic_messages_optional_param(
|
||||
params: Dict[str, Any],
|
||||
) -> AnthropicMessagesRequestOptionalParams:
|
||||
"""
|
||||
Filter parameters to only include those defined in AnthropicMessagesRequestOptionalParams.
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameters to filter
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesRequestOptionalParams instance with only the valid parameters
|
||||
"""
|
||||
valid_keys = get_type_hints(AnthropicMessagesRequestOptionalParams).keys()
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items() if k in valid_keys and v is not None
|
||||
}
|
||||
return cast(AnthropicMessagesRequestOptionalParams, filtered_params)
|
||||
|
||||
|
||||
def mock_response(
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
max_tokens: int,
|
||||
mock_response: str = "Hi! My name is Claude.",
|
||||
**kwargs,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
Mock response for Anthropic messages
|
||||
"""
|
||||
from litellm.exceptions import (
|
||||
ContextWindowExceededError,
|
||||
InternalServerError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
if mock_response == "litellm.InternalServerError":
|
||||
raise InternalServerError(
|
||||
message="this is a mock internal server error",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
elif mock_response == "litellm.ContextWindowExceededError":
|
||||
raise ContextWindowExceededError(
|
||||
message="this is a mock context window exceeded error",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
elif mock_response == "litellm.RateLimitError":
|
||||
raise RateLimitError(
|
||||
message="this is a mock rate limit error",
|
||||
llm_provider="anthropic",
|
||||
model=model,
|
||||
)
|
||||
return AnthropicMessagesResponse(
|
||||
**{
|
||||
"content": [{"text": mock_response, "type": "text"}],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": None,
|
||||
"type": "message",
|
||||
"usage": {"input_tokens": 2095, "output_tokens": 503},
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .transformation import LiteLLMAnthropicToResponsesAPIAdapter
|
||||
|
||||
__all__ = ["LiteLLMAnthropicToResponsesAPIAdapter"]
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Handler for the Anthropic v1/messages -> OpenAI Responses API path.
|
||||
|
||||
Used when the target model is an OpenAI or Azure model.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesRequest
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse
|
||||
|
||||
from .streaming_iterator import AnthropicResponsesStreamWrapper
|
||||
from .transformation import LiteLLMAnthropicToResponsesAPIAdapter
|
||||
|
||||
_ADAPTER = LiteLLMAnthropicToResponsesAPIAdapter()
|
||||
|
||||
|
||||
def _build_responses_kwargs(
|
||||
*,
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
context_management: Optional[Dict] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
output_config: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
extra_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build the kwargs dict to pass directly to litellm.responses() / litellm.aresponses().
|
||||
"""
|
||||
# Build a typed AnthropicMessagesRequest for the adapter
|
||||
request_data: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"max_tokens": max_tokens,
|
||||
}
|
||||
if context_management:
|
||||
request_data["context_management"] = context_management
|
||||
if output_config:
|
||||
request_data["output_config"] = output_config
|
||||
if metadata:
|
||||
request_data["metadata"] = metadata
|
||||
if system:
|
||||
request_data["system"] = system
|
||||
if temperature is not None:
|
||||
request_data["temperature"] = temperature
|
||||
if thinking:
|
||||
request_data["thinking"] = thinking
|
||||
if tool_choice:
|
||||
request_data["tool_choice"] = tool_choice
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
if top_p is not None:
|
||||
request_data["top_p"] = top_p
|
||||
if output_format:
|
||||
request_data["output_format"] = output_format
|
||||
|
||||
anthropic_request = AnthropicMessagesRequest(**request_data) # type: ignore[typeddict-item]
|
||||
responses_kwargs = _ADAPTER.translate_request(anthropic_request)
|
||||
|
||||
if stream:
|
||||
responses_kwargs["stream"] = True
|
||||
|
||||
# Forward litellm-specific kwargs (api_key, api_base, logging obj, etc.)
|
||||
excluded = {"anthropic_messages"}
|
||||
for key, value in (extra_kwargs or {}).items():
|
||||
if key == "litellm_logging_obj" and value is not None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObject,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
if isinstance(value, LiteLLMLoggingObject):
|
||||
# Reclassify as acompletion so the success handler doesn't try to
|
||||
# validate the Responses API event as an AnthropicResponse.
|
||||
# (Mirrors the pattern used in LiteLLMMessagesToCompletionTransformationHandler.)
|
||||
setattr(value, "call_type", CallTypes.acompletion.value)
|
||||
responses_kwargs[key] = value
|
||||
elif key not in excluded and key not in responses_kwargs and value is not None:
|
||||
responses_kwargs[key] = value
|
||||
|
||||
return responses_kwargs
|
||||
|
||||
|
||||
class LiteLLMMessagesToResponsesAPIHandler:
|
||||
"""
|
||||
Handles Anthropic /v1/messages requests for OpenAI / Azure models by
|
||||
calling litellm.responses() / litellm.aresponses() directly and translating
|
||||
the response back to Anthropic format.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def async_anthropic_messages_handler(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
context_management: Optional[Dict] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
output_config: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
**kwargs,
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
responses_kwargs = _build_responses_kwargs(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
context_management=context_management,
|
||||
metadata=metadata,
|
||||
output_config=output_config,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
result = await litellm.aresponses(**responses_kwargs)
|
||||
|
||||
if stream:
|
||||
wrapper = AnthropicResponsesStreamWrapper(
|
||||
responses_stream=result, model=model
|
||||
)
|
||||
return wrapper.async_anthropic_sse_wrapper()
|
||||
|
||||
if not isinstance(result, ResponsesAPIResponse):
|
||||
raise ValueError(f"Expected ResponsesAPIResponse, got {type(result)}")
|
||||
|
||||
return _ADAPTER.translate_response(result)
|
||||
|
||||
@staticmethod
|
||||
def anthropic_messages_handler(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
context_management: Optional[Dict] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
output_config: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
output_format: Optional[Dict] = None,
|
||||
_is_async: bool = False,
|
||||
**kwargs,
|
||||
) -> Union[
|
||||
AnthropicMessagesResponse,
|
||||
AsyncIterator[Any],
|
||||
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
|
||||
]:
|
||||
if _is_async:
|
||||
return (
|
||||
LiteLLMMessagesToResponsesAPIHandler.async_anthropic_messages_handler(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
context_management=context_management,
|
||||
metadata=metadata,
|
||||
output_config=output_config,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
# Sync path
|
||||
responses_kwargs = _build_responses_kwargs(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
context_management=context_management,
|
||||
metadata=metadata,
|
||||
output_config=output_config,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
output_format=output_format,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
result = litellm.responses(**responses_kwargs)
|
||||
|
||||
if stream:
|
||||
wrapper = AnthropicResponsesStreamWrapper(
|
||||
responses_stream=result, model=model
|
||||
)
|
||||
return wrapper.async_anthropic_sse_wrapper()
|
||||
|
||||
if not isinstance(result, ResponsesAPIResponse):
|
||||
raise ValueError(f"Expected ResponsesAPIResponse, got {type(result)}")
|
||||
|
||||
return _ADAPTER.translate_response(result)
|
||||
@@ -0,0 +1,344 @@
|
||||
# What is this?
|
||||
## Translates OpenAI call to Anthropic `/v1/messages` format
|
||||
import json
|
||||
import traceback
|
||||
from collections import deque
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
|
||||
|
||||
class AnthropicResponsesStreamWrapper:
|
||||
"""
|
||||
Wraps a Responses API streaming iterator and re-emits events in Anthropic SSE format.
|
||||
|
||||
Responses API event flow (relevant subset):
|
||||
response.created -> message_start
|
||||
response.output_item.added -> content_block_start (if message/function_call)
|
||||
response.output_text.delta -> content_block_delta (text_delta)
|
||||
response.reasoning_summary_text.delta -> content_block_delta (thinking_delta)
|
||||
response.function_call_arguments.delta -> content_block_delta (input_json_delta)
|
||||
response.output_item.done -> content_block_stop
|
||||
response.completed -> message_delta + message_stop
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
responses_stream: Any,
|
||||
model: str,
|
||||
) -> None:
|
||||
self.responses_stream = responses_stream
|
||||
self.model = model
|
||||
self._message_id: str = f"msg_{uuid.uuid4()}"
|
||||
self._current_block_index: int = -1
|
||||
# Map item_id -> content_block_index so we can stop the right block later
|
||||
self._item_id_to_block_index: Dict[str, int] = {}
|
||||
# Track open function_call items by item_id so we can emit tool_use start
|
||||
self._pending_tool_ids: Dict[
|
||||
str, str
|
||||
] = {} # item_id -> call_id / name accumulator
|
||||
self._sent_message_start = False
|
||||
self._sent_message_stop = False
|
||||
self._chunk_queue: deque = deque()
|
||||
|
||||
def _make_message_start(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "message_start",
|
||||
"message": {
|
||||
"id": self._message_id,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [],
|
||||
"model": self.model,
|
||||
"stop_reason": None,
|
||||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"cache_creation_input_tokens": 0,
|
||||
"cache_read_input_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _next_block_index(self) -> int:
|
||||
self._current_block_index += 1
|
||||
return self._current_block_index
|
||||
|
||||
def _process_event(self, event: Any) -> None: # noqa: PLR0915
|
||||
"""Convert one Responses API event into zero or more Anthropic chunks queued for emission."""
|
||||
event_type = getattr(event, "type", None)
|
||||
if event_type is None and isinstance(event, dict):
|
||||
event_type = event.get("type")
|
||||
|
||||
if event_type is None:
|
||||
return
|
||||
|
||||
# ---- message_start ----
|
||||
if event_type == "response.created":
|
||||
self._sent_message_start = True
|
||||
self._chunk_queue.append(self._make_message_start())
|
||||
return
|
||||
|
||||
# ---- content_block_start for a new output message item ----
|
||||
if event_type == "response.output_item.added":
|
||||
item = getattr(event, "item", None) or (
|
||||
event.get("item") if isinstance(event, dict) else None
|
||||
)
|
||||
if item is None:
|
||||
return
|
||||
item_type = getattr(item, "type", None) or (
|
||||
item.get("type") if isinstance(item, dict) else None
|
||||
)
|
||||
item_id = getattr(item, "id", None) or (
|
||||
item.get("id") if isinstance(item, dict) else None
|
||||
)
|
||||
|
||||
if item_type == "message":
|
||||
block_idx = self._next_block_index()
|
||||
if item_id:
|
||||
self._item_id_to_block_index[item_id] = block_idx
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_idx,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
)
|
||||
elif item_type == "function_call":
|
||||
call_id = (
|
||||
getattr(item, "call_id", None)
|
||||
or (item.get("call_id") if isinstance(item, dict) else None)
|
||||
or ""
|
||||
)
|
||||
name = (
|
||||
getattr(item, "name", None)
|
||||
or (item.get("name") if isinstance(item, dict) else None)
|
||||
or ""
|
||||
)
|
||||
block_idx = self._next_block_index()
|
||||
if item_id:
|
||||
self._item_id_to_block_index[item_id] = block_idx
|
||||
self._pending_tool_ids[item_id] = call_id
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_idx,
|
||||
"content_block": {
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": name,
|
||||
"input": {},
|
||||
},
|
||||
}
|
||||
)
|
||||
elif item_type == "reasoning":
|
||||
block_idx = self._next_block_index()
|
||||
if item_id:
|
||||
self._item_id_to_block_index[item_id] = block_idx
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": block_idx,
|
||||
"content_block": {"type": "thinking", "thinking": ""},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# ---- text delta ----
|
||||
if event_type == "response.output_text.delta":
|
||||
item_id = getattr(event, "item_id", None) or (
|
||||
event.get("item_id") if isinstance(event, dict) else None
|
||||
)
|
||||
delta = getattr(event, "delta", "") or (
|
||||
event.get("delta", "") if isinstance(event, dict) else ""
|
||||
)
|
||||
block_idx = (
|
||||
self._item_id_to_block_index.get(item_id, self._current_block_index)
|
||||
if item_id
|
||||
else self._current_block_index
|
||||
)
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": block_idx,
|
||||
"delta": {"type": "text_delta", "text": delta},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# ---- reasoning summary text delta ----
|
||||
if event_type == "response.reasoning_summary_text.delta":
|
||||
item_id = getattr(event, "item_id", None) or (
|
||||
event.get("item_id") if isinstance(event, dict) else None
|
||||
)
|
||||
delta = getattr(event, "delta", "") or (
|
||||
event.get("delta", "") if isinstance(event, dict) else ""
|
||||
)
|
||||
block_idx = (
|
||||
self._item_id_to_block_index.get(item_id, self._current_block_index)
|
||||
if item_id
|
||||
else self._current_block_index
|
||||
)
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": block_idx,
|
||||
"delta": {"type": "thinking_delta", "thinking": delta},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# ---- function call arguments delta ----
|
||||
if event_type == "response.function_call_arguments.delta":
|
||||
item_id = getattr(event, "item_id", None) or (
|
||||
event.get("item_id") if isinstance(event, dict) else None
|
||||
)
|
||||
delta = getattr(event, "delta", "") or (
|
||||
event.get("delta", "") if isinstance(event, dict) else ""
|
||||
)
|
||||
block_idx = (
|
||||
self._item_id_to_block_index.get(item_id, self._current_block_index)
|
||||
if item_id
|
||||
else self._current_block_index
|
||||
)
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": block_idx,
|
||||
"delta": {"type": "input_json_delta", "partial_json": delta},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# ---- output item done -> content_block_stop ----
|
||||
if event_type == "response.output_item.done":
|
||||
item = getattr(event, "item", None) or (
|
||||
event.get("item") if isinstance(event, dict) else None
|
||||
)
|
||||
item_id = (
|
||||
getattr(item, "id", None)
|
||||
or (item.get("id") if isinstance(item, dict) else None)
|
||||
if item
|
||||
else None
|
||||
)
|
||||
block_idx = (
|
||||
self._item_id_to_block_index.get(item_id, self._current_block_index)
|
||||
if item_id
|
||||
else self._current_block_index
|
||||
)
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "content_block_stop",
|
||||
"index": block_idx,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# ---- response completed -> message_delta + message_stop ----
|
||||
if event_type in (
|
||||
"response.completed",
|
||||
"response.failed",
|
||||
"response.incomplete",
|
||||
):
|
||||
response_obj = getattr(event, "response", None) or (
|
||||
event.get("response") if isinstance(event, dict) else None
|
||||
)
|
||||
stop_reason = "end_turn"
|
||||
input_tokens = 0
|
||||
output_tokens = 0
|
||||
cache_creation_tokens = 0
|
||||
cache_read_tokens = 0
|
||||
|
||||
if response_obj is not None:
|
||||
status = getattr(response_obj, "status", None)
|
||||
if status == "incomplete":
|
||||
stop_reason = "max_tokens"
|
||||
usage = getattr(response_obj, "usage", None)
|
||||
if usage is not None:
|
||||
input_tokens = getattr(usage, "input_tokens", 0) or 0
|
||||
output_tokens = getattr(usage, "output_tokens", 0) or 0
|
||||
cache_creation_tokens = getattr(usage, "input_tokens_details", None) # type: ignore[assignment]
|
||||
cache_read_tokens = getattr(usage, "output_tokens_details", None) # type: ignore[assignment]
|
||||
# Prefer direct cache fields if present
|
||||
cache_creation_tokens = int(
|
||||
getattr(usage, "cache_creation_input_tokens", 0) or 0
|
||||
)
|
||||
cache_read_tokens = int(
|
||||
getattr(usage, "cache_read_input_tokens", 0) or 0
|
||||
)
|
||||
|
||||
# Check if tool_use was in the output to override stop_reason
|
||||
if response_obj is not None:
|
||||
output = getattr(response_obj, "output", []) or []
|
||||
for out_item in output:
|
||||
out_type = getattr(out_item, "type", None) or (
|
||||
out_item.get("type") if isinstance(out_item, dict) else None
|
||||
)
|
||||
if out_type == "function_call":
|
||||
stop_reason = "tool_use"
|
||||
break
|
||||
|
||||
usage_delta: Dict[str, Any] = {
|
||||
"input_tokens": input_tokens,
|
||||
"output_tokens": output_tokens,
|
||||
}
|
||||
if cache_creation_tokens:
|
||||
usage_delta["cache_creation_input_tokens"] = cache_creation_tokens
|
||||
if cache_read_tokens:
|
||||
usage_delta["cache_read_input_tokens"] = cache_read_tokens
|
||||
|
||||
self._chunk_queue.append(
|
||||
{
|
||||
"type": "message_delta",
|
||||
"delta": {"stop_reason": stop_reason, "stop_sequence": None},
|
||||
"usage": usage_delta,
|
||||
}
|
||||
)
|
||||
self._chunk_queue.append({"type": "message_stop"})
|
||||
self._sent_message_stop = True
|
||||
return
|
||||
|
||||
def __aiter__(self) -> "AnthropicResponsesStreamWrapper":
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> Dict[str, Any]:
|
||||
# Return any queued chunks first
|
||||
if self._chunk_queue:
|
||||
return self._chunk_queue.popleft()
|
||||
|
||||
# Emit message_start if not yet done (fallback if response.created wasn't fired)
|
||||
if not self._sent_message_start:
|
||||
self._sent_message_start = True
|
||||
self._chunk_queue.append(self._make_message_start())
|
||||
return self._chunk_queue.popleft()
|
||||
|
||||
# Consume the upstream stream
|
||||
try:
|
||||
async for event in self.responses_stream:
|
||||
self._process_event(event)
|
||||
if self._chunk_queue:
|
||||
return self._chunk_queue.popleft()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"AnthropicResponsesStreamWrapper error: {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
# Drain any remaining queued chunks
|
||||
if self._chunk_queue:
|
||||
return self._chunk_queue.popleft()
|
||||
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def async_anthropic_sse_wrapper(self) -> AsyncIterator[bytes]:
|
||||
"""Yield SSE-encoded bytes for each Anthropic event chunk."""
|
||||
async for chunk in self:
|
||||
if isinstance(chunk, dict):
|
||||
event_type: str = str(chunk.get("type", "message"))
|
||||
payload = f"event: {event_type}\ndata: {json.dumps(chunk)}\n\n"
|
||||
yield payload.encode()
|
||||
else:
|
||||
yield chunk
|
||||
@@ -0,0 +1,488 @@
|
||||
"""
|
||||
Transformation layer: Anthropic /v1/messages <-> OpenAI Responses API.
|
||||
|
||||
This module owns all format conversions for the direct v1/messages -> Responses API
|
||||
path used for OpenAI and Azure models.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from litellm.types.llms.anthropic import (
|
||||
AllAnthropicToolsValues,
|
||||
AnthopicMessagesAssistantMessageParam,
|
||||
AnthropicFinishReason,
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicMessagesToolChoice,
|
||||
AnthropicMessagesUserMessageParam,
|
||||
AnthropicResponseContentBlockText,
|
||||
AnthropicResponseContentBlockThinking,
|
||||
AnthropicResponseContentBlockToolUse,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
AnthropicUsage,
|
||||
)
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse
|
||||
|
||||
|
||||
class LiteLLMAnthropicToResponsesAPIAdapter:
|
||||
"""
|
||||
Converts Anthropic /v1/messages requests to OpenAI Responses API format and
|
||||
converts Responses API responses back to Anthropic format.
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Request translation: Anthropic -> Responses API #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _translate_anthropic_image_source_to_url(source: dict) -> Optional[str]:
|
||||
"""Convert Anthropic image source to a URL string."""
|
||||
source_type = source.get("type")
|
||||
if source_type == "base64":
|
||||
media_type = source.get("media_type", "image/jpeg")
|
||||
data = source.get("data", "")
|
||||
return f"data:{media_type};base64,{data}" if data else None
|
||||
elif source_type == "url":
|
||||
return source.get("url")
|
||||
return None
|
||||
|
||||
def translate_messages_to_responses_input( # noqa: PLR0915
|
||||
self,
|
||||
messages: List[
|
||||
Union[
|
||||
AnthropicMessagesUserMessageParam,
|
||||
AnthopicMessagesAssistantMessageParam,
|
||||
]
|
||||
],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert Anthropic messages list to Responses API `input` items.
|
||||
|
||||
Mapping:
|
||||
user text -> message(role=user, input_text)
|
||||
user image -> message(role=user, input_image)
|
||||
user tool_result -> function_call_output
|
||||
assistant text -> message(role=assistant, output_text)
|
||||
assistant tool_use -> function_call
|
||||
"""
|
||||
input_items: List[Dict[str, Any]] = []
|
||||
|
||||
for m in messages:
|
||||
role = m["role"]
|
||||
content = m.get("content")
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, str):
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [{"type": "input_text", "text": content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
user_parts: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
btype = block.get("type")
|
||||
if btype == "text":
|
||||
user_parts.append(
|
||||
{"type": "input_text", "text": block.get("text", "")}
|
||||
)
|
||||
elif btype == "image":
|
||||
url = self._translate_anthropic_image_source_to_url(
|
||||
block.get("source", {})
|
||||
)
|
||||
if url:
|
||||
user_parts.append(
|
||||
{"type": "input_image", "image_url": url}
|
||||
)
|
||||
elif btype == "tool_result":
|
||||
tool_use_id = block.get("tool_use_id", "")
|
||||
inner = block.get("content")
|
||||
if inner is None:
|
||||
output_text = ""
|
||||
elif isinstance(inner, str):
|
||||
output_text = inner
|
||||
elif isinstance(inner, list):
|
||||
parts = [
|
||||
c.get("text", "")
|
||||
for c in inner
|
||||
if isinstance(c, dict) and c.get("type") == "text"
|
||||
]
|
||||
output_text = "\n".join(parts)
|
||||
else:
|
||||
output_text = str(inner)
|
||||
# tool_result is a top-level item, not inside the message
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_use_id,
|
||||
"output": output_text,
|
||||
}
|
||||
)
|
||||
if user_parts:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": user_parts,
|
||||
}
|
||||
)
|
||||
|
||||
elif role == "assistant":
|
||||
if isinstance(content, str):
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": content}],
|
||||
}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
asst_parts: List[Dict[str, Any]] = []
|
||||
for block in content:
|
||||
if not isinstance(block, dict):
|
||||
continue
|
||||
btype = block.get("type")
|
||||
if btype == "text":
|
||||
asst_parts.append(
|
||||
{"type": "output_text", "text": block.get("text", "")}
|
||||
)
|
||||
elif btype == "tool_use":
|
||||
# tool_use becomes a top-level function_call item
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": block.get("id", ""),
|
||||
"name": block.get("name", ""),
|
||||
"arguments": json.dumps(block.get("input", {})),
|
||||
}
|
||||
)
|
||||
elif btype == "thinking":
|
||||
thinking_text = block.get("thinking", "")
|
||||
if thinking_text:
|
||||
asst_parts.append(
|
||||
{"type": "output_text", "text": thinking_text}
|
||||
)
|
||||
if asst_parts:
|
||||
input_items.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": asst_parts,
|
||||
}
|
||||
)
|
||||
|
||||
return input_items
|
||||
|
||||
def translate_tools_to_responses_api(
|
||||
self,
|
||||
tools: List[AllAnthropicToolsValues],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic tool definitions to Responses API function tools."""
|
||||
result: List[Dict[str, Any]] = []
|
||||
for tool in tools:
|
||||
tool_dict = cast(Dict[str, Any], tool)
|
||||
tool_type = tool_dict.get("type", "")
|
||||
tool_name = tool_dict.get("name", "")
|
||||
# web_search tool
|
||||
if (
|
||||
isinstance(tool_type, str) and tool_type.startswith("web_search")
|
||||
) or tool_name == "web_search":
|
||||
result.append({"type": "web_search_preview"})
|
||||
continue
|
||||
func_tool: Dict[str, Any] = {"type": "function", "name": tool_name}
|
||||
if "description" in tool_dict:
|
||||
func_tool["description"] = tool_dict["description"]
|
||||
if "input_schema" in tool_dict:
|
||||
func_tool["parameters"] = tool_dict["input_schema"]
|
||||
result.append(func_tool)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def translate_tool_choice_to_responses_api(
|
||||
tool_choice: AnthropicMessagesToolChoice,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convert Anthropic tool_choice to Responses API tool_choice."""
|
||||
tc_type = tool_choice.get("type")
|
||||
if tc_type == "any":
|
||||
return {"type": "required"}
|
||||
elif tc_type == "tool":
|
||||
return {"type": "function", "name": tool_choice.get("name", "")}
|
||||
return {"type": "auto"}
|
||||
|
||||
@staticmethod
|
||||
def translate_context_management_to_responses_api(
|
||||
context_management: Dict[str, Any],
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Convert Anthropic context_management dict to OpenAI Responses API array format.
|
||||
|
||||
Anthropic format: {"edits": [{"type": "compact_20260112", "trigger": {"type": "input_tokens", "value": 150000}}]}
|
||||
OpenAI format: [{"type": "compaction", "compact_threshold": 150000}]
|
||||
"""
|
||||
if not isinstance(context_management, dict):
|
||||
return None
|
||||
|
||||
edits = context_management.get("edits", [])
|
||||
if not isinstance(edits, list):
|
||||
return None
|
||||
|
||||
result: List[Dict[str, Any]] = []
|
||||
for edit in edits:
|
||||
if not isinstance(edit, dict):
|
||||
continue
|
||||
edit_type = edit.get("type", "")
|
||||
if edit_type == "compact_20260112":
|
||||
entry: Dict[str, Any] = {"type": "compaction"}
|
||||
trigger = edit.get("trigger")
|
||||
if isinstance(trigger, dict) and trigger.get("value") is not None:
|
||||
entry["compact_threshold"] = int(trigger["value"])
|
||||
result.append(entry)
|
||||
|
||||
return result if result else None
|
||||
|
||||
@staticmethod
|
||||
def translate_thinking_to_reasoning(
|
||||
thinking: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert Anthropic thinking param to Responses API reasoning param.
|
||||
|
||||
thinking.budget_tokens maps to reasoning effort:
|
||||
>= 10000 -> high, >= 5000 -> medium, >= 2000 -> low, < 2000 -> minimal
|
||||
"""
|
||||
if not isinstance(thinking, dict) or thinking.get("type") != "enabled":
|
||||
return None
|
||||
budget = thinking.get("budget_tokens", 0)
|
||||
if budget >= 10000:
|
||||
effort = "high"
|
||||
elif budget >= 5000:
|
||||
effort = "medium"
|
||||
elif budget >= 2000:
|
||||
effort = "low"
|
||||
else:
|
||||
effort = "minimal"
|
||||
return {"effort": effort, "summary": "detailed"}
|
||||
|
||||
def translate_request(
|
||||
self,
|
||||
anthropic_request: AnthropicMessagesRequest,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Translate a full Anthropic /v1/messages request dict to
|
||||
litellm.responses() / litellm.aresponses() kwargs.
|
||||
"""
|
||||
model: str = anthropic_request["model"]
|
||||
messages_list = cast(
|
||||
List[
|
||||
Union[
|
||||
AnthropicMessagesUserMessageParam,
|
||||
AnthopicMessagesAssistantMessageParam,
|
||||
]
|
||||
],
|
||||
anthropic_request["messages"],
|
||||
)
|
||||
|
||||
responses_kwargs: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": self.translate_messages_to_responses_input(messages_list),
|
||||
}
|
||||
|
||||
# system -> instructions
|
||||
system = anthropic_request.get("system")
|
||||
if system:
|
||||
if isinstance(system, str):
|
||||
responses_kwargs["instructions"] = system
|
||||
elif isinstance(system, list):
|
||||
text_parts = [
|
||||
b.get("text", "")
|
||||
for b in system
|
||||
if isinstance(b, dict) and b.get("type") == "text"
|
||||
]
|
||||
responses_kwargs["instructions"] = "\n".join(filter(None, text_parts))
|
||||
|
||||
# max_tokens -> max_output_tokens
|
||||
max_tokens = anthropic_request.get("max_tokens")
|
||||
if max_tokens:
|
||||
responses_kwargs["max_output_tokens"] = max_tokens
|
||||
|
||||
# temperature / top_p passed through
|
||||
if "temperature" in anthropic_request:
|
||||
responses_kwargs["temperature"] = anthropic_request["temperature"]
|
||||
if "top_p" in anthropic_request:
|
||||
responses_kwargs["top_p"] = anthropic_request["top_p"]
|
||||
|
||||
# tools
|
||||
tools = anthropic_request.get("tools")
|
||||
if tools:
|
||||
responses_kwargs["tools"] = self.translate_tools_to_responses_api(
|
||||
cast(List[AllAnthropicToolsValues], tools)
|
||||
)
|
||||
|
||||
# tool_choice
|
||||
tool_choice = anthropic_request.get("tool_choice")
|
||||
if tool_choice:
|
||||
responses_kwargs[
|
||||
"tool_choice"
|
||||
] = self.translate_tool_choice_to_responses_api(
|
||||
cast(AnthropicMessagesToolChoice, tool_choice)
|
||||
)
|
||||
|
||||
# thinking -> reasoning
|
||||
thinking = anthropic_request.get("thinking")
|
||||
if isinstance(thinking, dict):
|
||||
reasoning = self.translate_thinking_to_reasoning(thinking)
|
||||
if reasoning:
|
||||
responses_kwargs["reasoning"] = reasoning
|
||||
|
||||
# output_format / output_config.format -> text format
|
||||
# output_format: {"type": "json_schema", "schema": {...}}
|
||||
# output_config: {"format": {"type": "json_schema", "schema": {...}}}
|
||||
output_format: Any = anthropic_request.get("output_format")
|
||||
output_config = anthropic_request.get("output_config")
|
||||
if not isinstance(output_format, dict) and isinstance(output_config, dict):
|
||||
output_format = output_config.get("format") # type: ignore[assignment]
|
||||
if (
|
||||
isinstance(output_format, dict)
|
||||
and output_format.get("type") == "json_schema"
|
||||
):
|
||||
schema = output_format.get("schema")
|
||||
if schema:
|
||||
responses_kwargs["text"] = {
|
||||
"format": {
|
||||
"type": "json_schema",
|
||||
"name": "structured_output",
|
||||
"schema": schema,
|
||||
"strict": True,
|
||||
}
|
||||
}
|
||||
|
||||
# context_management: Anthropic dict -> OpenAI array
|
||||
context_management = anthropic_request.get("context_management")
|
||||
if isinstance(context_management, dict):
|
||||
openai_cm = self.translate_context_management_to_responses_api(
|
||||
context_management
|
||||
)
|
||||
if openai_cm is not None:
|
||||
responses_kwargs["context_management"] = openai_cm
|
||||
|
||||
# metadata user_id -> user
|
||||
metadata = anthropic_request.get("metadata")
|
||||
if isinstance(metadata, dict) and "user_id" in metadata:
|
||||
responses_kwargs["user"] = str(metadata["user_id"])[:64]
|
||||
|
||||
return responses_kwargs
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Response translation: Responses API -> Anthropic #
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def translate_response(
|
||||
self,
|
||||
response: ResponsesAPIResponse,
|
||||
) -> AnthropicMessagesResponse:
|
||||
"""
|
||||
Translate an OpenAI ResponsesAPIResponse to AnthropicMessagesResponse.
|
||||
"""
|
||||
from openai.types.responses import (
|
||||
ResponseFunctionToolCall,
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
|
||||
from litellm.types.llms.openai import ResponseAPIUsage
|
||||
|
||||
content: List[Dict[str, Any]] = []
|
||||
stop_reason: AnthropicFinishReason = "end_turn"
|
||||
|
||||
for item in response.output:
|
||||
if isinstance(item, ResponseReasoningItem):
|
||||
for summary in item.summary:
|
||||
text = getattr(summary, "text", "")
|
||||
if text:
|
||||
content.append(
|
||||
AnthropicResponseContentBlockThinking(
|
||||
type="thinking",
|
||||
thinking=text,
|
||||
signature=None,
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
elif isinstance(item, ResponseOutputMessage):
|
||||
for part in item.content:
|
||||
if getattr(part, "type", None) == "output_text":
|
||||
content.append(
|
||||
AnthropicResponseContentBlockText(
|
||||
type="text", text=getattr(part, "text", "")
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
elif isinstance(item, ResponseFunctionToolCall):
|
||||
try:
|
||||
input_data = json.loads(item.arguments) if item.arguments else {}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
input_data = {}
|
||||
content.append(
|
||||
AnthropicResponseContentBlockToolUse(
|
||||
type="tool_use",
|
||||
id=item.call_id or item.id or "",
|
||||
name=item.name,
|
||||
input=input_data,
|
||||
).model_dump()
|
||||
)
|
||||
stop_reason = "tool_use"
|
||||
|
||||
elif isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "message":
|
||||
for part in item.get("content", []):
|
||||
if isinstance(part, dict) and part.get("type") == "output_text":
|
||||
content.append(
|
||||
AnthropicResponseContentBlockText(
|
||||
type="text", text=part.get("text", "")
|
||||
).model_dump()
|
||||
)
|
||||
elif item_type == "function_call":
|
||||
try:
|
||||
input_data = json.loads(item.get("arguments", "{}"))
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
input_data = {}
|
||||
content.append(
|
||||
AnthropicResponseContentBlockToolUse(
|
||||
type="tool_use",
|
||||
id=item.get("call_id") or item.get("id", ""),
|
||||
name=item.get("name", ""),
|
||||
input=input_data,
|
||||
).model_dump()
|
||||
)
|
||||
stop_reason = "tool_use"
|
||||
|
||||
# status -> stop_reason override
|
||||
if response.status == "incomplete":
|
||||
stop_reason = "max_tokens"
|
||||
|
||||
# usage
|
||||
raw_usage: Optional[ResponseAPIUsage] = response.usage
|
||||
input_tokens = int(getattr(raw_usage, "input_tokens", 0) or 0)
|
||||
output_tokens = int(getattr(raw_usage, "output_tokens", 0) or 0)
|
||||
|
||||
anthropic_usage = AnthropicUsage(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
return AnthropicMessagesResponse(
|
||||
id=response.id,
|
||||
type="message",
|
||||
role="assistant",
|
||||
model=response.model or "unknown-model",
|
||||
stop_sequence=None,
|
||||
usage=anthropic_usage, # type: ignore
|
||||
content=content, # type: ignore
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
from .handler import AnthropicFilesHandler
|
||||
from .transformation import AnthropicFilesConfig
|
||||
|
||||
__all__ = ["AnthropicFilesHandler", "AnthropicFilesConfig"]
|
||||
@@ -0,0 +1,366 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.openai import (
|
||||
FileContentRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIBatchResult,
|
||||
OpenAIChatCompletionResponse,
|
||||
OpenAIErrorBody,
|
||||
)
|
||||
from litellm.types.utils import CallTypes, LlmProviders, ModelResponse
|
||||
|
||||
from ..chat.transformation import AnthropicConfig
|
||||
from ..common_utils import AnthropicModelInfo
|
||||
|
||||
# Map Anthropic error types to HTTP status codes
|
||||
ANTHROPIC_ERROR_STATUS_CODE_MAP = {
|
||||
"invalid_request_error": 400,
|
||||
"authentication_error": 401,
|
||||
"permission_error": 403,
|
||||
"not_found_error": 404,
|
||||
"rate_limit_error": 429,
|
||||
"api_error": 500,
|
||||
"overloaded_error": 503,
|
||||
"timeout_error": 504,
|
||||
}
|
||||
|
||||
|
||||
class AnthropicFilesHandler:
|
||||
"""
|
||||
Handles Anthropic Files API operations.
|
||||
|
||||
Currently supports:
|
||||
- file_content() for retrieving Anthropic Message Batch results
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.anthropic_model_info = AnthropicModelInfo()
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600.0,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Async: Retrieve file content from Anthropic.
|
||||
|
||||
For batch results, the file_id should be the batch_id.
|
||||
This will call Anthropic's /v1/messages/batches/{batch_id}/results endpoint.
|
||||
|
||||
Args:
|
||||
file_content_request: Contains file_id (batch_id for batch results)
|
||||
api_base: Anthropic API base URL
|
||||
api_key: Anthropic API key
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts (unused for now)
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
|
||||
"""
|
||||
file_id = file_content_request.get("file_id")
|
||||
if not file_id:
|
||||
raise ValueError("file_id is required in file_content_request")
|
||||
|
||||
# Extract batch_id from file_id
|
||||
# Handle both formats: "anthropic_batch_results:{batch_id}" or just "{batch_id}"
|
||||
if file_id.startswith("anthropic_batch_results:"):
|
||||
batch_id = file_id.replace("anthropic_batch_results:", "", 1)
|
||||
else:
|
||||
batch_id = file_id
|
||||
|
||||
# Get Anthropic API credentials
|
||||
api_base = self.anthropic_model_info.get_api_base(api_base)
|
||||
api_key = api_key or self.anthropic_model_info.get_api_key()
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("Missing Anthropic API Key")
|
||||
|
||||
# Construct the Anthropic batch results URL
|
||||
results_url = f"{api_base.rstrip('/')}/v1/messages/batches/{batch_id}/results"
|
||||
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"accept": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"x-api-key": api_key,
|
||||
}
|
||||
|
||||
# Make the request to Anthropic
|
||||
async_client = get_async_httpx_client(llm_provider=LlmProviders.ANTHROPIC)
|
||||
anthropic_response = await async_client.get(url=results_url, headers=headers)
|
||||
anthropic_response.raise_for_status()
|
||||
|
||||
# Transform Anthropic batch results to OpenAI format
|
||||
transformed_content = self._transform_anthropic_batch_results_to_openai_format(
|
||||
anthropic_response.content
|
||||
)
|
||||
|
||||
# Create a new response with transformed content
|
||||
transformed_response = httpx.Response(
|
||||
status_code=anthropic_response.status_code,
|
||||
headers=anthropic_response.headers,
|
||||
content=transformed_content,
|
||||
request=anthropic_response.request,
|
||||
)
|
||||
|
||||
# Return the transformed response content
|
||||
return HttpxBinaryResponseContent(response=transformed_response)
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Union[float, httpx.Timeout] = 600.0,
|
||||
max_retries: Optional[int] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
"""
|
||||
Retrieve file content from Anthropic.
|
||||
|
||||
For batch results, the file_id should be the batch_id.
|
||||
This will call Anthropic's /v1/messages/batches/{batch_id}/results endpoint.
|
||||
|
||||
Args:
|
||||
_is_async: Whether to run asynchronously
|
||||
file_content_request: Contains file_id (batch_id for batch results)
|
||||
api_base: Anthropic API base URL
|
||||
api_key: Anthropic API key
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts (unused for now)
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
|
||||
"""
|
||||
if _is_async:
|
||||
return self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
|
||||
def _transform_anthropic_batch_results_to_openai_format(
|
||||
self, anthropic_content: bytes
|
||||
) -> bytes:
|
||||
"""
|
||||
Transform Anthropic batch results JSONL to OpenAI batch results JSONL format.
|
||||
|
||||
Anthropic format:
|
||||
{
|
||||
"custom_id": "...",
|
||||
"result": {
|
||||
"type": "succeeded",
|
||||
"message": { ... } // Anthropic message format
|
||||
}
|
||||
}
|
||||
|
||||
OpenAI format:
|
||||
{
|
||||
"custom_id": "...",
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": "...",
|
||||
"body": { ... } // OpenAI chat completion format
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
anthropic_config = AnthropicConfig()
|
||||
transformed_lines = []
|
||||
|
||||
# Parse JSONL content
|
||||
content_str = anthropic_content.decode("utf-8")
|
||||
for line in content_str.strip().split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
anthropic_result = json.loads(line)
|
||||
custom_id = anthropic_result.get("custom_id", "")
|
||||
result = anthropic_result.get("result", {})
|
||||
result_type = result.get("type", "")
|
||||
|
||||
# Transform based on result type
|
||||
if result_type == "succeeded":
|
||||
# Transform Anthropic message to OpenAI format
|
||||
anthropic_message = result.get("message", {})
|
||||
if anthropic_message:
|
||||
openai_response_body = (
|
||||
self._transform_anthropic_message_to_openai_format(
|
||||
anthropic_message=anthropic_message,
|
||||
anthropic_config=anthropic_config,
|
||||
)
|
||||
)
|
||||
|
||||
# Create OpenAI batch result format
|
||||
openai_result: OpenAIBatchResult = {
|
||||
"custom_id": custom_id,
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": anthropic_message.get("id", ""),
|
||||
"body": openai_response_body,
|
||||
},
|
||||
}
|
||||
transformed_lines.append(json.dumps(openai_result))
|
||||
elif result_type == "errored":
|
||||
# Handle error case
|
||||
error = result.get("error", {})
|
||||
error_obj = error.get("error", {})
|
||||
error_message = error_obj.get("message", "Unknown error")
|
||||
error_type = error_obj.get("type", "api_error")
|
||||
|
||||
status_code = ANTHROPIC_ERROR_STATUS_CODE_MAP.get(error_type, 500)
|
||||
|
||||
error_body_errored: OpenAIErrorBody = {
|
||||
"error": {
|
||||
"message": error_message,
|
||||
"type": error_type,
|
||||
}
|
||||
}
|
||||
openai_result_errored: OpenAIBatchResult = {
|
||||
"custom_id": custom_id,
|
||||
"response": {
|
||||
"status_code": status_code,
|
||||
"request_id": error.get("request_id", ""),
|
||||
"body": error_body_errored,
|
||||
},
|
||||
}
|
||||
transformed_lines.append(json.dumps(openai_result_errored))
|
||||
elif result_type in ["canceled", "expired"]:
|
||||
# Handle canceled/expired cases
|
||||
error_body_canceled: OpenAIErrorBody = {
|
||||
"error": {
|
||||
"message": f"Batch request was {result_type}",
|
||||
"type": "invalid_request_error",
|
||||
}
|
||||
}
|
||||
openai_result_canceled: OpenAIBatchResult = {
|
||||
"custom_id": custom_id,
|
||||
"response": {
|
||||
"status_code": 400,
|
||||
"request_id": "",
|
||||
"body": error_body_canceled,
|
||||
},
|
||||
}
|
||||
transformed_lines.append(json.dumps(openai_result_canceled))
|
||||
|
||||
# Join lines and encode back to bytes
|
||||
transformed_content = "\n".join(transformed_lines)
|
||||
if transformed_lines:
|
||||
transformed_content += "\n" # Add trailing newline for JSONL format
|
||||
return transformed_content.encode("utf-8")
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error transforming Anthropic batch results to OpenAI format: {e}"
|
||||
)
|
||||
# Return original content if transformation fails
|
||||
return anthropic_content
|
||||
|
||||
def _transform_anthropic_message_to_openai_format(
|
||||
self, anthropic_message: dict, anthropic_config: AnthropicConfig
|
||||
) -> OpenAIChatCompletionResponse:
|
||||
"""
|
||||
Transform a single Anthropic message to OpenAI chat completion format.
|
||||
"""
|
||||
try:
|
||||
# Create a mock httpx.Response for transformation
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=json.dumps(anthropic_message).encode("utf-8"),
|
||||
)
|
||||
|
||||
# Create a ModelResponse object
|
||||
model_response = ModelResponse()
|
||||
# Initialize with required fields - will be populated by transform_parsed_response
|
||||
model_response.choices = [
|
||||
litellm.Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=litellm.Message(content="", role="assistant"),
|
||||
)
|
||||
] # type: ignore
|
||||
|
||||
# Create a logging object for transformation
|
||||
logging_obj = Logging(
|
||||
model=anthropic_message.get("model", "claude-3-5-sonnet-20241022"),
|
||||
messages=[{"role": "user", "content": "batch_request"}],
|
||||
stream=False,
|
||||
call_type=CallTypes.aretrieve_batch,
|
||||
start_time=time.time(),
|
||||
litellm_call_id="batch_" + str(uuid.uuid4()),
|
||||
function_id="batch_processing",
|
||||
litellm_trace_id=str(uuid.uuid4()),
|
||||
kwargs={"optional_params": {}},
|
||||
)
|
||||
logging_obj.optional_params = {}
|
||||
|
||||
# Transform using AnthropicConfig
|
||||
transformed_response = anthropic_config.transform_parsed_response(
|
||||
completion_response=anthropic_message,
|
||||
raw_response=mock_response,
|
||||
model_response=model_response,
|
||||
json_mode=False,
|
||||
prefix_prompt=None,
|
||||
)
|
||||
|
||||
# Convert ModelResponse to OpenAI format dict - it's already in OpenAI format
|
||||
openai_body: OpenAIChatCompletionResponse = transformed_response.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
# Ensure id comes from anthropic_message if not set
|
||||
if not openai_body.get("id"):
|
||||
openai_body["id"] = anthropic_message.get("id", "")
|
||||
|
||||
return openai_body
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error transforming Anthropic message to OpenAI format: {e}"
|
||||
)
|
||||
# Return a basic error response if transformation fails
|
||||
error_response: OpenAIChatCompletionResponse = {
|
||||
"id": anthropic_message.get("id", ""),
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": anthropic_message.get("model", ""),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": ""},
|
||||
"finish_reason": "error",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
return error_response
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
Anthropic Files API transformation config.
|
||||
|
||||
Implements BaseFilesConfig for Anthropic's Files API (beta).
|
||||
Reference: https://docs.anthropic.com/en/docs/build-with-claude/files
|
||||
|
||||
Anthropic Files API endpoints:
|
||||
- POST /v1/files - Upload a file
|
||||
- GET /v1/files - List files
|
||||
- GET /v1/files/{file_id} - Retrieve file metadata
|
||||
- DELETE /v1/files/{file_id} - Delete a file
|
||||
- GET /v1/files/{file_id}/content - Download file content
|
||||
"""
|
||||
|
||||
import calendar
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from ..common_utils import AnthropicError, AnthropicModelInfo
|
||||
|
||||
ANTHROPIC_FILES_API_BASE = "https://api.anthropic.com"
|
||||
ANTHROPIC_FILES_BETA_HEADER = "files-api-2025-04-14"
|
||||
|
||||
|
||||
class AnthropicFilesConfig(BaseFilesConfig):
|
||||
"""
|
||||
Transformation config for Anthropic Files API.
|
||||
|
||||
Anthropic uses:
|
||||
- x-api-key header for authentication
|
||||
- anthropic-beta: files-api-2025-04-14 header
|
||||
- multipart/form-data for file uploads
|
||||
- purpose="messages" (Anthropic-specific, not for batches/fine-tuning)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.ANTHROPIC
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
api_base = AnthropicModelInfo.get_api_base(api_base) or ANTHROPIC_FILES_API_BASE
|
||||
return f"{api_base.rstrip('/')}/v1/files"
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: Union[dict, httpx.Headers],
|
||||
) -> BaseLLMException:
|
||||
return AnthropicError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers) if isinstance(headers, dict) else headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: list,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = AnthropicModelInfo.get_api_key(api_key)
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Anthropic API key is required. Set ANTHROPIC_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"x-api-key": api_key,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": ANTHROPIC_FILES_BETA_HEADER,
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return ["purpose"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform to multipart form data for Anthropic file upload.
|
||||
|
||||
Anthropic expects: POST /v1/files with multipart form-data
|
||||
- file: the file content
|
||||
- purpose: "messages" (defaults to "messages" if not provided)
|
||||
"""
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("File data is required")
|
||||
|
||||
extracted = extract_file_data(file_data)
|
||||
filename = extracted["filename"] or f"file_{int(time.time())}"
|
||||
content = extracted["content"]
|
||||
content_type = extracted.get("content_type", "application/octet-stream")
|
||||
|
||||
purpose = create_file_data.get("purpose", "messages")
|
||||
|
||||
return {
|
||||
"file": (filename, content, content_type),
|
||||
"purpose": (None, purpose),
|
||||
}
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform Anthropic file response to OpenAI format.
|
||||
|
||||
Anthropic response:
|
||||
{
|
||||
"id": "file-xxx",
|
||||
"type": "file",
|
||||
"filename": "document.pdf",
|
||||
"mime_type": "application/pdf",
|
||||
"size_bytes": 12345,
|
||||
"created_at": "2025-01-01T00:00:00Z"
|
||||
}
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
return self._parse_anthropic_file(response_json)
|
||||
|
||||
def transform_retrieve_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
api_base = (
|
||||
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
|
||||
or ANTHROPIC_FILES_API_BASE
|
||||
)
|
||||
return f"{api_base.rstrip('/')}/v1/files/{file_id}", {}
|
||||
|
||||
def transform_retrieve_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
response_json = raw_response.json()
|
||||
return self._parse_anthropic_file(response_json)
|
||||
|
||||
def transform_delete_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
api_base = (
|
||||
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
|
||||
or ANTHROPIC_FILES_API_BASE
|
||||
)
|
||||
return f"{api_base.rstrip('/')}/v1/files/{file_id}", {}
|
||||
|
||||
def transform_delete_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> FileDeleted:
|
||||
response_json = raw_response.json()
|
||||
file_id = response_json.get("id", "")
|
||||
return FileDeleted(
|
||||
id=file_id,
|
||||
deleted=True,
|
||||
object="file",
|
||||
)
|
||||
|
||||
def transform_list_files_request(
|
||||
self,
|
||||
purpose: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
api_base = (
|
||||
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
|
||||
or ANTHROPIC_FILES_API_BASE
|
||||
)
|
||||
url = f"{api_base.rstrip('/')}/v1/files"
|
||||
params: Dict[str, Any] = {}
|
||||
if purpose:
|
||||
params["purpose"] = purpose
|
||||
return url, params
|
||||
|
||||
def transform_list_files_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
"""
|
||||
Anthropic list response:
|
||||
{
|
||||
"data": [...],
|
||||
"has_more": false,
|
||||
"first_id": "...",
|
||||
"last_id": "..."
|
||||
}
|
||||
"""
|
||||
response_json = raw_response.json()
|
||||
files_data = response_json.get("data", [])
|
||||
return [self._parse_anthropic_file(f) for f in files_data]
|
||||
|
||||
def transform_file_content_request(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
file_id = file_content_request.get("file_id")
|
||||
api_base = (
|
||||
AnthropicModelInfo.get_api_base(litellm_params.get("api_base"))
|
||||
or ANTHROPIC_FILES_API_BASE
|
||||
)
|
||||
return f"{api_base.rstrip('/')}/v1/files/{file_id}/content", {}
|
||||
|
||||
def transform_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
return HttpxBinaryResponseContent(response=raw_response)
|
||||
|
||||
@staticmethod
|
||||
def _parse_anthropic_file(file_data: dict) -> OpenAIFileObject:
|
||||
"""Parse Anthropic file object into OpenAI format."""
|
||||
created_at_str = file_data.get("created_at", "")
|
||||
if created_at_str:
|
||||
try:
|
||||
created_at = int(
|
||||
calendar.timegm(
|
||||
time.strptime(
|
||||
created_at_str.replace("Z", "+00:00")[:19],
|
||||
"%Y-%m-%dT%H:%M:%S",
|
||||
)
|
||||
)
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
created_at = int(time.time())
|
||||
else:
|
||||
created_at = int(time.time())
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_data.get("id", ""),
|
||||
bytes=file_data.get("size_bytes", file_data.get("bytes", 0)),
|
||||
created_at=created_at,
|
||||
filename=file_data.get("filename", ""),
|
||||
object="file",
|
||||
purpose=file_data.get("purpose", "messages"),
|
||||
status="uploaded",
|
||||
status_details=None,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Anthropic Skills API integration"""
|
||||
|
||||
from .transformation import AnthropicSkillsConfig
|
||||
|
||||
__all__ = ["AnthropicSkillsConfig"]
|
||||
@@ -0,0 +1,279 @@
|
||||
# Anthropic Skills API Integration
|
||||
|
||||
This module provides comprehensive support for the Anthropic Skills API through LiteLLM.
|
||||
|
||||
## Features
|
||||
|
||||
The Skills API allows you to:
|
||||
- **Create skills**: Define reusable AI capabilities
|
||||
- **List skills**: Browse all available skills
|
||||
- **Get skills**: Retrieve detailed information about a specific skill
|
||||
- **Delete skills**: Remove skills that are no longer needed
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
Set your Anthropic API key:
|
||||
```python
|
||||
import os
|
||||
os.environ["ANTHROPIC_API_KEY"] = "your-api-key-here"
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
|
||||
#### Create a Skill
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
# Create a skill with files
|
||||
# Note: All files must be in the same top-level directory
|
||||
# and must include a SKILL.md file at the root
|
||||
skill = litellm.create_skill(
|
||||
files=[
|
||||
# List of file objects to upload
|
||||
# Must include SKILL.md
|
||||
],
|
||||
display_title="Python Code Generator",
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
print(f"Created skill: {skill.id}")
|
||||
|
||||
# Asynchronous version
|
||||
skill = await litellm.acreate_skill(
|
||||
files=[...], # Your files here
|
||||
display_title="Python Code Generator",
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
```
|
||||
|
||||
#### List Skills
|
||||
|
||||
```python
|
||||
# List all skills
|
||||
skills = litellm.list_skills(
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
for skill in skills.data:
|
||||
print(f"{skill.display_title}: {skill.id}")
|
||||
|
||||
# With pagination and filtering
|
||||
skills = litellm.list_skills(
|
||||
limit=20,
|
||||
source="custom", # Filter by 'custom' or 'anthropic'
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
# Get next page if available
|
||||
if skills.has_more:
|
||||
next_page = litellm.list_skills(
|
||||
page=skills.next_page,
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
```
|
||||
|
||||
#### Get a Skill
|
||||
|
||||
```python
|
||||
skill = litellm.get_skill(
|
||||
skill_id="skill_abc123",
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
print(f"Skill: {skill.display_title}")
|
||||
print(f"Created: {skill.created_at}")
|
||||
print(f"Latest version: {skill.latest_version}")
|
||||
print(f"Source: {skill.source}")
|
||||
```
|
||||
|
||||
#### Delete a Skill
|
||||
|
||||
```python
|
||||
result = litellm.delete_skill(
|
||||
skill_id="skill_abc123",
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
|
||||
print(f"Deleted skill {result.id}, type: {result.type}")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### `create_skill()`
|
||||
|
||||
Create a new skill.
|
||||
|
||||
**Parameters:**
|
||||
- `files` (List[Any], optional): Files to upload for the skill. All files must be in the same top-level directory and must include a SKILL.md file at the root.
|
||||
- `display_title` (str, optional): Display title for the skill
|
||||
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
|
||||
- `extra_headers` (dict, optional): Additional HTTP headers
|
||||
- `timeout` (float, optional): Request timeout
|
||||
|
||||
**Returns:**
|
||||
- `Skill`: The created skill object
|
||||
|
||||
**Async version:** `acreate_skill()`
|
||||
|
||||
### `list_skills()`
|
||||
|
||||
List all skills.
|
||||
|
||||
**Parameters:**
|
||||
- `limit` (int, optional): Number of results to return per page (max 100, default 20)
|
||||
- `page` (str, optional): Pagination token for fetching a specific page of results
|
||||
- `source` (str, optional): Filter skills by source ('custom' or 'anthropic')
|
||||
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
|
||||
- `extra_headers` (dict, optional): Additional HTTP headers
|
||||
- `timeout` (float, optional): Request timeout
|
||||
|
||||
**Returns:**
|
||||
- `ListSkillsResponse`: Object containing a list of skills and pagination info
|
||||
|
||||
**Async version:** `alist_skills()`
|
||||
|
||||
### `get_skill()`
|
||||
|
||||
Get a specific skill by ID.
|
||||
|
||||
**Parameters:**
|
||||
- `skill_id` (str, required): The skill ID
|
||||
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
|
||||
- `extra_headers` (dict, optional): Additional HTTP headers
|
||||
- `timeout` (float, optional): Request timeout
|
||||
|
||||
**Returns:**
|
||||
- `Skill`: The requested skill object
|
||||
|
||||
**Async version:** `aget_skill()`
|
||||
|
||||
### `delete_skill()`
|
||||
|
||||
Delete a skill.
|
||||
|
||||
**Parameters:**
|
||||
- `skill_id` (str, required): The skill ID to delete
|
||||
- `custom_llm_provider` (str, optional): Provider name (default: "anthropic")
|
||||
- `extra_headers` (dict, optional): Additional HTTP headers
|
||||
- `timeout` (float, optional): Request timeout
|
||||
|
||||
**Returns:**
|
||||
- `DeleteSkillResponse`: Object with `id` and `type` fields
|
||||
|
||||
**Async version:** `adelete_skill()`
|
||||
|
||||
## Response Types
|
||||
|
||||
### `Skill`
|
||||
|
||||
Represents a skill from the Anthropic Skills API.
|
||||
|
||||
**Fields:**
|
||||
- `id` (str): Unique identifier
|
||||
- `created_at` (str): ISO 8601 timestamp
|
||||
- `display_title` (str, optional): Display title
|
||||
- `latest_version` (str, optional): Latest version identifier
|
||||
- `source` (str): Source ("custom" or "anthropic")
|
||||
- `type` (str): Object type (always "skill")
|
||||
- `updated_at` (str): ISO 8601 timestamp
|
||||
|
||||
### `ListSkillsResponse`
|
||||
|
||||
Response from listing skills.
|
||||
|
||||
**Fields:**
|
||||
- `data` (List[Skill]): List of skills
|
||||
- `next_page` (str, optional): Pagination token for the next page
|
||||
- `has_more` (bool): Whether more skills are available
|
||||
|
||||
### `DeleteSkillResponse`
|
||||
|
||||
Response from deleting a skill.
|
||||
|
||||
**Fields:**
|
||||
- `id` (str): The deleted skill ID
|
||||
- `type` (str): Deleted object type (always "skill_deleted")
|
||||
|
||||
## Architecture
|
||||
|
||||
The Skills API implementation follows LiteLLM's standard patterns:
|
||||
|
||||
1. **Type Definitions** (`litellm/types/llms/anthropic_skills.py`)
|
||||
- Pydantic models for request/response types
|
||||
- TypedDict definitions for request parameters
|
||||
|
||||
2. **Base Configuration** (`litellm/llms/base_llm/skills/transformation.py`)
|
||||
- Abstract base class `BaseSkillsAPIConfig`
|
||||
- Defines transformation interface for provider-specific implementations
|
||||
|
||||
3. **Provider Implementation** (`litellm/llms/anthropic/skills/transformation.py`)
|
||||
- `AnthropicSkillsConfig` - Anthropic-specific transformations
|
||||
- Handles API authentication, URL construction, and response mapping
|
||||
|
||||
4. **Main Handler** (`litellm/skills/main.py`)
|
||||
- Public API functions (sync and async)
|
||||
- Request validation and routing
|
||||
- Error handling
|
||||
|
||||
5. **HTTP Handlers** (`litellm/llms/custom_httpx/llm_http_handler.py`)
|
||||
- Low-level HTTP request/response handling
|
||||
- Connection pooling and retry logic
|
||||
|
||||
## Beta API Support
|
||||
|
||||
The Skills API is in beta. The beta header (`skills-2025-10-02`) is automatically added by the Anthropic provider configuration. You can customize it if needed:
|
||||
|
||||
```python
|
||||
skill = litellm.create_skill(
|
||||
display_title="My Skill",
|
||||
extra_headers={
|
||||
"anthropic-beta": "skills-2025-10-02" # Or any other beta version
|
||||
},
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
```
|
||||
|
||||
The default beta version is configured in `litellm.constants.ANTHROPIC_SKILLS_API_BETA_VERSION`.
|
||||
|
||||
## Error Handling
|
||||
|
||||
All Skills API functions follow LiteLLM's standard error handling:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
try:
|
||||
skill = litellm.create_skill(
|
||||
display_title="My Skill",
|
||||
custom_llm_provider="anthropic"
|
||||
)
|
||||
except litellm.exceptions.AuthenticationError as e:
|
||||
print(f"Authentication failed: {e}")
|
||||
except litellm.exceptions.RateLimitError as e:
|
||||
print(f"Rate limit exceeded: {e}")
|
||||
except litellm.exceptions.APIError as e:
|
||||
print(f"API error: {e}")
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
To add support for Skills API to a new provider:
|
||||
|
||||
1. Create provider-specific configuration class inheriting from `BaseSkillsAPIConfig`
|
||||
2. Implement all abstract methods for request/response transformations
|
||||
3. Register the config in `ProviderConfigManager.get_provider_skills_api_config()`
|
||||
4. Add appropriate tests
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [Anthropic Skills API Documentation](https://platform.claude.com/docs/en/api/beta/skills/create)
|
||||
- [LiteLLM Responses API](../../../responses/)
|
||||
- [Provider Configuration System](../../base_llm/)
|
||||
|
||||
## Support
|
||||
|
||||
For issues or questions:
|
||||
- GitHub Issues: https://github.com/BerriAI/litellm/issues
|
||||
- Discord: https://discord.gg/wuPM9dRgDw
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Anthropic Skills API configuration and transformations
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.skills.transformation import (
|
||||
BaseSkillsAPIConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.anthropic_skills import (
|
||||
CreateSkillRequest,
|
||||
DeleteSkillResponse,
|
||||
ListSkillsParams,
|
||||
ListSkillsResponse,
|
||||
Skill,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
class AnthropicSkillsConfig(BaseSkillsAPIConfig):
|
||||
"""Anthropic-specific Skills API configuration"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.ANTHROPIC
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""Add Anthropic-specific headers"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
|
||||
# Get API key
|
||||
api_key = None
|
||||
if litellm_params:
|
||||
api_key = litellm_params.api_key
|
||||
api_key = AnthropicModelInfo.get_api_key(api_key)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required for Skills API")
|
||||
|
||||
# Add required headers
|
||||
headers["x-api-key"] = api_key
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
# Add beta header for skills API
|
||||
from litellm.constants import ANTHROPIC_SKILLS_API_BETA_VERSION
|
||||
|
||||
if "anthropic-beta" not in headers:
|
||||
headers["anthropic-beta"] = ANTHROPIC_SKILLS_API_BETA_VERSION
|
||||
elif isinstance(headers["anthropic-beta"], list):
|
||||
if ANTHROPIC_SKILLS_API_BETA_VERSION not in headers["anthropic-beta"]:
|
||||
headers["anthropic-beta"].append(ANTHROPIC_SKILLS_API_BETA_VERSION)
|
||||
elif isinstance(headers["anthropic-beta"], str):
|
||||
if ANTHROPIC_SKILLS_API_BETA_VERSION not in headers["anthropic-beta"]:
|
||||
headers["anthropic-beta"] = [
|
||||
headers["anthropic-beta"],
|
||||
ANTHROPIC_SKILLS_API_BETA_VERSION,
|
||||
]
|
||||
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
endpoint: str,
|
||||
skill_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get complete URL for Anthropic Skills API"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
|
||||
if api_base is None:
|
||||
api_base = AnthropicModelInfo.get_api_base()
|
||||
|
||||
if skill_id:
|
||||
return f"{api_base}/v1/skills/{skill_id}"
|
||||
return f"{api_base}/v1/{endpoint}"
|
||||
|
||||
def transform_create_skill_request(
|
||||
self,
|
||||
create_request: CreateSkillRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""Transform create skill request for Anthropic"""
|
||||
verbose_logger.debug("Transforming create skill request: %s", create_request)
|
||||
|
||||
# Anthropic expects the request body directly
|
||||
request_body = {k: v for k, v in create_request.items() if v is not None}
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_create_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Skill:
|
||||
"""Transform Anthropic response to Skill object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming create skill response: %s", response_json)
|
||||
|
||||
return Skill(**response_json)
|
||||
|
||||
def transform_list_skills_request(
|
||||
self,
|
||||
list_params: ListSkillsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform list skills request for Anthropic"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
|
||||
api_base = AnthropicModelInfo.get_api_base(
|
||||
litellm_params.api_base if litellm_params else None
|
||||
)
|
||||
url = self.get_complete_url(api_base=api_base, endpoint="skills")
|
||||
|
||||
# Build query parameters
|
||||
query_params: Dict[str, Any] = {}
|
||||
if "limit" in list_params and list_params["limit"]:
|
||||
query_params["limit"] = list_params["limit"]
|
||||
if "page" in list_params and list_params["page"]:
|
||||
query_params["page"] = list_params["page"]
|
||||
if "source" in list_params and list_params["source"]:
|
||||
query_params["source"] = list_params["source"]
|
||||
|
||||
verbose_logger.debug(
|
||||
"List skills request made to Anthropic Skills endpoint with params: %s",
|
||||
query_params,
|
||||
)
|
||||
|
||||
return url, query_params
|
||||
|
||||
def transform_list_skills_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListSkillsResponse:
|
||||
"""Transform Anthropic response to ListSkillsResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming list skills response: %s", response_json)
|
||||
|
||||
return ListSkillsResponse(**response_json)
|
||||
|
||||
def transform_get_skill_request(
|
||||
self,
|
||||
skill_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform get skill request for Anthropic"""
|
||||
url = self.get_complete_url(
|
||||
api_base=api_base, endpoint="skills", skill_id=skill_id
|
||||
)
|
||||
|
||||
verbose_logger.debug("Get skill request - URL: %s", url)
|
||||
|
||||
return url, headers
|
||||
|
||||
def transform_get_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Skill:
|
||||
"""Transform Anthropic response to Skill object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming get skill response: %s", response_json)
|
||||
|
||||
return Skill(**response_json)
|
||||
|
||||
def transform_delete_skill_request(
|
||||
self,
|
||||
skill_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform delete skill request for Anthropic"""
|
||||
url = self.get_complete_url(
|
||||
api_base=api_base, endpoint="skills", skill_id=skill_id
|
||||
)
|
||||
|
||||
verbose_logger.debug("Delete skill request - URL: %s", url)
|
||||
|
||||
return url, headers
|
||||
|
||||
def transform_delete_skill_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteSkillResponse:
|
||||
"""Transform Anthropic response to DeleteSkillResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming delete skill response: %s", response_json)
|
||||
|
||||
return DeleteSkillResponse(**response_json)
|
||||
@@ -0,0 +1,392 @@
|
||||
"""
|
||||
AWS Polly Text-to-Speech transformation
|
||||
|
||||
Maps OpenAI TTS spec to AWS Polly SynthesizeSpeech API
|
||||
Reference: https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
||||
BaseTextToSpeechConfig,
|
||||
TextToSpeechRequestData,
|
||||
)
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HttpxBinaryResponseContent = Any
|
||||
|
||||
|
||||
class AWSPollyTextToSpeechConfig(BaseTextToSpeechConfig, BaseAWSLLM):
|
||||
"""
|
||||
Configuration for AWS Polly Text-to-Speech
|
||||
|
||||
Reference: https://docs.aws.amazon.com/polly/latest/dg/API_SynthesizeSpeech.html
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
BaseTextToSpeechConfig.__init__(self)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
# Default settings
|
||||
DEFAULT_VOICE = "Joanna"
|
||||
DEFAULT_ENGINE = "neural"
|
||||
DEFAULT_OUTPUT_FORMAT = "mp3"
|
||||
DEFAULT_REGION = "us-east-1"
|
||||
|
||||
# Voice name mappings from OpenAI voices to Polly voices
|
||||
VOICE_MAPPINGS = {
|
||||
"alloy": "Joanna", # US English female
|
||||
"echo": "Matthew", # US English male
|
||||
"fable": "Amy", # British English female
|
||||
"onyx": "Brian", # British English male
|
||||
"nova": "Ivy", # US English female (child)
|
||||
"shimmer": "Kendra", # US English female
|
||||
}
|
||||
|
||||
# Response format mappings from OpenAI to Polly
|
||||
FORMAT_MAPPINGS = {
|
||||
"mp3": "mp3",
|
||||
"opus": "ogg_vorbis",
|
||||
"aac": "mp3", # Polly doesn't support AAC, use MP3
|
||||
"flac": "mp3", # Polly doesn't support FLAC, use MP3
|
||||
"wav": "pcm",
|
||||
"pcm": "pcm",
|
||||
}
|
||||
|
||||
# Valid Polly engines
|
||||
VALID_ENGINES = {"standard", "neural", "long-form", "generative"}
|
||||
|
||||
def dispatch_text_to_speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, Dict]],
|
||||
optional_params: Dict,
|
||||
litellm_params_dict: Dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
extra_headers: Optional[Dict[str, Any]],
|
||||
base_llm_http_handler: Any,
|
||||
aspeech: bool,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
"HttpxBinaryResponseContent",
|
||||
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
|
||||
]:
|
||||
"""
|
||||
Dispatch method to handle AWS Polly TTS requests
|
||||
|
||||
This method encapsulates AWS-specific credential resolution and parameter handling
|
||||
|
||||
Args:
|
||||
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
|
||||
"""
|
||||
# Get AWS region from kwargs or environment
|
||||
aws_region_name = kwargs.get(
|
||||
"aws_region_name"
|
||||
) or self._get_aws_region_name_for_polly(optional_params=optional_params)
|
||||
|
||||
# Convert voice to string if it's a dict
|
||||
voice_str: Optional[str] = None
|
||||
if isinstance(voice, str):
|
||||
voice_str = voice
|
||||
elif isinstance(voice, dict):
|
||||
voice_str = voice.get("name") if voice else None
|
||||
|
||||
# Update litellm_params with resolved values
|
||||
# Note: AWS credentials (aws_access_key_id, aws_secret_access_key, etc.)
|
||||
# are already in litellm_params_dict via get_litellm_params() in main.py
|
||||
litellm_params_dict["aws_region_name"] = aws_region_name
|
||||
litellm_params_dict["api_base"] = api_base
|
||||
litellm_params_dict["api_key"] = api_key
|
||||
|
||||
# Call the text_to_speech_handler
|
||||
response = base_llm_http_handler.text_to_speech_handler(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=voice_str,
|
||||
text_to_speech_provider_config=self,
|
||||
text_to_speech_optional_params=optional_params,
|
||||
custom_llm_provider="aws_polly",
|
||||
litellm_params=litellm_params_dict,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=None,
|
||||
_is_async=aspeech,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _get_aws_region_name_for_polly(self, optional_params: Dict) -> str:
|
||||
"""Get AWS region name for Polly API calls."""
|
||||
aws_region_name = optional_params.get("aws_region_name")
|
||||
if aws_region_name is None:
|
||||
aws_region_name = self.get_aws_region_name_for_non_llm_api_calls()
|
||||
return aws_region_name
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
AWS Polly TTS supports these OpenAI parameters
|
||||
"""
|
||||
return ["voice", "response_format", "speed"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
voice: Optional[Union[str, Dict]] = None,
|
||||
drop_params: bool = False,
|
||||
kwargs: Dict = {},
|
||||
) -> Tuple[Optional[str], Dict]:
|
||||
"""
|
||||
Map OpenAI parameters to AWS Polly parameters
|
||||
"""
|
||||
mapped_params = {}
|
||||
|
||||
# Map voice - support both native Polly voices and OpenAI voice mappings
|
||||
mapped_voice: Optional[str] = None
|
||||
if isinstance(voice, str):
|
||||
if voice in self.VOICE_MAPPINGS:
|
||||
# OpenAI voice -> Polly voice
|
||||
mapped_voice = self.VOICE_MAPPINGS[voice]
|
||||
else:
|
||||
# Assume it's already a Polly voice name
|
||||
mapped_voice = voice
|
||||
|
||||
# Map response format
|
||||
if "response_format" in optional_params:
|
||||
format_name = optional_params["response_format"]
|
||||
if format_name in self.FORMAT_MAPPINGS:
|
||||
mapped_params["output_format"] = self.FORMAT_MAPPINGS[format_name]
|
||||
else:
|
||||
mapped_params["output_format"] = format_name
|
||||
else:
|
||||
mapped_params["output_format"] = self.DEFAULT_OUTPUT_FORMAT
|
||||
|
||||
# Extract engine from model name (e.g., "aws_polly/neural" -> "neural")
|
||||
engine = self._extract_engine_from_model(model)
|
||||
mapped_params["engine"] = engine
|
||||
|
||||
# Pass through Polly-specific parameters (use AWS API casing)
|
||||
if "language_code" in kwargs:
|
||||
mapped_params["LanguageCode"] = kwargs["language_code"]
|
||||
if "lexicon_names" in kwargs:
|
||||
mapped_params["LexiconNames"] = kwargs["lexicon_names"]
|
||||
if "sample_rate" in kwargs:
|
||||
mapped_params["SampleRate"] = kwargs["sample_rate"]
|
||||
|
||||
return mapped_voice, mapped_params
|
||||
|
||||
def _extract_engine_from_model(self, model: str) -> str:
|
||||
"""
|
||||
Extract engine from model name.
|
||||
|
||||
Examples:
|
||||
- aws_polly/neural -> neural
|
||||
- aws_polly/standard -> standard
|
||||
- aws_polly/long-form -> long-form
|
||||
- aws_polly -> neural (default)
|
||||
"""
|
||||
if "/" in model:
|
||||
parts = model.split("/")
|
||||
if len(parts) >= 2:
|
||||
engine = parts[1].lower()
|
||||
if engine in self.VALID_ENGINES:
|
||||
return engine
|
||||
return self.DEFAULT_ENGINE
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate AWS environment and set up headers.
|
||||
AWS SigV4 signing will be done in transform_text_to_speech_request.
|
||||
"""
|
||||
validated_headers = headers.copy()
|
||||
validated_headers["Content-Type"] = "application/json"
|
||||
return validated_headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for AWS Polly SynthesizeSpeech request
|
||||
|
||||
Polly endpoint format:
|
||||
https://polly.{region}.amazonaws.com/v1/speech
|
||||
"""
|
||||
if api_base is not None:
|
||||
return api_base.rstrip("/") + "/v1/speech"
|
||||
|
||||
aws_region_name = litellm_params.get("aws_region_name", self.DEFAULT_REGION)
|
||||
return f"https://polly.{aws_region_name}.amazonaws.com/v1/speech"
|
||||
|
||||
def is_ssml_input(self, input: str) -> bool:
|
||||
"""
|
||||
Returns True if input is SSML, False otherwise.
|
||||
|
||||
Based on AWS Polly SSML requirements - must contain <speak> tag.
|
||||
"""
|
||||
return "<speak>" in input or "<speak " in input
|
||||
|
||||
def _sign_polly_request(
|
||||
self,
|
||||
request_body: Dict[str, Any],
|
||||
endpoint_url: str,
|
||||
litellm_params: Dict,
|
||||
) -> Tuple[Dict[str, str], str]:
|
||||
"""
|
||||
Sign the AWS Polly request using SigV4.
|
||||
|
||||
Returns:
|
||||
Tuple of (signed_headers, json_body_string)
|
||||
"""
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing boto3 to call AWS Polly. Run 'pip install boto3'."
|
||||
)
|
||||
|
||||
# Get AWS region
|
||||
aws_region_name = litellm_params.get("aws_region_name", self.DEFAULT_REGION)
|
||||
|
||||
# Get AWS credentials
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=litellm_params.get("aws_access_key_id"),
|
||||
aws_secret_access_key=litellm_params.get("aws_secret_access_key"),
|
||||
aws_session_token=litellm_params.get("aws_session_token"),
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=litellm_params.get("aws_session_name"),
|
||||
aws_profile_name=litellm_params.get("aws_profile_name"),
|
||||
aws_role_name=litellm_params.get("aws_role_name"),
|
||||
aws_web_identity_token=litellm_params.get("aws_web_identity_token"),
|
||||
aws_sts_endpoint=litellm_params.get("aws_sts_endpoint"),
|
||||
aws_external_id=litellm_params.get("aws_external_id"),
|
||||
)
|
||||
|
||||
# Serialize request body to JSON
|
||||
json_body = json.dumps(request_body)
|
||||
|
||||
# Create headers for signing
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Create AWS request for signing
|
||||
aws_request = AWSRequest(
|
||||
method="POST",
|
||||
url=endpoint_url,
|
||||
data=json_body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Sign the request
|
||||
SigV4Auth(credentials, "polly", aws_region_name).add_auth(aws_request)
|
||||
|
||||
# Return signed headers and body
|
||||
return dict(aws_request.headers), json_body
|
||||
|
||||
def transform_text_to_speech_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[str],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: dict,
|
||||
) -> TextToSpeechRequestData:
|
||||
"""
|
||||
Transform OpenAI TTS request to AWS Polly SynthesizeSpeech format.
|
||||
|
||||
Supports:
|
||||
- Native Polly voices (Joanna, Matthew, etc.)
|
||||
- OpenAI voice mapping (alloy, echo, etc.)
|
||||
- SSML input (auto-detected via <speak> tag)
|
||||
- Multiple engines (neural, standard, long-form, generative)
|
||||
|
||||
Returns:
|
||||
TextToSpeechRequestData: Contains signed request for Polly API
|
||||
"""
|
||||
# Get voice (already mapped in main.py, or use default)
|
||||
polly_voice = voice or self.DEFAULT_VOICE
|
||||
|
||||
# Get output format
|
||||
output_format = optional_params.get("output_format", self.DEFAULT_OUTPUT_FORMAT)
|
||||
|
||||
# Get engine
|
||||
engine = optional_params.get("engine", self.DEFAULT_ENGINE)
|
||||
|
||||
# Build request body
|
||||
request_body: Dict[str, Any] = {
|
||||
"Engine": engine,
|
||||
"OutputFormat": output_format,
|
||||
"Text": input,
|
||||
"VoiceId": polly_voice,
|
||||
}
|
||||
|
||||
# Auto-detect SSML
|
||||
if self.is_ssml_input(input):
|
||||
request_body["TextType"] = "ssml"
|
||||
else:
|
||||
request_body["TextType"] = "text"
|
||||
|
||||
# Add optional Polly parameters (already in AWS casing from map_openai_params)
|
||||
for key in ["LanguageCode", "LexiconNames", "SampleRate"]:
|
||||
if key in optional_params:
|
||||
request_body[key] = optional_params[key]
|
||||
|
||||
# Get endpoint URL
|
||||
endpoint_url = self.get_complete_url(
|
||||
model=model,
|
||||
api_base=litellm_params.get("api_base"),
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Sign the request with AWS SigV4
|
||||
signed_headers, json_body = self._sign_polly_request(
|
||||
request_body=request_body,
|
||||
endpoint_url=endpoint_url,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Return as ssml_body so the handler uses data= instead of json=
|
||||
# This preserves the exact JSON string that was signed
|
||||
return TextToSpeechRequestData(
|
||||
ssml_body=json_body,
|
||||
headers=signed_headers,
|
||||
)
|
||||
|
||||
def transform_text_to_speech_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""
|
||||
Transform AWS Polly response to standard format.
|
||||
|
||||
Polly returns the audio data directly in the response body.
|
||||
"""
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
return HttpxBinaryResponseContent(raw_response)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,198 @@
|
||||
from litellm._uuid import uuid
|
||||
from typing import Any, Coroutine, Optional, Union
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import (
|
||||
TranscriptionResponse,
|
||||
convert_to_model_response_object,
|
||||
extract_duration_from_srt_or_vtt,
|
||||
)
|
||||
|
||||
from .azure import AzureChatCompletion
|
||||
from .common_utils import AzureOpenAIError
|
||||
|
||||
|
||||
class AzureAudioTranscription(AzureChatCompletion):
|
||||
def audio_transcriptions(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
atranscription: bool = False,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[TranscriptionResponse, Coroutine[Any, Any, TranscriptionResponse]]:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions(
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=f"audio_file_{uuid.uuid4()}",
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
response = azure_client.audio.transcriptions.create(
|
||||
**data, timeout=timeout # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": model, "custom_llm_provider": "azure"}
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
self,
|
||||
audio_file: FileTypes,
|
||||
model: str,
|
||||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
logging_obj: Any,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> TranscriptionResponse:
|
||||
response = None
|
||||
try:
|
||||
async_azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(async_azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="async_azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=f"audio_file_{uuid.uuid4()}",
|
||||
api_key=async_azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {async_azure_client.api_key}"
|
||||
},
|
||||
"api_base": async_azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
|
||||
raw_response = (
|
||||
await async_azure_client.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
duration = extract_duration_from_srt_or_vtt(response)
|
||||
stringified_response["_audio_transcription_duration"] = duration
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {async_azure_client.api_key}"
|
||||
},
|
||||
"api_base": async_azure_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": model, "custom_llm_provider": "azure"}
|
||||
response = convert_to_model_response_object(
|
||||
_response_headers=headers,
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params=hidden_params,
|
||||
response_type="audio_transcription",
|
||||
)
|
||||
if not isinstance(response, TranscriptionResponse):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="response is not an instance of TranscriptionResponse",
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,233 @@
|
||||
"""
|
||||
Azure Batches API Handler
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
from litellm.llms.azure.azure import AsyncAzureOpenAI, AzureOpenAI
|
||||
from litellm.types.llms.openai import (
|
||||
CancelBatchRequest,
|
||||
CreateBatchRequest,
|
||||
RetrieveBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureBatchesAPI(BaseAzureLLM):
|
||||
"""
|
||||
Azure methods to support for batches
|
||||
- create_batch()
|
||||
- retrieve_batch()
|
||||
- cancel_batch()
|
||||
- list_batch()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def acreate_batch(
|
||||
self,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
azure_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> LiteLLMBatch:
|
||||
response = await azure_client.batches.create(**create_batch_data) # type: ignore[arg-type]
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def create_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acreate_batch( # type: ignore
|
||||
create_batch_data=create_batch_data, azure_client=azure_client
|
||||
)
|
||||
response = cast(Union[AzureOpenAI, OpenAI], azure_client).batches.create(**create_batch_data) # type: ignore[arg-type]
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def aretrieve_batch(
|
||||
self,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> LiteLLMBatch:
|
||||
response = await client.batches.retrieve(**retrieve_batch_data) # type: ignore[arg-type]
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def retrieve_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
retrieve_batch_data: RetrieveBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.aretrieve_batch( # type: ignore
|
||||
retrieve_batch_data=retrieve_batch_data, client=azure_client
|
||||
)
|
||||
response = cast(Union[AzureOpenAI, OpenAI], azure_client).batches.retrieve(
|
||||
**retrieve_batch_data
|
||||
)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def acancel_batch(
|
||||
self,
|
||||
cancel_batch_data: CancelBatchRequest,
|
||||
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> LiteLLMBatch:
|
||||
response = await client.batches.cancel(**cancel_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
def cancel_batch(
|
||||
self,
|
||||
_is_async: bool,
|
||||
cancel_batch_data: CancelBatchRequest,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"Azure client is not an instance of AsyncAzureOpenAI or AsyncOpenAI. Make sure you passed an async client."
|
||||
)
|
||||
return self.acancel_batch( # type: ignore
|
||||
cancel_batch_data=cancel_batch_data, client=azure_client
|
||||
)
|
||||
|
||||
# At this point, azure_client is guaranteed to be a sync client
|
||||
if not isinstance(azure_client, (AzureOpenAI, OpenAI)):
|
||||
raise ValueError(
|
||||
"Azure client is not an instance of AzureOpenAI or OpenAI. Make sure you passed a sync client."
|
||||
)
|
||||
response = azure_client.batches.cancel(**cancel_batch_data)
|
||||
return LiteLLMBatch(**response.model_dump())
|
||||
|
||||
async def alist_batches(
|
||||
self,
|
||||
client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
response = await client.batches.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
def list_batches(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
azure_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
litellm_params=litellm_params or {},
|
||||
)
|
||||
if azure_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(azure_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.alist_batches( # type: ignore
|
||||
client=azure_client, after=after, limit=limit
|
||||
)
|
||||
response = azure_client.batches.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
@@ -0,0 +1,160 @@
|
||||
"""Support for Azure OpenAI gpt-5 model family."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import litellm
|
||||
from litellm.exceptions import UnsupportedParamsError
|
||||
from litellm.llms.openai.chat.gpt_5_transformation import (
|
||||
OpenAIGPT5Config,
|
||||
_get_effort_level,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
from .gpt_transformation import AzureOpenAIConfig
|
||||
|
||||
|
||||
class AzureOpenAIGPT5Config(AzureOpenAIConfig, OpenAIGPT5Config):
|
||||
"""Azure specific handling for gpt-5 models."""
|
||||
|
||||
GPT5_SERIES_ROUTE = "gpt5_series/"
|
||||
|
||||
@classmethod
|
||||
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
|
||||
"""Override to handle gpt5_series/ prefix used for Azure routing.
|
||||
|
||||
The parent class calls ``_supports_factory(model, custom_llm_provider=None)``
|
||||
which fails to resolve ``gpt5_series/gpt-5.1`` to the correct Azure model
|
||||
entry. Strip the prefix and prepend ``azure/`` so the lookup finds
|
||||
``azure/gpt-5.1`` in model_prices_and_context_window.json.
|
||||
"""
|
||||
if model.startswith(cls.GPT5_SERIES_ROUTE):
|
||||
model = "azure/" + model[len(cls.GPT5_SERIES_ROUTE) :]
|
||||
elif not model.startswith("azure/"):
|
||||
model = "azure/" + model
|
||||
return super()._supports_reasoning_effort_level(model, level)
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_model(cls, model: str) -> bool:
|
||||
"""Check if the Azure model string refers to a gpt-5 variant.
|
||||
|
||||
Accepts both explicit gpt-5 model names and the ``gpt5_series/`` prefix
|
||||
used for manual routing.
|
||||
"""
|
||||
# gpt-5-chat* is a chat model and shouldn't go through GPT-5 reasoning restrictions.
|
||||
return (
|
||||
"gpt-5" in model and "gpt-5-chat" not in model
|
||||
) or "gpt5_series" in model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""Get supported parameters for Azure OpenAI GPT-5 models.
|
||||
|
||||
Azure OpenAI GPT-5.2/5.4 models support logprobs, unlike OpenAI's GPT-5.
|
||||
This overrides the parent class to add logprobs support back for gpt-5.2+.
|
||||
|
||||
Reference:
|
||||
- Tested with Azure OpenAI GPT-5.2 (api-version: 2025-01-01-preview)
|
||||
- Azure returns logprobs successfully despite Microsoft's general
|
||||
documentation stating reasoning models don't support it.
|
||||
"""
|
||||
params = OpenAIGPT5Config.get_supported_openai_params(self, model=model)
|
||||
|
||||
# Azure supports tool_choice for GPT-5 deployments, but the base GPT-5 config
|
||||
# can drop it when the deployment name isn't in the OpenAI model registry.
|
||||
if "tool_choice" not in params:
|
||||
params.append("tool_choice")
|
||||
|
||||
# Only gpt-5.2+ has been verified to support logprobs on Azure.
|
||||
# The base OpenAI class includes logprobs for gpt-5.1+, but Azure
|
||||
# hasn't verified support for gpt-5.1, so remove them unless gpt-5.2/5.4+.
|
||||
if self._supports_reasoning_effort_level(
|
||||
model, "none"
|
||||
) and not self.is_model_gpt_5_2_model(model):
|
||||
params = [p for p in params if p not in ["logprobs", "top_logprobs"]]
|
||||
elif self.is_model_gpt_5_2_model(model):
|
||||
azure_supported_params = ["logprobs", "top_logprobs"]
|
||||
params.extend(azure_supported_params)
|
||||
|
||||
return params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
api_version: str = "",
|
||||
) -> dict:
|
||||
reasoning_effort_value = non_default_params.get(
|
||||
"reasoning_effort"
|
||||
) or optional_params.get("reasoning_effort")
|
||||
effective_effort = _get_effort_level(reasoning_effort_value)
|
||||
|
||||
# gpt-5.1/5.2/5.4 support reasoning_effort='none', but other gpt-5 models don't
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/reasoning
|
||||
supports_none = self._supports_reasoning_effort_level(model, "none")
|
||||
|
||||
if effective_effort == "none" and not supports_none:
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
non_default_params = non_default_params.copy()
|
||||
optional_params = optional_params.copy()
|
||||
if (
|
||||
_get_effort_level(non_default_params.get("reasoning_effort"))
|
||||
== "none"
|
||||
):
|
||||
non_default_params.pop("reasoning_effort")
|
||||
if _get_effort_level(optional_params.get("reasoning_effort")) == "none":
|
||||
optional_params.pop("reasoning_effort")
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=(
|
||||
"Azure OpenAI does not support reasoning_effort='none' for this model. "
|
||||
"Supported values are: 'low', 'medium', and 'high'. "
|
||||
"To drop this parameter, set `litellm.drop_params=True` or for proxy:\n\n"
|
||||
"`litellm_settings:\n drop_params: true`\n"
|
||||
"Issue: https://github.com/BerriAI/litellm/issues/16704"
|
||||
),
|
||||
)
|
||||
|
||||
result = OpenAIGPT5Config.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
# Only drop reasoning_effort='none' for models that don't support it
|
||||
result_effort = _get_effort_level(result.get("reasoning_effort"))
|
||||
if result_effort == "none" and not supports_none:
|
||||
result.pop("reasoning_effort")
|
||||
|
||||
# Azure Chat Completions: gpt-5.4+ does not support tools + reasoning together.
|
||||
# Drop reasoning_effort when both are present (OpenAI routes to Responses API; Azure does not).
|
||||
if self.is_model_gpt_5_4_plus_model(model):
|
||||
has_tools = bool(
|
||||
non_default_params.get("tools") or optional_params.get("tools")
|
||||
)
|
||||
if has_tools and result_effort not in (None, "none"):
|
||||
result.pop("reasoning_effort", None)
|
||||
|
||||
return result
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
model = model.replace(self.GPT5_SERIES_ROUTE, "")
|
||||
return super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,334 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from httpx._models import Headers, Response
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
convert_to_azure_openai_messages,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.azure import (
|
||||
API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT,
|
||||
API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ....exceptions import UnsupportedParamsError
|
||||
from ....types.llms.openai import AllMessageValues
|
||||
from ...base_llm.chat.transformation import BaseConfig
|
||||
from ..common_utils import AzureOpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
LoggingClass = LiteLLMLoggingObj
|
||||
else:
|
||||
LoggingClass = Any
|
||||
|
||||
|
||||
class AzureOpenAIConfig(BaseConfig):
|
||||
"""
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"temperature",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"function_call",
|
||||
"functions",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"top_p",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"response_format",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"prediction",
|
||||
"modalities",
|
||||
"audio",
|
||||
"web_search_options",
|
||||
"prompt_cache_key",
|
||||
"store",
|
||||
]
|
||||
|
||||
def _is_response_format_supported_model(self, model: str) -> bool:
|
||||
"""
|
||||
Determines if the model supports response_format.
|
||||
- Handles Azure deployment names (e.g., azure/gpt-4.1-suffix)
|
||||
- Normalizes model names (e.g., gpt-4-1 -> gpt-4.1)
|
||||
- Strips deployment-specific suffixes
|
||||
- Passes provider to supports_response_schema
|
||||
- Backwards compatible with previous model name patterns
|
||||
"""
|
||||
import re
|
||||
|
||||
# Normalize model name: e.g., gpt-3-5-turbo -> gpt-3.5-turbo
|
||||
normalized_model = re.sub(r"(\d)-(\d)", r"\1.\2", model)
|
||||
|
||||
if "gpt-3.5" in normalized_model or "gpt-35" in model:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_response_format_supported_api_version(
|
||||
self, api_version_year: str, api_version_month: str
|
||||
) -> bool:
|
||||
"""
|
||||
- check if api_version is supported for response_format
|
||||
- returns True if the API version is equal to or newer than the supported version
|
||||
"""
|
||||
api_year = int(api_version_year)
|
||||
api_month = int(api_version_month)
|
||||
supported_year = int(API_VERSION_YEAR_SUPPORTED_RESPONSE_FORMAT)
|
||||
supported_month = int(API_VERSION_MONTH_SUPPORTED_RESPONSE_FORMAT)
|
||||
|
||||
# If the year is greater than supported year, it's definitely supported
|
||||
if api_year > supported_year:
|
||||
return True
|
||||
# If the year is less than supported year, it's not supported
|
||||
elif api_year < supported_year:
|
||||
return False
|
||||
# If same year, check if month is >= supported month
|
||||
else:
|
||||
return api_month >= supported_month
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
api_version: str = "",
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
api_version_times = api_version.split("-")
|
||||
|
||||
if len(api_version_times) >= 3:
|
||||
api_version_year = api_version_times[0]
|
||||
api_version_month = api_version_times[1]
|
||||
api_version_day = api_version_times[2]
|
||||
else:
|
||||
api_version_year = None
|
||||
api_version_month = None
|
||||
api_version_day = None
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param == "tool_choice":
|
||||
"""
|
||||
This parameter requires API version 2023-12-01-preview or later
|
||||
|
||||
tool_choice='required' is not supported as of 2024-05-01-preview
|
||||
"""
|
||||
## check if api version supports this param ##
|
||||
if (
|
||||
api_version_year is None
|
||||
or api_version_month is None
|
||||
or api_version_day is None
|
||||
):
|
||||
optional_params["tool_choice"] = value
|
||||
else:
|
||||
if (
|
||||
api_version_year < "2023"
|
||||
or (api_version_year == "2023" and api_version_month < "12")
|
||||
or (
|
||||
api_version_year == "2023"
|
||||
and api_version_month == "12"
|
||||
and api_version_day < "01"
|
||||
)
|
||||
):
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"""Azure does not support 'tool_choice', for api_version={api_version}. Bump your API version to '2023-12-01-preview' or later. This parameter requires 'api_version="2023-12-01-preview"' or later. Azure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions""",
|
||||
)
|
||||
elif value == "required" and (
|
||||
api_version_year == "2024" and api_version_month <= "05"
|
||||
): ## check if tool_choice value is supported ##
|
||||
if litellm.drop_params is True or (
|
||||
drop_params is not None and drop_params is True
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=400,
|
||||
message=f"Azure does not support '{value}' as a {param} param, for api_version={api_version}. To drop 'tool_choice=required' for calls with this Azure API version, set `litellm.drop_params=True` or for proxy:\n\n`litellm_settings:\n drop_params: true`\nAzure API Reference: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions",
|
||||
)
|
||||
else:
|
||||
optional_params["tool_choice"] = value
|
||||
elif param == "response_format" and isinstance(value, dict):
|
||||
_is_response_format_supported_model = (
|
||||
self._is_response_format_supported_model(model)
|
||||
)
|
||||
|
||||
if api_version_year is None or api_version_month is None:
|
||||
is_response_format_supported_api_version = True
|
||||
else:
|
||||
is_response_format_supported_api_version = (
|
||||
self._is_response_format_supported_api_version(
|
||||
api_version_year, api_version_month
|
||||
)
|
||||
)
|
||||
is_response_format_supported = (
|
||||
is_response_format_supported_api_version
|
||||
and _is_response_format_supported_model
|
||||
)
|
||||
|
||||
optional_params = self._add_response_format_to_tools(
|
||||
optional_params=optional_params,
|
||||
value=value,
|
||||
is_response_format_supported=is_response_format_supported,
|
||||
)
|
||||
elif param == "tools" and isinstance(value, list):
|
||||
optional_params.setdefault("tools", [])
|
||||
optional_params["tools"].extend(value)
|
||||
elif param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
messages = convert_to_azure_openai_messages(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LoggingClass,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError(
|
||||
"Azure OpenAI handler.py has custom logic for transforming response, as it uses the OpenAI SDK."
|
||||
)
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
return {"token": "azure_ad_token"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "token":
|
||||
optional_params["azure_ad_token"] = value
|
||||
return optional_params
|
||||
|
||||
def get_eu_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return ["europe", "sweden", "switzerland", "france", "uk"]
|
||||
|
||||
def get_us_regions(self) -> List[str]:
|
||||
"""
|
||||
Source: https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/models#gpt-4-and-gpt-4-turbo-model-availability
|
||||
"""
|
||||
return [
|
||||
"us",
|
||||
"eastus",
|
||||
"eastus2",
|
||||
"eastus2euap",
|
||||
"eastus3",
|
||||
"southcentralus",
|
||||
"westus",
|
||||
"westus2",
|
||||
"westus3",
|
||||
"westus4",
|
||||
]
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return AzureOpenAIError(
|
||||
message=error_message, status_code=status_code, headers=headers
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError(
|
||||
"Azure OpenAI has custom logic for validating environment, as it uses the OpenAI SDK."
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Handler file for calls to Azure OpenAI's o1/o3 family of models
|
||||
|
||||
Written separately to handle faking streaming for o1 and o3 models.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ...openai.openai import OpenAIChatCompletion
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
|
||||
|
||||
class AzureOpenAIO1ChatCompletion(BaseAzureLLM, OpenAIChatCompletion):
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
logging_obj: Any,
|
||||
model: Optional[str] = None,
|
||||
messages: Optional[list] = None,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
dynamic_params: Optional[bool] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
custom_prompt_dict: dict = {},
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
drop_params: Optional[bool] = None,
|
||||
shared_session: Optional["ClientSession"] = None,
|
||||
):
|
||||
client = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=acompletion,
|
||||
)
|
||||
return super().completion(
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
messages=messages,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
dynamic_params=dynamic_params,
|
||||
azure_ad_token=azure_ad_token,
|
||||
acompletion=acompletion,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
client=client,
|
||||
organization=organization,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
shared_session=shared_session,
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Support for o1 and o3 model families
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
- Temperature => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.utils import get_model_info, supports_reasoning
|
||||
|
||||
from ...openai.chat.o_series_transformation import OpenAIOSeriesConfig
|
||||
|
||||
|
||||
class AzureOpenAIO1Config(OpenAIOSeriesConfig):
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the Azure O-Series models
|
||||
"""
|
||||
all_openai_params = litellm.OpenAIGPTConfig().get_supported_openai_params(
|
||||
model=model
|
||||
)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = self._get_o_series_only_params(model)
|
||||
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
return [
|
||||
param for param in all_openai_params if param not in non_supported_params
|
||||
]
|
||||
|
||||
def _get_o_series_only_params(self, model: str) -> list:
|
||||
"""
|
||||
Helper function to get the o-series only params for the model
|
||||
|
||||
- reasoning_effort
|
||||
"""
|
||||
o_series_only_param = []
|
||||
|
||||
#########################################################
|
||||
# Case 1: If the model is recognized and in litellm model cost map
|
||||
# then check if it supports reasoning
|
||||
#########################################################
|
||||
if model in litellm.model_list_set:
|
||||
if supports_reasoning(model):
|
||||
o_series_only_param.append("reasoning_effort")
|
||||
#########################################################
|
||||
# Case 2: If the model is not recognized, then we assume it supports reasoning
|
||||
# This is critical because several users tend to use custom deployment names
|
||||
# for azure o-series models.
|
||||
#########################################################
|
||||
else:
|
||||
o_series_only_param.append("reasoning_effort")
|
||||
|
||||
return o_series_only_param
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Currently no Azure O Series models support native streaming.
|
||||
"""
|
||||
|
||||
if stream is not True:
|
||||
return False
|
||||
|
||||
if (
|
||||
model and "o3" in model
|
||||
): # o3 models support streaming - https://github.com/BerriAI/litellm/issues/8274
|
||||
return False
|
||||
|
||||
if model is not None:
|
||||
try:
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
) # allow user to override default with model_info={"supports_native_streaming": true}
|
||||
|
||||
if (
|
||||
model_info.get("supports_native_streaming") is True
|
||||
): # allow user to override default with model_info={"supports_native_streaming": true}
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in AzureOpenAIO1Config: {e}"
|
||||
)
|
||||
return True
|
||||
|
||||
def is_o_series_model(self, model: str) -> bool:
|
||||
return "o1" in model or "o3" in model or "o4" in model or "o_series/" in model
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
model = model.replace(
|
||||
"o_series/", ""
|
||||
) # handle o_series/my-random-deployment-name
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
@@ -0,0 +1,844 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Callable, Dict, Literal, NamedTuple, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.openai.common_utils import BaseOpenAILLM
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
get_azure_ad_token_provider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
azure_ad_cache = DualCache()
|
||||
|
||||
|
||||
class AzureOpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code,
|
||||
message,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[httpx.Headers, dict]] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request=request,
|
||||
response=response,
|
||||
headers=headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
def process_azure_headers(headers: Union[httpx.Headers, dict]) -> dict:
|
||||
openai_headers = {}
|
||||
if "x-ratelimit-limit-requests" in headers:
|
||||
openai_headers["x-ratelimit-limit-requests"] = headers[
|
||||
"x-ratelimit-limit-requests"
|
||||
]
|
||||
if "x-ratelimit-remaining-requests" in headers:
|
||||
openai_headers["x-ratelimit-remaining-requests"] = headers[
|
||||
"x-ratelimit-remaining-requests"
|
||||
]
|
||||
if "x-ratelimit-limit-tokens" in headers:
|
||||
openai_headers["x-ratelimit-limit-tokens"] = headers["x-ratelimit-limit-tokens"]
|
||||
if "x-ratelimit-remaining-tokens" in headers:
|
||||
openai_headers["x-ratelimit-remaining-tokens"] = headers[
|
||||
"x-ratelimit-remaining-tokens"
|
||||
]
|
||||
llm_response_headers = {
|
||||
"{}-{}".format("llm_provider", k): v for k, v in headers.items()
|
||||
}
|
||||
|
||||
return {**llm_response_headers, **openai_headers}
|
||||
|
||||
|
||||
def get_azure_ad_token_from_entra_id(
|
||||
tenant_id: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
scope: str = "https://cognitiveservices.azure.com/.default",
|
||||
) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider from `client_id`, `client_secret`, and `tenant_id`
|
||||
|
||||
Args:
|
||||
tenant_id: str
|
||||
client_id: str
|
||||
client_secret: str
|
||||
scope: str
|
||||
|
||||
Returns:
|
||||
callable that returns a bearer token.
|
||||
"""
|
||||
from azure.identity import ClientSecretCredential, get_bearer_token_provider
|
||||
|
||||
verbose_logger.debug("Getting Azure AD Token from Entra ID")
|
||||
|
||||
if tenant_id.startswith("os.environ/"):
|
||||
_tenant_id = get_secret_str(tenant_id)
|
||||
else:
|
||||
_tenant_id = tenant_id
|
||||
|
||||
if client_id.startswith("os.environ/"):
|
||||
_client_id = get_secret_str(client_id)
|
||||
else:
|
||||
_client_id = client_id
|
||||
|
||||
if client_secret.startswith("os.environ/"):
|
||||
_client_secret = get_secret_str(client_secret)
|
||||
else:
|
||||
_client_secret = client_secret
|
||||
|
||||
verbose_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
_tenant_id,
|
||||
_client_id,
|
||||
_client_secret,
|
||||
)
|
||||
if _tenant_id is None or _client_id is None or _client_secret is None:
|
||||
raise ValueError("tenant_id, client_id, and client_secret must be provided")
|
||||
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
|
||||
|
||||
verbose_logger.debug("credential %s", credential)
|
||||
|
||||
token_provider = get_bearer_token_provider(credential, scope)
|
||||
|
||||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_username_password(
|
||||
client_id: str,
|
||||
azure_username: str,
|
||||
azure_password: str,
|
||||
scope: str = "https://cognitiveservices.azure.com/.default",
|
||||
) -> Callable[[], str]:
|
||||
"""
|
||||
Get Azure AD token provider from `client_id`, `azure_username`, and `azure_password`
|
||||
|
||||
Args:
|
||||
client_id: str
|
||||
azure_username: str
|
||||
azure_password: str
|
||||
scope: str
|
||||
|
||||
Returns:
|
||||
callable that returns a bearer token.
|
||||
"""
|
||||
from azure.identity import UsernamePasswordCredential, get_bearer_token_provider
|
||||
|
||||
verbose_logger.debug(
|
||||
"client_id %s, azure_username %s, azure_password %s",
|
||||
client_id,
|
||||
azure_username,
|
||||
azure_password,
|
||||
)
|
||||
credential = UsernamePasswordCredential(
|
||||
client_id=client_id,
|
||||
username=azure_username,
|
||||
password=azure_password,
|
||||
)
|
||||
|
||||
verbose_logger.debug("credential %s", credential)
|
||||
|
||||
token_provider = get_bearer_token_provider(credential, scope)
|
||||
|
||||
verbose_logger.debug("token_provider %s", token_provider)
|
||||
|
||||
return token_provider
|
||||
|
||||
|
||||
def get_azure_ad_token_from_oidc(
|
||||
azure_ad_token: str,
|
||||
azure_client_id: Optional[str] = None,
|
||||
azure_tenant_id: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get Azure AD token from OIDC token
|
||||
|
||||
Args:
|
||||
azure_ad_token: str
|
||||
azure_client_id: Optional[str]
|
||||
azure_tenant_id: Optional[str]
|
||||
scope: str
|
||||
|
||||
Returns:
|
||||
`azure_ad_token_access_token` - str
|
||||
"""
|
||||
if scope is None:
|
||||
scope = "https://cognitiveservices.azure.com/.default"
|
||||
azure_authority_host = os.getenv(
|
||||
"AZURE_AUTHORITY_HOST", "https://login.microsoftonline.com"
|
||||
)
|
||||
azure_client_id = azure_client_id or os.getenv("AZURE_CLIENT_ID")
|
||||
azure_tenant_id = azure_tenant_id or os.getenv("AZURE_TENANT_ID")
|
||||
if azure_client_id is None or azure_tenant_id is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422,
|
||||
message="AZURE_CLIENT_ID and AZURE_TENANT_ID must be set",
|
||||
)
|
||||
|
||||
oidc_token = get_secret_str(azure_ad_token)
|
||||
|
||||
if oidc_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=401,
|
||||
message="OIDC token could not be retrieved from secret manager.",
|
||||
)
|
||||
|
||||
azure_ad_token_cache_key = json.dumps(
|
||||
{
|
||||
"azure_client_id": azure_client_id,
|
||||
"azure_tenant_id": azure_tenant_id,
|
||||
"azure_authority_host": azure_authority_host,
|
||||
"oidc_token": oidc_token,
|
||||
}
|
||||
)
|
||||
|
||||
azure_ad_token_access_token = azure_ad_cache.get_cache(azure_ad_token_cache_key)
|
||||
if azure_ad_token_access_token is not None:
|
||||
return azure_ad_token_access_token
|
||||
|
||||
client = litellm.module_level_client
|
||||
|
||||
req_token = client.post(
|
||||
f"{azure_authority_host}/{azure_tenant_id}/oauth2/v2.0/token",
|
||||
data={
|
||||
"client_id": azure_client_id,
|
||||
"grant_type": "client_credentials",
|
||||
"scope": scope,
|
||||
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
|
||||
"client_assertion": oidc_token,
|
||||
},
|
||||
)
|
||||
|
||||
if req_token.status_code != 200:
|
||||
raise AzureOpenAIError(
|
||||
status_code=req_token.status_code,
|
||||
message=req_token.text,
|
||||
)
|
||||
|
||||
azure_ad_token_json = req_token.json()
|
||||
azure_ad_token_access_token = azure_ad_token_json.get("access_token", None)
|
||||
azure_ad_token_expires_in = azure_ad_token_json.get("expires_in", None)
|
||||
|
||||
if azure_ad_token_access_token is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token access_token not returned"
|
||||
)
|
||||
|
||||
if azure_ad_token_expires_in is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Azure AD Token expires_in not returned"
|
||||
)
|
||||
|
||||
azure_ad_cache.set_cache(
|
||||
key=azure_ad_token_cache_key,
|
||||
value=azure_ad_token_access_token,
|
||||
ttl=azure_ad_token_expires_in,
|
||||
)
|
||||
|
||||
return azure_ad_token_access_token
|
||||
|
||||
|
||||
def select_azure_base_url_or_endpoint(azure_client_params: dict):
|
||||
azure_endpoint = azure_client_params.get("azure_endpoint", None)
|
||||
if azure_endpoint is not None:
|
||||
# see : https://github.com/openai/openai-python/blob/3d61ed42aba652b547029095a7eb269ad4e1e957/src/openai/lib/azure.py#L192
|
||||
if "/openai/deployments" in azure_endpoint:
|
||||
# this is base_url, not an azure_endpoint
|
||||
azure_client_params["base_url"] = azure_endpoint
|
||||
azure_client_params.pop("azure_endpoint")
|
||||
|
||||
return azure_client_params
|
||||
|
||||
|
||||
def get_azure_ad_token(
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get Azure AD token from various authentication methods.
|
||||
|
||||
This function tries different methods to obtain an Azure AD token:
|
||||
1. From an existing token provider
|
||||
2. From Entra ID using tenant_id, client_id, and client_secret
|
||||
3. From username and password
|
||||
4. From OIDC token
|
||||
5. From a service principal with secret workflow
|
||||
6. From DefaultAzureCredential
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing authentication parameters
|
||||
- azure_ad_token_provider: Optional callable that returns a token
|
||||
- azure_ad_token: Optional existing token
|
||||
- tenant_id: Optional Azure tenant ID
|
||||
- client_id: Optional Azure client ID
|
||||
- client_secret: Optional Azure client secret
|
||||
- azure_username: Optional Azure username
|
||||
- azure_password: Optional Azure password
|
||||
|
||||
Returns:
|
||||
Azure AD token as string if successful, None otherwise
|
||||
"""
|
||||
# Extract parameters
|
||||
# Use `or` instead of default parameter to handle cases where key exists but value is None
|
||||
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
|
||||
azure_ad_token = litellm_params.get("azure_ad_token") or get_secret_str(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
tenant_id = litellm_params.get("tenant_id") or os.getenv("AZURE_TENANT_ID")
|
||||
client_id = litellm_params.get("client_id") or os.getenv("AZURE_CLIENT_ID")
|
||||
client_secret = litellm_params.get("client_secret") or os.getenv(
|
||||
"AZURE_CLIENT_SECRET"
|
||||
)
|
||||
azure_username = litellm_params.get("azure_username") or os.getenv("AZURE_USERNAME")
|
||||
azure_password = litellm_params.get("azure_password") or os.getenv("AZURE_PASSWORD")
|
||||
scope = litellm_params.get("azure_scope") or os.getenv(
|
||||
"AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"
|
||||
)
|
||||
if scope is None:
|
||||
scope = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
# Try to get token provider from Entra ID
|
||||
if azure_ad_token_provider is None and tenant_id and client_id and client_secret:
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD Token Provider from Entra ID for Azure Auth"
|
||||
)
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# Try to get token provider from username and password
|
||||
if (
|
||||
azure_ad_token_provider is None
|
||||
and azure_username
|
||||
and azure_password
|
||||
and client_id
|
||||
):
|
||||
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=azure_username,
|
||||
azure_password=azure_password,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
# Try to get token from OIDC
|
||||
if (
|
||||
client_id
|
||||
and tenant_id
|
||||
and azure_ad_token
|
||||
and azure_ad_token.startswith("oidc/")
|
||||
):
|
||||
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_client_id=client_id,
|
||||
azure_tenant_id=tenant_id,
|
||||
scope=scope,
|
||||
)
|
||||
# Try to get token provider from service principal or DefaultAzureCredential
|
||||
elif (
|
||||
azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD token provider based on Service Principal with Secret workflow or DefaultAzureCredential for Azure Auth"
|
||||
)
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider(azure_scope=scope)
|
||||
except ValueError:
|
||||
verbose_logger.debug("Azure AD Token Provider could not be used.")
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error calling Azure AD token provider: {str(e)}. Follow docs - https://docs.litellm.ai/docs/providers/azure/#azure-ad-token-refresh---defaultazurecredential"
|
||||
)
|
||||
raise e
|
||||
|
||||
#########################################################
|
||||
# If litellm.enable_azure_ad_token_refresh is True and no other token provider is available,
|
||||
# try to get DefaultAzureCredential provider
|
||||
#########################################################
|
||||
if azure_ad_token_provider is None and azure_ad_token is None:
|
||||
azure_ad_token_provider = (
|
||||
BaseAzureLLM._try_get_default_azure_credential_provider(
|
||||
scope=scope,
|
||||
)
|
||||
)
|
||||
|
||||
# Execute the token provider to get the token if available
|
||||
if azure_ad_token_provider and callable(azure_ad_token_provider):
|
||||
try:
|
||||
token = azure_ad_token_provider()
|
||||
if not isinstance(token, str):
|
||||
verbose_logger.error(
|
||||
f"Azure AD token provider returned non-string value: {type(token)}"
|
||||
)
|
||||
raise TypeError(f"Azure AD token must be a string, got {type(token)}")
|
||||
else:
|
||||
azure_ad_token = token
|
||||
except TypeError:
|
||||
# Re-raise TypeError directly
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error calling Azure AD token provider: {str(e)}")
|
||||
raise RuntimeError(f"Failed to get Azure AD token: {str(e)}") from e
|
||||
|
||||
return azure_ad_token
|
||||
|
||||
|
||||
class BaseAzureLLM(BaseOpenAILLM):
|
||||
@staticmethod
|
||||
def _try_get_default_azure_credential_provider(
|
||||
scope: str,
|
||||
) -> Optional[Callable[[], str]]:
|
||||
"""
|
||||
Try to get DefaultAzureCredential provider
|
||||
|
||||
Args:
|
||||
scope: Azure scope for the token
|
||||
|
||||
Returns:
|
||||
Token provider callable if DefaultAzureCredential is enabled and available, None otherwise
|
||||
"""
|
||||
from litellm.types.secret_managers.get_azure_ad_token_provider import (
|
||||
AzureCredentialType,
|
||||
)
|
||||
|
||||
verbose_logger.debug("Attempting to use DefaultAzureCredential for Azure Auth")
|
||||
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider(
|
||||
azure_scope=scope,
|
||||
azure_credential=AzureCredentialType.DefaultAzureCredential,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Successfully obtained Azure AD token provider using DefaultAzureCredential"
|
||||
)
|
||||
return azure_ad_token_provider
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"DefaultAzureCredential failed: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_azure_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
_is_async: bool = False,
|
||||
model: Optional[str] = None,
|
||||
) -> Optional[Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]]:
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None
|
||||
client_initialization_params: dict = locals()
|
||||
client_initialization_params["is_async"] = _is_async
|
||||
if client is None:
|
||||
cached_client = self.get_cached_openai_client(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
if cached_client:
|
||||
if isinstance(
|
||||
cached_client, (AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI)
|
||||
):
|
||||
return cached_client
|
||||
|
||||
azure_client_params = self.initialize_azure_sdk_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
model_name=model,
|
||||
api_version=api_version,
|
||||
is_async=_is_async,
|
||||
)
|
||||
|
||||
# For Azure v1 API, use standard OpenAI client instead of AzureOpenAI
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#api-specs
|
||||
if self._is_azure_v1_api_version(api_version):
|
||||
# Extract only params that OpenAI client accepts
|
||||
# Always use /openai/v1/ regardless of whether user passed "v1", "latest", or "preview"
|
||||
v1_params = {
|
||||
"api_key": azure_client_params.get("api_key"),
|
||||
"base_url": f"{api_base}/openai/v1/",
|
||||
}
|
||||
if "timeout" in azure_client_params:
|
||||
v1_params["timeout"] = azure_client_params["timeout"]
|
||||
if "max_retries" in azure_client_params:
|
||||
v1_params["max_retries"] = azure_client_params["max_retries"]
|
||||
if "http_client" in azure_client_params:
|
||||
v1_params["http_client"] = azure_client_params["http_client"]
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Using Azure v1 API with base_url: {v1_params['base_url']}"
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
openai_client = AsyncOpenAI(**v1_params) # type: ignore
|
||||
else:
|
||||
openai_client = OpenAI(**v1_params) # type: ignore
|
||||
else:
|
||||
# Traditional Azure API uses AzureOpenAI client
|
||||
if _is_async is True:
|
||||
openai_client = AsyncAzureOpenAI(**azure_client_params)
|
||||
else:
|
||||
openai_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
if (
|
||||
api_version is not None
|
||||
and isinstance(openai_client, (AzureOpenAI, AsyncAzureOpenAI))
|
||||
and isinstance(openai_client._custom_query, dict)
|
||||
):
|
||||
# set api_version to version passed by user
|
||||
openai_client._custom_query.setdefault("api-version", api_version)
|
||||
|
||||
# save client in-memory cache
|
||||
self.set_cached_openai_client(
|
||||
openai_client=openai_client,
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type="azure",
|
||||
)
|
||||
return openai_client
|
||||
|
||||
def initialize_azure_sdk_client(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
model_name: Optional[str],
|
||||
api_version: Optional[str],
|
||||
is_async: bool,
|
||||
) -> dict:
|
||||
azure_ad_token_provider = litellm_params.get("azure_ad_token_provider")
|
||||
# If we have api_key, then we have higher priority
|
||||
azure_ad_token = litellm_params.get("azure_ad_token")
|
||||
|
||||
# litellm_params sometimes contains the key, but the value is None
|
||||
# We should respect environment variables in this case
|
||||
tenant_id = self._resolve_env_var(
|
||||
litellm_params, "tenant_id", "AZURE_TENANT_ID"
|
||||
)
|
||||
client_id = self._resolve_env_var(
|
||||
litellm_params, "client_id", "AZURE_CLIENT_ID"
|
||||
)
|
||||
client_secret = self._resolve_env_var(
|
||||
litellm_params, "client_secret", "AZURE_CLIENT_SECRET"
|
||||
)
|
||||
azure_username = self._resolve_env_var(
|
||||
litellm_params, "azure_username", "AZURE_USERNAME"
|
||||
)
|
||||
azure_password = self._resolve_env_var(
|
||||
litellm_params, "azure_password", "AZURE_PASSWORD"
|
||||
)
|
||||
scope = self._resolve_env_var(litellm_params, "azure_scope", "AZURE_SCOPE")
|
||||
if scope is None:
|
||||
scope = "https://cognitiveservices.azure.com/.default"
|
||||
|
||||
max_retries = litellm_params.get("max_retries")
|
||||
timeout = litellm_params.get("timeout")
|
||||
if (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and tenant_id
|
||||
and client_id
|
||||
and client_secret
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD Token Provider from Entra ID for Azure Auth"
|
||||
)
|
||||
azure_ad_token_provider = get_azure_ad_token_from_entra_id(
|
||||
tenant_id=tenant_id,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
scope=scope,
|
||||
)
|
||||
if (
|
||||
azure_ad_token_provider is None
|
||||
and azure_username
|
||||
and azure_password
|
||||
and client_id
|
||||
):
|
||||
verbose_logger.debug("Using Azure Username and Password for Azure Auth")
|
||||
azure_ad_token_provider = get_azure_ad_token_from_username_password(
|
||||
azure_username=azure_username,
|
||||
azure_password=azure_password,
|
||||
client_id=client_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
if azure_ad_token is not None and azure_ad_token.startswith("oidc/"):
|
||||
verbose_logger.debug("Using Azure OIDC Token for Azure Auth")
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_client_id=client_id,
|
||||
azure_tenant_id=tenant_id,
|
||||
scope=scope,
|
||||
)
|
||||
elif (
|
||||
not api_key
|
||||
and azure_ad_token_provider is None
|
||||
and litellm.enable_azure_ad_token_refresh is True
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Using Azure AD token provider based on Service Principal with Secret workflow for Azure Auth"
|
||||
)
|
||||
try:
|
||||
azure_ad_token_provider = get_azure_ad_token_provider(
|
||||
azure_scope=scope,
|
||||
)
|
||||
except ValueError:
|
||||
verbose_logger.debug("Azure AD Token Provider could not be used.")
|
||||
if api_version is None:
|
||||
api_version = os.getenv(
|
||||
"AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
_api_key = api_key
|
||||
if _api_key is not None and isinstance(_api_key, str):
|
||||
# only show first 5 chars of api_key
|
||||
_api_key = _api_key[:8] + "*" * 15
|
||||
verbose_logger.debug(
|
||||
f"Initializing Azure OpenAI Client for {model_name}, Api Base: {str(api_base)}, Api Key:{_api_key}"
|
||||
)
|
||||
azure_client_params = {
|
||||
"api_key": api_key,
|
||||
"azure_endpoint": api_base,
|
||||
"api_version": api_version,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
"azure_ad_token_provider": azure_ad_token_provider,
|
||||
}
|
||||
# init http client + SSL Verification settings
|
||||
if is_async is True:
|
||||
azure_client_params["http_client"] = self._get_async_http_client()
|
||||
else:
|
||||
azure_client_params["http_client"] = self._get_sync_http_client()
|
||||
|
||||
if max_retries is not None:
|
||||
azure_client_params["max_retries"] = max_retries
|
||||
if timeout is not None:
|
||||
azure_client_params["timeout"] = timeout
|
||||
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
# this decides if we should set azure_endpoint or base_url on Azure OpenAI Client
|
||||
# required to support GPT-4 vision enhancements, since base_url needs to be set on Azure OpenAI Client
|
||||
|
||||
azure_client_params = select_azure_base_url_or_endpoint(
|
||||
azure_client_params=azure_client_params
|
||||
)
|
||||
|
||||
return azure_client_params
|
||||
|
||||
def _init_azure_client_for_cloudflare_ai_gateway(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
api_version: str,
|
||||
max_retries: int,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str],
|
||||
azure_ad_token: Optional[str],
|
||||
azure_ad_token_provider: Optional[Callable[[], str]],
|
||||
acompletion: bool,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
) -> Union[AzureOpenAI, AsyncAzureOpenAI]:
|
||||
## build base url - assume api base includes resource name
|
||||
tenant_id = litellm_params.get("tenant_id", os.getenv("AZURE_TENANT_ID"))
|
||||
client_id = litellm_params.get("client_id", os.getenv("AZURE_CLIENT_ID"))
|
||||
scope = litellm_params.get(
|
||||
"azure_scope",
|
||||
os.getenv("AZURE_SCOPE", "https://cognitiveservices.azure.com/.default"),
|
||||
)
|
||||
if client is None:
|
||||
if not api_base.endswith("/"):
|
||||
api_base += "/"
|
||||
api_base += f"{model}"
|
||||
|
||||
azure_client_params: Dict[str, Any] = {
|
||||
"api_version": api_version,
|
||||
"base_url": f"{api_base}",
|
||||
"http_client": litellm.client_session,
|
||||
"max_retries": max_retries,
|
||||
"timeout": timeout,
|
||||
}
|
||||
if api_key is not None:
|
||||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
if azure_ad_token.startswith("oidc/"):
|
||||
azure_ad_token = get_azure_ad_token_from_oidc(
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_client_id=client_id,
|
||||
azure_tenant_id=tenant_id,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if azure_ad_token_provider is not None:
|
||||
azure_client_params["azure_ad_token_provider"] = azure_ad_token_provider
|
||||
|
||||
if acompletion is True:
|
||||
client = AsyncAzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
return client
|
||||
|
||||
@staticmethod
|
||||
def _base_validate_azure_environment(
|
||||
headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
|
||||
# Check if api-key is already in headers; if so, use it
|
||||
if "api-key" in headers:
|
||||
return headers
|
||||
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
if api_key:
|
||||
headers["api-key"] = api_key
|
||||
return headers
|
||||
|
||||
### Fallback to Azure AD token-based authentication if no API key is available
|
||||
### Retrieves Azure AD token and adds it to the Authorization header
|
||||
azure_ad_token = get_azure_ad_token(litellm_params)
|
||||
if azure_ad_token:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _get_base_azure_url(
|
||||
api_base: Optional[str],
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, Dict[str, Any]]],
|
||||
route: Union[Literal["/openai/responses", "/openai/vector_stores"], str],
|
||||
default_api_version: Optional[Union[str, Literal["latest", "preview"]]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the base Azure URL for the given route and API version.
|
||||
|
||||
Args:
|
||||
api_base: The base URL of the Azure API.
|
||||
litellm_params: The litellm parameters.
|
||||
route: The route to the API.
|
||||
default_api_version: The default API version to use if no api_version is provided. If 'latest', it will use `openai/v1/...` route.
|
||||
"""
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
litellm_params = litellm_params or {}
|
||||
api_version = (
|
||||
cast(Optional[str], litellm_params.get("api_version"))
|
||||
or default_api_version
|
||||
)
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL
|
||||
if route not in api_base:
|
||||
new_url = _add_path_to_api_base(api_base=api_base, ending_path=route)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
if BaseAzureLLM._is_azure_v1_api_version(api_version):
|
||||
# ensure the request go to /openai/v1 and not just /openai
|
||||
if "/openai/v1" not in new_url:
|
||||
parsed_url = httpx.URL(new_url)
|
||||
new_url = str(
|
||||
parsed_url.copy_with(
|
||||
path=parsed_url.path.replace("/openai", "/openai/v1")
|
||||
)
|
||||
)
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
|
||||
@staticmethod
|
||||
def _is_azure_v1_api_version(api_version: Optional[str]) -> bool:
|
||||
if api_version is None:
|
||||
return False
|
||||
return api_version in {"preview", "latest", "v1"}
|
||||
|
||||
def _resolve_env_var(
|
||||
self, litellm_params: Dict[str, Any], param_key: str, env_var_key: str
|
||||
) -> Optional[str]:
|
||||
"""Resolve the environment variable for a given parameter key.
|
||||
|
||||
The logic here is different from `params.get(key, os.getenv(env_var))` because
|
||||
litellm_params may contain the key with a None value, in which case we want
|
||||
to fallback to the environment variable.
|
||||
"""
|
||||
param_value = litellm_params.get(param_key)
|
||||
if param_value is not None:
|
||||
return param_value
|
||||
return os.getenv(env_var_key)
|
||||
|
||||
|
||||
class AzureCredentials(NamedTuple):
|
||||
api_base: Optional[str]
|
||||
api_key: Optional[str]
|
||||
api_version: Optional[str]
|
||||
|
||||
|
||||
def get_azure_credentials(
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> AzureCredentials:
|
||||
"""Resolve Azure credentials from params, litellm globals, and env vars."""
|
||||
resolved_api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
resolved_api_version = (
|
||||
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
resolved_api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
return AzureCredentials(
|
||||
api_base=resolved_api_base,
|
||||
api_key=resolved_api_key,
|
||||
api_version=resolved_api_version,
|
||||
)
|
||||
@@ -0,0 +1,379 @@
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import prompt_factory
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||
from ..common_utils import AzureOpenAIError, BaseAzureLLM
|
||||
|
||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||
|
||||
|
||||
class AzureTextCompletion(BaseAzureLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(self, api_key, azure_ad_token):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key is not None:
|
||||
headers["api-key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
headers["Authorization"] = f"Bearer {azure_ad_token}"
|
||||
return headers
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
model_response: ModelResponse,
|
||||
api_key: Optional[str],
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_type: str,
|
||||
azure_ad_token: Optional[str],
|
||||
azure_ad_token_provider: Optional[Callable],
|
||||
print_verbose: Callable,
|
||||
timeout,
|
||||
logging_obj,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
logger_fn,
|
||||
acompletion: bool = False,
|
||||
headers: Optional[dict] = None,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="Missing model or messages"
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
prompt = prompt_factory(
|
||||
messages=messages, model=model, custom_llm_provider="azure_text"
|
||||
)
|
||||
|
||||
### CHECK IF CLOUDFLARE AI GATEWAY ###
|
||||
### if so - set the model as part of the base url
|
||||
if api_base is not None and "gateway.ai.cloudflare.com" in api_base:
|
||||
## build base url - assume api base includes resource name
|
||||
client = self._init_azure_client_for_cloudflare_ai_gateway(
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout,
|
||||
azure_ad_token=azure_ad_token,
|
||||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = {"model": None, "prompt": prompt, **optional_params}
|
||||
else:
|
||||
data = {
|
||||
"model": model, # type: ignore
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
model=model,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
max_retries=max_retries,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
data=data,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=azure_ad_token,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": {
|
||||
"api_key": api_key,
|
||||
"azure_ad_token": azure_ad_token,
|
||||
},
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
_is_async=False,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
raw_response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
stringified_response = response.model_dump()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return (
|
||||
openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=TextCompletionResponse(**stringified_response),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_version: str,
|
||||
model: str,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
timeout: Any,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Any,
|
||||
max_retries: int,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None, # this is the AsyncAzureOpenAI
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
# setting Azure client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = await azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
return openai_text_completion_config.convert_to_chat_model_response_object(
|
||||
response_object=response.model_dump(),
|
||||
model_response_object=model_response,
|
||||
)
|
||||
except AzureOpenAIError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message="max retries must be an int"
|
||||
)
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=False,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AzureOpenAI",
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
api_version: str,
|
||||
data: dict,
|
||||
model: str,
|
||||
timeout: Any,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
litellm_params: dict = {},
|
||||
):
|
||||
try:
|
||||
# init AzureOpenAI Client
|
||||
azure_client = self.get_azure_openai_client(
|
||||
api_version=api_version,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
_is_async=True,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
if not isinstance(azure_client, AsyncAzureOpenAI):
|
||||
raise AzureOpenAIError(
|
||||
status_code=500,
|
||||
message="azure_client is not an instance of AsyncAzureOpenAI",
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data["prompt"],
|
||||
api_key=azure_client.api_key,
|
||||
additional_args={
|
||||
"headers": {"Authorization": f"Bearer {azure_client.api_key}"},
|
||||
"api_base": azure_client._base_url._uri_reference,
|
||||
"acompletion": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
raw_response = await azure_client.completions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
response = raw_response.parse()
|
||||
# return response
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="azure_text",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streamwrapper ## DO NOT make this into an async for ... loop, it will yield an async generator, which won't raise errors if the response fails
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise AzureOpenAIError(
|
||||
status_code=status_code, message=str(e), headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,53 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||
|
||||
|
||||
class AzureOpenAITextConfig(OpenAITextCompletionConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `AzureOpenAIConfig` provides configuration for the OpenAI's Chat API interface, for use with Azure. It inherits from `OpenAIConfig`. Below are the parameters::
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
stop=stop,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
)
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Helper util for handling azure openai-specific cost calculation
|
||||
- e.g.: prompt caching, audio tokens
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str, usage: Usage, response_time_ms: Optional[float] = 0.0
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing caching and audio token information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(model=model, custom_llm_provider="azure")
|
||||
|
||||
## Speech / Audio cost calculation (cost per second for TTS models)
|
||||
if (
|
||||
"output_cost_per_second" in model_info
|
||||
and model_info["output_cost_per_second"] is not None
|
||||
and response_time_ms is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; response time: {response_time_ms}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
prompt_cost = 0.0
|
||||
completion_cost = model_info["output_cost_per_second"] * response_time_ms / 1000
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
## Use generic cost calculator for all other cases
|
||||
## This properly handles: text tokens, audio tokens, cached tokens, reasoning tokens, etc.
|
||||
return generic_cost_per_token(
|
||||
model=model,
|
||||
usage=usage,
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
@@ -0,0 +1,91 @@
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from litellm.exceptions import ContentPolicyViolationError
|
||||
|
||||
|
||||
class AzureOpenAIExceptionMapping:
|
||||
"""
|
||||
Class for creating Azure OpenAI specific exceptions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_content_policy_violation_error(
|
||||
message: str,
|
||||
model: str,
|
||||
extra_information: str,
|
||||
original_exception: Exception,
|
||||
) -> ContentPolicyViolationError:
|
||||
"""
|
||||
Create a content policy violation error
|
||||
"""
|
||||
azure_error, inner_error = AzureOpenAIExceptionMapping._extract_azure_error(
|
||||
original_exception
|
||||
)
|
||||
|
||||
# Prefer the provider message/type/code when present.
|
||||
provider_message = (
|
||||
azure_error.get("message") if isinstance(azure_error, dict) else None
|
||||
) or message
|
||||
provider_type = (
|
||||
azure_error.get("type") if isinstance(azure_error, dict) else None
|
||||
)
|
||||
provider_code = (
|
||||
azure_error.get("code") if isinstance(azure_error, dict) else None
|
||||
)
|
||||
|
||||
# Keep the OpenAI-style body fields populated so downstream (proxy + SDK)
|
||||
# can surface `type` / `code` correctly.
|
||||
openai_style_body: Dict[str, Any] = {
|
||||
"message": provider_message,
|
||||
"type": provider_type or "invalid_request_error",
|
||||
"code": provider_code or "content_policy_violation",
|
||||
"param": None,
|
||||
}
|
||||
|
||||
raise ContentPolicyViolationError(
|
||||
message=provider_message,
|
||||
llm_provider="azure",
|
||||
model=model,
|
||||
litellm_debug_info=extra_information,
|
||||
response=getattr(original_exception, "response", None),
|
||||
provider_specific_fields={
|
||||
# Preserve legacy key for backward compatibility.
|
||||
"innererror": inner_error,
|
||||
# Prefer Azure's current naming.
|
||||
"inner_error": inner_error,
|
||||
# Include the full Azure error object for clients that want it.
|
||||
"azure_error": azure_error or None,
|
||||
},
|
||||
body=openai_style_body,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_azure_error(
|
||||
original_exception: Exception,
|
||||
) -> Tuple[Dict[str, Any], Optional[dict]]:
|
||||
"""Extract Azure OpenAI error payload and inner error details.
|
||||
|
||||
Azure error formats can vary by endpoint/version. Common shapes:
|
||||
- {"innererror": {...}} (legacy)
|
||||
- {"error": {"code": "...", "message": "...", "type": "...", "inner_error": {...}}}
|
||||
- {"code": "...", "message": "...", "type": "..."} (already flattened)
|
||||
"""
|
||||
body_dict = getattr(original_exception, "body", None) or {}
|
||||
if not isinstance(body_dict, dict):
|
||||
return {}, None
|
||||
|
||||
# Some SDKs place the payload under "error".
|
||||
azure_error: Dict[str, Any]
|
||||
if isinstance(body_dict.get("error"), dict):
|
||||
azure_error = body_dict.get("error", {}) # type: ignore[assignment]
|
||||
else:
|
||||
azure_error = body_dict
|
||||
|
||||
inner_error = (
|
||||
azure_error.get("inner_error")
|
||||
or azure_error.get("innererror")
|
||||
or body_dict.get("innererror")
|
||||
or body_dict.get("inner_error")
|
||||
)
|
||||
|
||||
return azure_error, inner_error
|
||||
@@ -0,0 +1,308 @@
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import *
|
||||
|
||||
from ..common_utils import BaseAzureLLM
|
||||
|
||||
|
||||
class AzureOpenAIFilesAPI(BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support for batches
|
||||
- create_file()
|
||||
- retrieve_file()
|
||||
- list_files()
|
||||
- delete_file()
|
||||
- file_content()
|
||||
- update_file()
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
@staticmethod
|
||||
def _prepare_create_file_data(
|
||||
create_file_data: CreateFileRequest,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Prepare create_file_data for OpenAI SDK.
|
||||
|
||||
Removes expires_after if None to match SDK's Omit pattern.
|
||||
SDK expects file_create_params.ExpiresAfter | Omit, but FileExpiresAfter works at runtime.
|
||||
"""
|
||||
data = dict(create_file_data)
|
||||
if data.get("expires_after") is None:
|
||||
data.pop("expires_after", None)
|
||||
return data
|
||||
|
||||
async def acreate_file(
|
||||
self,
|
||||
create_file_data: CreateFileRequest,
|
||||
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> OpenAIFileObject:
|
||||
verbose_logger.debug("create_file_data=%s", create_file_data)
|
||||
response = await openai_client.files.create(**self._prepare_create_file_data(create_file_data)) # type: ignore[arg-type]
|
||||
verbose_logger.debug("create_file_response=%s", response)
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
|
||||
def create_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.acreate_file(
|
||||
create_file_data=create_file_data, openai_client=openai_client
|
||||
)
|
||||
response = cast(Union[AzureOpenAI, OpenAI], openai_client).files.create(**self._prepare_create_file_data(create_file_data)) # type: ignore[arg-type]
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> HttpxBinaryResponseContent:
|
||||
response = await openai_client.files.content(**file_content_request)
|
||||
return HttpxBinaryResponseContent(response=response.response)
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.afile_content( # type: ignore
|
||||
file_content_request=file_content_request,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = cast(Union[AzureOpenAI, OpenAI], openai_client).files.content(
|
||||
**file_content_request
|
||||
)
|
||||
|
||||
return HttpxBinaryResponseContent(response=response.response)
|
||||
|
||||
async def aretrieve_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> FileObject:
|
||||
response = await openai_client.files.retrieve(file_id=file_id)
|
||||
return response
|
||||
|
||||
def retrieve_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.aretrieve_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.retrieve(file_id=file_id)
|
||||
|
||||
return response
|
||||
|
||||
async def adelete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
) -> FileDeleted:
|
||||
response = await openai_client.files.delete(file_id=file_id)
|
||||
|
||||
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
return response
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.adelete_file( # type: ignore
|
||||
file_id=file_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
response = openai_client.files.delete(file_id=file_id)
|
||||
|
||||
if not isinstance(response, FileDeleted): # azure returns an empty string
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
|
||||
return response
|
||||
|
||||
async def alist_files(
|
||||
self,
|
||||
openai_client: Union[AsyncAzureOpenAI, AsyncOpenAI],
|
||||
purpose: Optional[str] = None,
|
||||
):
|
||||
if isinstance(purpose, str):
|
||||
response = await openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = await openai_client.files.list()
|
||||
return response
|
||||
|
||||
def list_files(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
purpose: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[AzureOpenAI, AsyncAzureOpenAI, OpenAI, AsyncOpenAI]
|
||||
] = self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncAzureOpenAI, AsyncOpenAI)):
|
||||
raise ValueError(
|
||||
"AzureOpenAI client is not an instance of AsyncAzureOpenAI. Make sure you passed an AsyncAzureOpenAI client."
|
||||
)
|
||||
return self.alist_files( # type: ignore
|
||||
purpose=purpose,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
|
||||
if isinstance(purpose, str):
|
||||
response = openai_client.files.list(purpose=purpose)
|
||||
else:
|
||||
response = openai_client.files.list()
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,40 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
|
||||
|
||||
class AzureOpenAIFineTuningAPI(OpenAIFineTuningAPI, BaseAzureLLM):
|
||||
"""
|
||||
AzureOpenAI methods to support fine tuning, inherits from OpenAIFineTuningAPI.
|
||||
"""
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||
# Override to use Azure-specific client initialization
|
||||
if isinstance(client, OpenAI) or isinstance(client, AsyncOpenAI):
|
||||
client = None
|
||||
|
||||
return self.get_azure_openai_client(
|
||||
litellm_params=litellm_params or {},
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
)
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.openai.image_edit.transformation import OpenAIImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import _add_path_to_api_base
|
||||
|
||||
|
||||
class AzureImageEditConfig(OpenAIImageEditConfig):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.openai.azure.com/openai/deployments/<deployment_name>/images/edits?api-version=2024-05-01-preview"
|
||||
- model: Model name (deployment name).
|
||||
- litellm_params: Additional query parameters, including "api_version".
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com/openai/deployments/<deployment_name>/images/edits?api-version=2024-05-01-preview"
|
||||
"""
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AI Studio. Please set the api_base parameter. Passed `api_base={api_base}`"
|
||||
)
|
||||
original_url = httpx.URL(api_base)
|
||||
|
||||
# Extract api_version or use default
|
||||
api_version = cast(Optional[str], litellm_params.get("api_version"))
|
||||
|
||||
# Create a new dictionary with existing params
|
||||
query_params = dict(original_url.params)
|
||||
|
||||
# Add api_version if needed
|
||||
if "api-version" not in query_params and api_version:
|
||||
query_params["api-version"] = api_version
|
||||
|
||||
# Add the path to the base URL using the model as deployment name
|
||||
if "/openai/deployments/" not in api_base:
|
||||
new_url = _add_path_to_api_base(
|
||||
api_base=api_base,
|
||||
ending_path=f"/openai/deployments/{model}/images/edits",
|
||||
)
|
||||
else:
|
||||
new_url = api_base
|
||||
|
||||
# Use the new query_params dictionary
|
||||
final_url = httpx.URL(new_url).copy_with(params=query_params)
|
||||
|
||||
return str(final_url)
|
||||
@@ -0,0 +1,29 @@
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .dall_e_2_transformation import AzureDallE2ImageGenerationConfig
|
||||
from .dall_e_3_transformation import AzureDallE3ImageGenerationConfig
|
||||
from .gpt_transformation import AzureGPTImageGenerationConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureDallE2ImageGenerationConfig",
|
||||
"AzureDallE3ImageGenerationConfig",
|
||||
"AzureGPTImageGenerationConfig",
|
||||
]
|
||||
|
||||
|
||||
def get_azure_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
model = model.lower()
|
||||
model = model.replace("-", "")
|
||||
model = model.replace("_", "")
|
||||
if model == "" or "dalle2" in model: # empty model is dall-e-2
|
||||
return AzureDallE2ImageGenerationConfig()
|
||||
elif "dalle3" in model:
|
||||
return AzureDallE3ImageGenerationConfig()
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"Using AzureGPTImageGenerationConfig for model: {model}. This follows the gpt-image-1 model format."
|
||||
)
|
||||
return AzureGPTImageGenerationConfig()
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import DallE2ImageGenerationConfig
|
||||
|
||||
|
||||
class AzureDallE2ImageGenerationConfig(DallE2ImageGenerationConfig):
|
||||
"""
|
||||
Azure dall-e-2 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import DallE3ImageGenerationConfig
|
||||
|
||||
|
||||
class AzureDallE3ImageGenerationConfig(DallE3ImageGenerationConfig):
|
||||
"""
|
||||
Azure dall-e-3 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
from litellm.llms.openai.image_generation import GPTImageGenerationConfig
|
||||
|
||||
|
||||
class AzureGPTImageGenerationConfig(GPTImageGenerationConfig):
|
||||
"""
|
||||
Azure gpt-image-1 image generation config
|
||||
"""
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,85 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import URL
|
||||
|
||||
|
||||
class AzurePassthroughConfig(BasePassthroughConfig):
|
||||
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
|
||||
return "stream" in request_data
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
endpoint: str,
|
||||
request_query_params: Optional[dict],
|
||||
litellm_params: dict,
|
||||
) -> Tuple["URL", str]:
|
||||
base_target_url = self.get_api_base(api_base)
|
||||
|
||||
if base_target_url is None:
|
||||
raise Exception("Azure api base not found")
|
||||
|
||||
litellm_metadata = litellm_params.get("litellm_metadata") or {}
|
||||
model_group = litellm_metadata.get("model_group")
|
||||
if model_group and model_group in endpoint:
|
||||
endpoint = endpoint.replace(model_group, model)
|
||||
|
||||
complete_url = BaseAzureLLM._get_base_azure_url(
|
||||
api_base=base_target_url,
|
||||
litellm_params=litellm_params,
|
||||
route=endpoint,
|
||||
default_api_version=litellm_params.get("api_version"),
|
||||
)
|
||||
return (
|
||||
httpx.URL(complete_url),
|
||||
base_target_url,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers,
|
||||
litellm_params=GenericLiteLLMParams(
|
||||
**{**litellm_params, "api_key": api_key}
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
return api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(
|
||||
api_key: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
return api_key or get_secret_str("AZURE_API_KEY")
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
return super().get_models(api_key, api_base)
|
||||
@@ -0,0 +1,126 @@
|
||||
"""
|
||||
This file contains the calling Azure OpenAI's `/openai/realtime` endpoint.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ....llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
|
||||
from ..azure import AzureChatCompletion
|
||||
|
||||
# BACKEND_WS_URL = "ws://localhost:8080/v1/realtime?model=gpt-4o-realtime-preview-2024-10-01"
|
||||
|
||||
|
||||
async def forward_messages(client_ws: Any, backend_ws: Any):
|
||||
import websockets
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await backend_ws.recv()
|
||||
await client_ws.send_text(message)
|
||||
except websockets.exceptions.ConnectionClosed: # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
class AzureOpenAIRealtime(AzureChatCompletion):
|
||||
def _construct_url(
|
||||
self,
|
||||
api_base: str,
|
||||
model: str,
|
||||
api_version: Optional[str],
|
||||
realtime_protocol: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Construct Azure realtime WebSocket URL.
|
||||
|
||||
Args:
|
||||
api_base: Azure API base URL (will be converted from https:// to wss://)
|
||||
model: Model deployment name
|
||||
api_version: Azure API version
|
||||
realtime_protocol: Protocol version to use:
|
||||
- "GA" or "v1": Uses /openai/v1/realtime (GA path)
|
||||
- "beta" or None: Uses /openai/realtime (beta path, default)
|
||||
|
||||
Returns:
|
||||
WebSocket URL string
|
||||
|
||||
Examples:
|
||||
beta/default: "wss://.../openai/realtime?api-version=2024-10-01-preview&deployment=gpt-4o-realtime-preview"
|
||||
GA/v1: "wss://.../openai/v1/realtime?model=gpt-realtime-deployment"
|
||||
"""
|
||||
api_base = api_base.replace("https://", "wss://")
|
||||
|
||||
# Determine path based on realtime_protocol (case-insensitive)
|
||||
_is_ga = realtime_protocol is not None and realtime_protocol.upper() in (
|
||||
"GA",
|
||||
"V1",
|
||||
)
|
||||
if _is_ga:
|
||||
path = "/openai/v1/realtime"
|
||||
return f"{api_base}{path}?model={model}"
|
||||
else:
|
||||
# Default to beta path for backwards compatibility
|
||||
path = "/openai/realtime"
|
||||
return f"{api_base}{path}?api-version={api_version}&deployment={model}"
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
realtime_protocol: Optional[str] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
litellm_metadata: Optional[dict] = None,
|
||||
):
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure OpenAI calls")
|
||||
if api_version is None and (
|
||||
realtime_protocol is None or realtime_protocol.upper() not in ("GA", "V1")
|
||||
):
|
||||
raise ValueError("api_version is required for Azure OpenAI calls")
|
||||
|
||||
url = self._construct_url(
|
||||
api_base, model, api_version, realtime_protocol=realtime_protocol
|
||||
)
|
||||
|
||||
try:
|
||||
ssl_context = get_shared_realtime_ssl_context()
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
additional_headers={
|
||||
"api-key": api_key, # type: ignore
|
||||
},
|
||||
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
|
||||
ssl=ssl_context,
|
||||
) as backend_ws:
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket,
|
||||
cast(ClientConnection, backend_ws),
|
||||
logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data={"litellm_metadata": litellm_metadata or {}},
|
||||
)
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in AzureOpenAIRealtime.async_realtime"
|
||||
)
|
||||
pass
|
||||
@@ -0,0 +1,46 @@
|
||||
"""Azure OpenAI realtime HTTP transformation config (client_secrets + realtime_calls)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AzureRealtimeHTTPConfig(BaseRealtimeHTTPConfig):
|
||||
def get_api_base(self, api_base: Optional[str], **kwargs) -> str:
|
||||
return api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") or ""
|
||||
|
||||
def get_api_key(self, api_key: Optional[str], **kwargs) -> str:
|
||||
return api_key or litellm.api_key or get_secret_str("AZURE_API_KEY") or ""
|
||||
|
||||
def get_complete_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
version = api_version or get_secret_str("AZURE_API_VERSION") or "2024-12-17"
|
||||
return f"{base}/openai/realtime/client_secrets?api-version={version}"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {
|
||||
**headers,
|
||||
"api-key": api_key or "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
def get_realtime_calls_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
version = api_version or get_secret_str("AZURE_API_VERSION") or "2024-12-17"
|
||||
return f"{base}/openai/realtime/calls?api-version={version}"
|
||||
|
||||
def get_realtime_calls_headers(self, ephemeral_key: str) -> dict:
|
||||
return {
|
||||
"api-key": ephemeral_key,
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Support for Azure OpenAI O-series models (o1, o3, etc.) in Responses API
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- temperature => drop param (if user opts in to dropping param)
|
||||
- Other parameters follow base Azure OpenAI Responses API behavior
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
from litellm.utils import supports_reasoning
|
||||
|
||||
from .transformation import AzureOpenAIResponsesAPIConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AzureOpenAIOSeriesResponsesAPIConfig(AzureOpenAIResponsesAPIConfig):
|
||||
"""
|
||||
Configuration for Azure OpenAI O-series models in Responses API.
|
||||
|
||||
O-series models (o1, o3, etc.) do not support the temperature parameter
|
||||
in the responses API, so we need to drop it when drop_params is enabled.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get supported parameters for Azure OpenAI O-series Responses API.
|
||||
|
||||
O-series models don't support temperature parameter in responses API.
|
||||
"""
|
||||
# Get the base Azure supported params
|
||||
base_supported_params = super().get_supported_openai_params(model)
|
||||
|
||||
# O-series models don't support temperature parameter in responses API
|
||||
o_series_unsupported_params = ["temperature"]
|
||||
|
||||
# Filter out unsupported parameters for O-series models
|
||||
o_series_supported_params = [
|
||||
param
|
||||
for param in base_supported_params
|
||||
if param not in o_series_unsupported_params
|
||||
]
|
||||
|
||||
return o_series_supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI parameters for Azure OpenAI O-series Responses API.
|
||||
|
||||
Drops temperature parameter if drop_params is True since O-series models
|
||||
don't support temperature in the responses API.
|
||||
"""
|
||||
mapped_params = dict(response_api_optional_params)
|
||||
|
||||
# If drop_params is enabled, remove temperature parameter for O-series models
|
||||
if drop_params and "temperature" in mapped_params:
|
||||
verbose_logger.debug(
|
||||
f"Dropping unsupported parameter 'temperature' for Azure OpenAI O-series responses API model {model}"
|
||||
)
|
||||
mapped_params.pop("temperature", None)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def is_o_series_model(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an O-series model.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
True if it's an O-series model, False otherwise
|
||||
"""
|
||||
# Check if model name contains o_series or if it's a known O-series model
|
||||
if "o_series" in model.lower():
|
||||
return True
|
||||
|
||||
# Check if the model supports reasoning (which is O-series specific)
|
||||
return supports_reasoning(model)
|
||||
@@ -0,0 +1,359 @@
|
||||
from copy import deepcopy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from openai.types.responses import ResponseReasoningItem
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AzureOpenAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
# Parameters not supported by Azure Responses API
|
||||
AZURE_UNSUPPORTED_PARAMS = ["context_management"]
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.AZURE
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Azure Responses API does not support context_management (compaction).
|
||||
"""
|
||||
base_supported_params = super().get_supported_openai_params(model)
|
||||
return [
|
||||
param
|
||||
for param in base_supported_params
|
||||
if param not in self.AZURE_UNSUPPORTED_PARAMS
|
||||
]
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
def get_stripped_model_name(self, model: str) -> str:
|
||||
# if "responses/" is in the model name, remove it
|
||||
if "responses/" in model:
|
||||
model = model.replace("responses/", "")
|
||||
if "o_series" in model:
|
||||
model = model.replace("o_series/", "")
|
||||
return model
|
||||
|
||||
def _handle_reasoning_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle reasoning items to filter out the status field.
|
||||
Issue: https://github.com/BerriAI/litellm/issues/13484
|
||||
|
||||
Azure OpenAI API does not accept 'status' field in reasoning input items.
|
||||
"""
|
||||
if item.get("type") == "reasoning":
|
||||
try:
|
||||
# Ensure required fields are present for ResponseReasoningItem
|
||||
item_data = dict(item)
|
||||
if "summary" not in item_data:
|
||||
item_data["summary"] = (
|
||||
item_data.get("reasoning_content", "")[:100] + "..."
|
||||
if len(item_data.get("reasoning_content", "")) > 100
|
||||
else item_data.get("reasoning_content", "")
|
||||
)
|
||||
|
||||
# Create ResponseReasoningItem object from the item data
|
||||
reasoning_item = ResponseReasoningItem(**item_data)
|
||||
|
||||
# Convert back to dict with exclude_none=True to exclude None fields
|
||||
dict_reasoning_item = reasoning_item.model_dump(exclude_none=True)
|
||||
dict_reasoning_item.pop("status", None)
|
||||
|
||||
return dict_reasoning_item
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to create ResponseReasoningItem, falling back to manual filtering: {e}"
|
||||
)
|
||||
# Fallback: manually filter out known None fields
|
||||
filtered_item = {
|
||||
k: v
|
||||
for k, v in item.items()
|
||||
if v is not None
|
||||
or k not in {"status", "content", "encrypted_content"}
|
||||
}
|
||||
return filtered_item
|
||||
return item
|
||||
|
||||
def _validate_input_param(
|
||||
self, input: Union[str, ResponseInputParam]
|
||||
) -> Union[str, ResponseInputParam]:
|
||||
"""
|
||||
Override parent method to also filter out 'status' field from message items.
|
||||
Azure OpenAI API does not accept 'status' field in input messages.
|
||||
"""
|
||||
from typing import cast
|
||||
|
||||
# First call parent's validation
|
||||
validated_input = super()._validate_input_param(input)
|
||||
|
||||
# Then filter out status from message items
|
||||
if isinstance(validated_input, list):
|
||||
filtered_input: List[Any] = []
|
||||
for item in validated_input:
|
||||
if isinstance(item, dict) and item.get("type") == "message":
|
||||
# Filter out status field from message items
|
||||
filtered_item = {k: v for k, v in item.items() if k != "status"}
|
||||
filtered_input.append(filtered_item)
|
||||
else:
|
||||
filtered_input.append(item)
|
||||
return cast(ResponseInputParam, filtered_input)
|
||||
|
||||
return validated_input
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""No transform applied since inputs are in OpenAI spec already"""
|
||||
stripped_model_name = self.get_stripped_model_name(model)
|
||||
|
||||
# Azure Responses API requires flattened tools (params at top level, not nested in 'function')
|
||||
if "tools" in response_api_optional_request_params and isinstance(
|
||||
response_api_optional_request_params["tools"], list
|
||||
):
|
||||
new_tools: List[Dict[str, Any]] = []
|
||||
for tool in response_api_optional_request_params["tools"]:
|
||||
if isinstance(tool, dict) and "function" in tool:
|
||||
new_tool: Dict[str, Any] = deepcopy(tool)
|
||||
function_data = new_tool.pop("function")
|
||||
new_tool.update(function_data)
|
||||
new_tools.append(new_tool)
|
||||
else:
|
||||
new_tools.append(tool)
|
||||
response_api_optional_request_params["tools"] = new_tools
|
||||
|
||||
return super().transform_responses_api_request(
|
||||
model=stripped_model_name,
|
||||
input=input,
|
||||
response_api_optional_request_params=response_api_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
|
||||
Args:
|
||||
- api_base: Base URL, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com"
|
||||
OR
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
- model: Model name.
|
||||
- optional_params: Additional query parameters, including "api_version".
|
||||
- stream: If streaming is required (optional).
|
||||
|
||||
Returns:
|
||||
- A complete URL string, e.g.,
|
||||
"https://litellm8397336933.openai.azure.com/openai/responses?api-version=2024-05-01-preview"
|
||||
"""
|
||||
from litellm.constants import AZURE_DEFAULT_RESPONSES_API_VERSION
|
||||
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/responses",
|
||||
default_api_version=AZURE_DEFAULT_RESPONSES_API_VERSION,
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def _construct_url_for_response_id_in_path(
|
||||
self, api_base: str, response_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a URL for the API request with the response_id in the path.
|
||||
"""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
# Parse the URL to separate its components
|
||||
parsed_url = urlparse(api_base)
|
||||
|
||||
# Insert the response_id at the end of the path component
|
||||
# Remove trailing slash if present to avoid double slashes
|
||||
path = parsed_url.path.rstrip("/")
|
||||
new_path = f"{path}/{response_id}"
|
||||
|
||||
# Reconstruct the URL with all original components but with the modified path
|
||||
constructed_url = urlunparse(
|
||||
(
|
||||
parsed_url.scheme, # http, https
|
||||
parsed_url.netloc, # domain name, port
|
||||
new_path, # path with response_id added
|
||||
parsed_url.params, # parameters
|
||||
parsed_url.query, # query string
|
||||
parsed_url.fragment, # fragment
|
||||
)
|
||||
)
|
||||
return constructed_url
|
||||
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the delete response API request into a URL and data
|
||||
|
||||
Azure OpenAI API expects the following request:
|
||||
- DELETE /openai/responses/{response_id}?api-version=xxx
|
||||
|
||||
This function handles URLs with query parameters by inserting the response_id
|
||||
at the correct location (before any query parameters).
|
||||
"""
|
||||
delete_url = self._construct_url_for_response_id_in_path(
|
||||
api_base=api_base, response_id=response_id
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
verbose_logger.debug(f"delete response url={delete_url}")
|
||||
return delete_url, data
|
||||
|
||||
#########################################################
|
||||
########## GET RESPONSE API TRANSFORMATION ###############
|
||||
#########################################################
|
||||
def transform_get_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the get response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- GET /v1/responses/{response_id}
|
||||
"""
|
||||
get_url = self._construct_url_for_response_id_in_path(
|
||||
api_base=api_base, response_id=response_id
|
||||
)
|
||||
data: Dict = {}
|
||||
verbose_logger.debug(f"get response url={get_url}")
|
||||
return get_url, data
|
||||
|
||||
def transform_list_input_items_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
) -> Tuple[str, Dict]:
|
||||
url = (
|
||||
self._construct_url_for_response_id_in_path(
|
||||
api_base=api_base, response_id=response_id
|
||||
)
|
||||
+ "/input_items"
|
||||
)
|
||||
params: Dict[str, Any] = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
if include:
|
||||
params["include"] = ",".join(include)
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
verbose_logger.debug(f"list input items url={url}")
|
||||
return url, params
|
||||
|
||||
#########################################################
|
||||
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
def transform_cancel_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the cancel response API request into a URL and data
|
||||
|
||||
Azure OpenAI API expects the following request:
|
||||
- POST /openai/responses/{response_id}/cancel?api-version=xxx
|
||||
|
||||
This function handles URLs with query parameters by inserting the response_id
|
||||
at the correct location (before any query parameters).
|
||||
"""
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
# Parse the URL to separate its components
|
||||
parsed_url = urlparse(api_base)
|
||||
|
||||
# Insert the response_id and /cancel at the end of the path component
|
||||
# Remove trailing slash if present to avoid double slashes
|
||||
path = parsed_url.path.rstrip("/")
|
||||
new_path = f"{path}/{response_id}/cancel"
|
||||
|
||||
# Reconstruct the URL with all original components but with the modified path
|
||||
cancel_url = urlunparse(
|
||||
(
|
||||
parsed_url.scheme, # http, https
|
||||
parsed_url.netloc, # domain name, port
|
||||
new_path, # path with response_id and /cancel added
|
||||
parsed_url.params, # parameters
|
||||
parsed_url.query, # query string
|
||||
parsed_url.fragment, # fragment
|
||||
)
|
||||
)
|
||||
|
||||
data: Dict = {}
|
||||
verbose_logger.debug(f"cancel response url={cancel_url}")
|
||||
return cancel_url, data
|
||||
|
||||
def transform_cancel_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform the cancel response API response into a ResponsesAPIResponse
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
from litellm.llms.azure.chat.gpt_transformation import AzureOpenAIError
|
||||
|
||||
raise AzureOpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return ResponsesAPIResponse(**raw_response_json)
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Azure Text-to-Speech module"""
|
||||
|
||||
from .transformation import AzureAVATextToSpeechConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAVATextToSpeechConfig",
|
||||
]
|
||||
@@ -0,0 +1,504 @@
|
||||
"""
|
||||
Azure AVA (Cognitive Services) Text-to-Speech transformation
|
||||
|
||||
Maps OpenAI TTS spec to Azure Cognitive Services TTS API
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.text_to_speech.transformation import (
|
||||
BaseTextToSpeechConfig,
|
||||
TextToSpeechRequestData,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HttpxBinaryResponseContent = Any
|
||||
|
||||
|
||||
class AzureAVATextToSpeechConfig(BaseTextToSpeechConfig):
|
||||
"""
|
||||
Configuration for Azure AVA (Cognitive Services) Text-to-Speech
|
||||
|
||||
Reference: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/rest-text-to-speech
|
||||
"""
|
||||
|
||||
# Azure endpoint domains
|
||||
DEFAULT_VOICE = "en-US-AriaNeural"
|
||||
COGNITIVE_SERVICES_DOMAIN = "api.cognitive.microsoft.com"
|
||||
TTS_SPEECH_DOMAIN = "tts.speech.microsoft.com"
|
||||
TTS_ENDPOINT_PATH = "/cognitiveservices/v1"
|
||||
|
||||
# Voice name mappings from OpenAI voices to Azure voices
|
||||
VOICE_MAPPINGS = {
|
||||
"alloy": "en-US-JennyNeural",
|
||||
"echo": "en-US-GuyNeural",
|
||||
"fable": "en-GB-RyanNeural",
|
||||
"onyx": "en-US-DavisNeural",
|
||||
"nova": "en-US-AmberNeural",
|
||||
"shimmer": "en-US-AriaNeural",
|
||||
}
|
||||
|
||||
# Response format mappings from OpenAI to Azure
|
||||
FORMAT_MAPPINGS = {
|
||||
"mp3": "audio-24khz-48kbitrate-mono-mp3",
|
||||
"opus": "ogg-48khz-16bit-mono-opus",
|
||||
"aac": "audio-24khz-48kbitrate-mono-mp3", # Azure doesn't have AAC, use MP3
|
||||
"flac": "audio-24khz-48kbitrate-mono-mp3", # Azure doesn't have FLAC, use MP3
|
||||
"wav": "riff-24khz-16bit-mono-pcm",
|
||||
"pcm": "raw-24khz-16bit-mono-pcm",
|
||||
}
|
||||
|
||||
def dispatch_text_to_speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, Dict]],
|
||||
optional_params: Dict,
|
||||
litellm_params_dict: Dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
extra_headers: Optional[Dict[str, Any]],
|
||||
base_llm_http_handler: Any,
|
||||
aspeech: bool,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
**kwargs: Any,
|
||||
) -> Union[
|
||||
"HttpxBinaryResponseContent",
|
||||
Coroutine[Any, Any, "HttpxBinaryResponseContent"],
|
||||
]:
|
||||
"""
|
||||
Dispatch method to handle Azure AVA TTS requests
|
||||
|
||||
This method encapsulates Azure-specific credential resolution and parameter handling
|
||||
|
||||
Args:
|
||||
base_llm_http_handler: The BaseLLMHTTPHandler instance from main.py
|
||||
"""
|
||||
# Resolve api_base from multiple sources
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm_params_dict.get("api_base")
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
|
||||
# Resolve api_key from multiple sources (Azure-specific)
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm_params_dict.get("api_key")
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
# Convert voice to string if it's a dict (for Azure AVA, voice must be a string)
|
||||
voice_str: Optional[str] = None
|
||||
if isinstance(voice, str):
|
||||
voice_str = voice
|
||||
elif isinstance(voice, dict):
|
||||
# Extract voice name from dict if needed
|
||||
voice_str = voice.get("name") if voice else None
|
||||
|
||||
litellm_params_dict.update(
|
||||
{
|
||||
"api_key": api_key,
|
||||
"api_base": api_base,
|
||||
}
|
||||
)
|
||||
# Call the text_to_speech_handler
|
||||
response = base_llm_http_handler.text_to_speech_handler(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=voice_str,
|
||||
text_to_speech_provider_config=self,
|
||||
text_to_speech_optional_params=optional_params,
|
||||
custom_llm_provider="azure",
|
||||
litellm_params=litellm_params_dict,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
extra_headers=extra_headers,
|
||||
client=None,
|
||||
_is_async=aspeech,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Azure AVA TTS supports these OpenAI parameters
|
||||
|
||||
Note: Azure also supports additional SSML-specific parameters (style, styledegree, role)
|
||||
which can be passed but are not part of the OpenAI spec
|
||||
"""
|
||||
return ["voice", "response_format", "speed"]
|
||||
|
||||
def _convert_speed_to_azure_rate(self, speed: float) -> str:
|
||||
"""
|
||||
Convert OpenAI speed value to Azure SSML prosody rate percentage
|
||||
|
||||
Args:
|
||||
speed: OpenAI speed value (0.25-4.0, default 1.0)
|
||||
|
||||
Returns:
|
||||
Azure rate string with percentage (e.g., "+50%", "-50%", "+0%")
|
||||
|
||||
Examples:
|
||||
speed=1.0 -> "+0%" (default)
|
||||
speed=2.0 -> "+100%"
|
||||
speed=0.5 -> "-50%"
|
||||
"""
|
||||
rate_percentage = int((speed - 1.0) * 100)
|
||||
return f"{rate_percentage:+d}%"
|
||||
|
||||
def _build_express_as_element(
|
||||
self,
|
||||
content: str,
|
||||
style: Optional[str] = None,
|
||||
styledegree: Optional[str] = None,
|
||||
role: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Build mstts:express-as element with optional style, styledegree, and role attributes
|
||||
|
||||
Args:
|
||||
content: The inner content to wrap
|
||||
style: Speaking style (e.g., "cheerful", "sad", "angry")
|
||||
styledegree: Style intensity (0.01 to 2)
|
||||
role: Voice role (e.g., "Girl", "Boy", "SeniorFemale", "SeniorMale")
|
||||
|
||||
Returns:
|
||||
Content wrapped in mstts:express-as if any attributes provided, otherwise raw content
|
||||
"""
|
||||
if not (style or styledegree or role):
|
||||
return content
|
||||
|
||||
express_as_attrs = []
|
||||
if style:
|
||||
express_as_attrs.append(f"style='{style}'")
|
||||
if styledegree:
|
||||
express_as_attrs.append(f"styledegree='{styledegree}'")
|
||||
if role:
|
||||
express_as_attrs.append(f"role='{role}'")
|
||||
|
||||
express_as_attrs_str = " ".join(express_as_attrs)
|
||||
return f"<mstts:express-as {express_as_attrs_str}>{content}</mstts:express-as>"
|
||||
|
||||
def _get_voice_language(
|
||||
self,
|
||||
voice_name: Optional[str],
|
||||
explicit_lang: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the language for the voice element's xml:lang attribute
|
||||
|
||||
Args:
|
||||
voice_name: The Azure voice name (e.g., "en-US-AriaNeural")
|
||||
explicit_lang: Explicitly provided language code (takes precedence)
|
||||
|
||||
Returns:
|
||||
Language code if available (e.g., "es-ES"), or None
|
||||
|
||||
Examples:
|
||||
- explicit_lang="es-ES" → "es-ES" (explicit takes precedence)
|
||||
- voice_name="en-US-AriaNeural", explicit_lang=None → None (use default from voice)
|
||||
- voice_name="en-US-AvaMultilingualNeural", explicit_lang="fr-FR" → "fr-FR"
|
||||
"""
|
||||
# If explicit language is provided, use it (for multilingual voices)
|
||||
if explicit_lang:
|
||||
return explicit_lang
|
||||
|
||||
# For non-multilingual voices, we don't need to set xml:lang on the voice element
|
||||
# The voice name already encodes the language (e.g., en-US-AriaNeural)
|
||||
# Only return a language if explicitly set
|
||||
return None
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
voice: Optional[Union[str, Dict]] = None,
|
||||
drop_params: bool = False,
|
||||
kwargs: Dict = {},
|
||||
) -> Tuple[Optional[str], Dict]:
|
||||
"""
|
||||
Map OpenAI parameters to Azure AVA TTS parameters
|
||||
"""
|
||||
mapped_params = {}
|
||||
##########################################################
|
||||
# Map voice
|
||||
# OpenAI uses voice as a required param, hence not in optional_params
|
||||
##########################################################
|
||||
# If it's already an Azure voice, use it directly
|
||||
mapped_voice: Optional[str] = None
|
||||
if isinstance(voice, str):
|
||||
if voice in self.VOICE_MAPPINGS:
|
||||
mapped_voice = self.VOICE_MAPPINGS[voice]
|
||||
else:
|
||||
# Assume it's already an Azure voice name
|
||||
mapped_voice = voice
|
||||
|
||||
# Map response format
|
||||
if "response_format" in optional_params:
|
||||
format_name = optional_params["response_format"]
|
||||
if format_name in self.FORMAT_MAPPINGS:
|
||||
mapped_params["output_format"] = self.FORMAT_MAPPINGS[format_name]
|
||||
else:
|
||||
# Try to use it directly as Azure format
|
||||
mapped_params["output_format"] = format_name
|
||||
else:
|
||||
# Default to MP3
|
||||
mapped_params["output_format"] = "audio-24khz-48kbitrate-mono-mp3"
|
||||
|
||||
# Map speed (OpenAI: 0.25-4.0, Azure: prosody rate)
|
||||
if "speed" in optional_params:
|
||||
speed = optional_params["speed"]
|
||||
if speed is not None:
|
||||
mapped_params["rate"] = self._convert_speed_to_azure_rate(speed=speed)
|
||||
|
||||
# Pass through Azure-specific SSML parameters
|
||||
if "style" in kwargs:
|
||||
mapped_params["style"] = kwargs["style"]
|
||||
|
||||
if "styledegree" in kwargs:
|
||||
mapped_params["styledegree"] = kwargs["styledegree"]
|
||||
|
||||
if "role" in kwargs:
|
||||
mapped_params["role"] = kwargs["role"]
|
||||
|
||||
if "lang" in kwargs:
|
||||
mapped_params["lang"] = kwargs["lang"]
|
||||
return mapped_voice, mapped_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate Azure environment and set up authentication headers
|
||||
"""
|
||||
validated_headers = headers.copy()
|
||||
|
||||
# Azure AVA TTS requires either:
|
||||
# 1. Ocp-Apim-Subscription-Key header, or
|
||||
# 2. Authorization: Bearer <token> header
|
||||
|
||||
# We'll use the token-based auth via our token handler
|
||||
# The token will be added later in the handler
|
||||
|
||||
if api_key:
|
||||
# If subscription key is provided, use it directly
|
||||
validated_headers["Ocp-Apim-Subscription-Key"] = api_key
|
||||
|
||||
# Content-Type for SSML
|
||||
validated_headers["Content-Type"] = "application/ssml+xml"
|
||||
|
||||
# User-Agent
|
||||
validated_headers["User-Agent"] = "litellm"
|
||||
|
||||
return validated_headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Azure AVA TTS request
|
||||
|
||||
Azure TTS endpoint format:
|
||||
https://{region}.tts.speech.microsoft.com/cognitiveservices/v1
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"api_base is required for Azure AVA TTS. "
|
||||
f"Format: https://{{region}}.{self.COGNITIVE_SERVICES_DOMAIN} or "
|
||||
f"https://{{region}}.{self.TTS_SPEECH_DOMAIN}"
|
||||
)
|
||||
|
||||
# Remove trailing slash and parse URL
|
||||
api_base = api_base.rstrip("/")
|
||||
parsed_url = urlparse(api_base)
|
||||
hostname = parsed_url.hostname or ""
|
||||
|
||||
# Check if it's a Cognitive Services endpoint (convert to TTS endpoint)
|
||||
if self._is_cognitive_services_endpoint(hostname=hostname):
|
||||
region = self._extract_region_from_hostname(
|
||||
hostname=hostname, domain=self.COGNITIVE_SERVICES_DOMAIN
|
||||
)
|
||||
return self._build_tts_url(region=region)
|
||||
|
||||
# Check if it's already a TTS endpoint
|
||||
if self._is_tts_endpoint(hostname=hostname):
|
||||
if not api_base.endswith(self.TTS_ENDPOINT_PATH):
|
||||
return f"{api_base}{self.TTS_ENDPOINT_PATH}"
|
||||
return api_base
|
||||
|
||||
# Assume it's a custom endpoint, append the path
|
||||
return f"{api_base}{self.TTS_ENDPOINT_PATH}"
|
||||
|
||||
def _is_cognitive_services_endpoint(self, hostname: str) -> bool:
|
||||
"""Check if hostname is a Cognitive Services endpoint"""
|
||||
return hostname == self.COGNITIVE_SERVICES_DOMAIN or hostname.endswith(
|
||||
f".{self.COGNITIVE_SERVICES_DOMAIN}"
|
||||
)
|
||||
|
||||
def _is_tts_endpoint(self, hostname: str) -> bool:
|
||||
"""Check if hostname is a TTS endpoint"""
|
||||
return hostname == self.TTS_SPEECH_DOMAIN or hostname.endswith(
|
||||
f".{self.TTS_SPEECH_DOMAIN}"
|
||||
)
|
||||
|
||||
def _extract_region_from_hostname(self, hostname: str, domain: str) -> str:
|
||||
"""
|
||||
Extract region from hostname
|
||||
|
||||
Examples:
|
||||
eastus.api.cognitive.microsoft.com -> eastus
|
||||
api.cognitive.microsoft.com -> ""
|
||||
"""
|
||||
if hostname.endswith(f".{domain}"):
|
||||
return hostname[: -len(f".{domain}")]
|
||||
return ""
|
||||
|
||||
def _build_tts_url(self, region: str) -> str:
|
||||
"""Build the complete TTS URL with region"""
|
||||
if region:
|
||||
return f"https://{region}.{self.TTS_SPEECH_DOMAIN}{self.TTS_ENDPOINT_PATH}"
|
||||
return f"https://{self.TTS_SPEECH_DOMAIN}{self.TTS_ENDPOINT_PATH}"
|
||||
|
||||
def is_ssml_input(self, input: str) -> bool:
|
||||
"""
|
||||
Returns True if input is SSML, False otherwise
|
||||
|
||||
Based on https://www.w3.org/TR/speech-synthesis/ all SSML must start with <speak>
|
||||
"""
|
||||
return "<speak>" in input or "<speak " in input
|
||||
|
||||
def transform_text_to_speech_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[str],
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
headers: dict,
|
||||
) -> TextToSpeechRequestData:
|
||||
"""
|
||||
Transform OpenAI TTS request to Azure AVA TTS SSML format
|
||||
|
||||
Note: optional_params should already be mapped via map_openai_params in main.py
|
||||
|
||||
Supports Azure-specific SSML features:
|
||||
- style: Speaking style (e.g., "cheerful", "sad", "angry")
|
||||
- styledegree: Style intensity (0.01 to 2)
|
||||
- role: Voice role (e.g., "Girl", "Boy", "SeniorFemale", "SeniorMale")
|
||||
- lang: Language code for multilingual voices (e.g., "es-ES", "fr-FR")
|
||||
|
||||
Auto-detects SSML:
|
||||
- If input contains <speak>, it's passed through as-is without transformation
|
||||
|
||||
Returns:
|
||||
TextToSpeechRequestData: Contains SSML body and Azure-specific headers
|
||||
"""
|
||||
# Get voice (already mapped in main.py, or use default)
|
||||
azure_voice = voice or self.DEFAULT_VOICE
|
||||
|
||||
# Get output format (already mapped in main.py)
|
||||
output_format = optional_params.get(
|
||||
"output_format", "audio-24khz-48kbitrate-mono-mp3"
|
||||
)
|
||||
headers["X-Microsoft-OutputFormat"] = output_format
|
||||
|
||||
# Auto-detect SSML: if input contains <speak>, pass it through as-is
|
||||
# Similar to Vertex AI behavior - check if input looks like SSML
|
||||
if self.is_ssml_input(input=input):
|
||||
return TextToSpeechRequestData(
|
||||
ssml_body=input,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Build SSML from plain text
|
||||
rate = optional_params.get("rate", "0%")
|
||||
style = optional_params.get("style")
|
||||
styledegree = optional_params.get("styledegree")
|
||||
role = optional_params.get("role")
|
||||
lang = optional_params.get("lang")
|
||||
|
||||
# Escape XML special characters in input text
|
||||
escaped_input = (
|
||||
input.replace("&", "&")
|
||||
.replace("<", "<")
|
||||
.replace(">", ">")
|
||||
.replace('"', """)
|
||||
.replace("'", "'")
|
||||
)
|
||||
|
||||
# Determine if we need mstts namespace (for express-as element)
|
||||
use_mstts = style or role or styledegree
|
||||
|
||||
# Build the xmlns attributes
|
||||
if use_mstts:
|
||||
xmlns = "xmlns='http://www.w3.org/2001/10/synthesis' xmlns:mstts='https://www.w3.org/2001/mstts'"
|
||||
else:
|
||||
xmlns = "xmlns='http://www.w3.org/2001/10/synthesis'"
|
||||
|
||||
# Build the inner content with prosody
|
||||
prosody_content = f"<prosody rate='{rate}'>{escaped_input}</prosody>"
|
||||
|
||||
# Wrap in mstts:express-as if style or role is specified
|
||||
voice_content = self._build_express_as_element(
|
||||
content=prosody_content,
|
||||
style=style,
|
||||
styledegree=styledegree,
|
||||
role=role,
|
||||
)
|
||||
|
||||
# Build voice element with optional xml:lang attribute
|
||||
voice_lang = self._get_voice_language(
|
||||
voice_name=azure_voice,
|
||||
explicit_lang=lang,
|
||||
)
|
||||
voice_lang_attr = f" xml:lang='{voice_lang}'" if voice_lang else ""
|
||||
|
||||
ssml_body = f"""<speak version='1.0' {xmlns} xml:lang='en-US'>
|
||||
<voice name='{azure_voice}'{voice_lang_attr}>
|
||||
{voice_content}
|
||||
</voice>
|
||||
</speak>"""
|
||||
|
||||
return {
|
||||
"ssml_body": ssml_body,
|
||||
"headers": headers,
|
||||
}
|
||||
|
||||
def transform_text_to_speech_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
"""
|
||||
Transform Azure AVA TTS response to standard format
|
||||
|
||||
Azure returns the audio data directly in the response body
|
||||
"""
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
# Azure returns audio data directly in the response body
|
||||
# Wrap it in HttpxBinaryResponseContent for consistent return type
|
||||
return HttpxBinaryResponseContent(raw_response)
|
||||
@@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.vector_stores.transformation import OpenAIVectorStoreConfig
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class AzureOpenAIVectorStoreConfig(OpenAIVectorStoreConfig):
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/vector_stores",
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params
|
||||
)
|
||||
@@ -0,0 +1,93 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm.types.videos.main import VideoCreateOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.llms.openai.videos.transformation import OpenAIVideoConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ...base_llm.videos.transformation import BaseVideoConfig as _BaseVideoConfig
|
||||
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseVideoConfig = _BaseVideoConfig
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseVideoConfig = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class AzureVideoConfig(OpenAIVideoConfig):
|
||||
"""
|
||||
Configuration class for OpenAI video generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the list of supported OpenAI parameters for video generation.
|
||||
"""
|
||||
return [
|
||||
"model",
|
||||
"prompt",
|
||||
"input_reference",
|
||||
"seconds",
|
||||
"size",
|
||||
"user",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
video_create_optional_params: VideoCreateOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(video_create_optional_params)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate Azure environment and set up authentication headers.
|
||||
Uses _base_validate_azure_environment to properly handle credentials from litellm_credential_name.
|
||||
"""
|
||||
# If litellm_params is provided, use it; otherwise create a new one
|
||||
if litellm_params is None:
|
||||
litellm_params = GenericLiteLLMParams()
|
||||
|
||||
if api_key and not litellm_params.api_key:
|
||||
litellm_params.api_key = api_key
|
||||
|
||||
# Use the base Azure validation method which properly handles:
|
||||
# 1. Credentials from litellm_credential_name via litellm_params
|
||||
# 2. Sets the correct "api-key" header (not "Authorization: Bearer")
|
||||
return BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Constructs a complete URL for the API request.
|
||||
"""
|
||||
return BaseAzureLLM._get_base_azure_url(
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
route="/openai/v1/videos",
|
||||
default_api_version="",
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
`/chat/completion` calls routed via `openai.py`.
|
||||
@@ -0,0 +1,11 @@
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAIAgentsConfig",
|
||||
"AzureAIAgentsError",
|
||||
"azure_ai_agents_handler",
|
||||
]
|
||||
@@ -0,0 +1,659 @@
|
||||
"""
|
||||
Handler for Azure Foundry Agent Service API.
|
||||
|
||||
This handler executes the multi-step agent flow:
|
||||
1. Create thread (or use existing)
|
||||
2. Add messages to thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
Supports both polling-based and native streaming (SSE) modes.
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure_ai.agents.transformation import (
|
||||
AzureAIAgentsConfig,
|
||||
AzureAIAgentsError,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsHandler:
|
||||
"""
|
||||
Handler for Azure AI Agent Service.
|
||||
|
||||
Executes the complete agent flow which requires multiple API calls.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.config = AzureAIAgentsConfig()
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# URL Builders
|
||||
# -------------------------------------------------------------------------
|
||||
# Azure Foundry Agents API uses /assistants, /threads, etc. directly
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
# -------------------------------------------------------------------------
|
||||
def _build_thread_url(self, api_base: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads?api-version={api_version}"
|
||||
|
||||
def _build_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_runs_url(self, api_base: str, thread_id: str, api_version: str) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs?api-version={api_version}"
|
||||
|
||||
def _build_run_status_url(
|
||||
self, api_base: str, thread_id: str, run_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/runs/{run_id}?api-version={api_version}"
|
||||
|
||||
def _build_list_messages_url(
|
||||
self, api_base: str, thread_id: str, api_version: str
|
||||
) -> str:
|
||||
return f"{api_base}/threads/{thread_id}/messages?api-version={api_version}"
|
||||
|
||||
def _build_create_thread_and_run_url(self, api_base: str, api_version: str) -> str:
|
||||
"""URL for the create-thread-and-run endpoint (supports streaming)."""
|
||||
return f"{api_base}/threads/runs?api-version={api_version}"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Response Helpers
|
||||
# -------------------------------------------------------------------------
|
||||
def _extract_content_from_messages(self, messages_data: dict) -> str:
|
||||
"""Extract assistant content from the messages response."""
|
||||
for msg in messages_data.get("data", []):
|
||||
if msg.get("role") == "assistant":
|
||||
for content_item in msg.get("content", []):
|
||||
if content_item.get("type") == "text":
|
||||
return content_item.get("text", {}).get("value", "")
|
||||
return ""
|
||||
|
||||
def _build_model_response(
|
||||
self,
|
||||
model: str,
|
||||
content: str,
|
||||
model_response: ModelResponse,
|
||||
thread_id: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> ModelResponse:
|
||||
"""Build the ModelResponse from agent output."""
|
||||
from litellm.types.utils import Choices, Message, Usage
|
||||
|
||||
model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=Message(content=content, role="assistant"),
|
||||
)
|
||||
]
|
||||
model_response.model = model
|
||||
|
||||
# Store thread_id for conversation continuity
|
||||
if (
|
||||
not hasattr(model_response, "_hidden_params")
|
||||
or model_response._hidden_params is None
|
||||
):
|
||||
model_response._hidden_params = {}
|
||||
model_response._hidden_params["thread_id"] = thread_id
|
||||
|
||||
# Estimate token usage
|
||||
try:
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens = token_counter(model="gpt-3.5-turbo", messages=messages)
|
||||
completion_tokens = token_counter(
|
||||
model="gpt-3.5-turbo", text=content, count_response_tokens=True
|
||||
)
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to calculate token usage: {str(e)}")
|
||||
|
||||
return model_response
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
model: str,
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
optional_params: dict,
|
||||
headers: Optional[dict],
|
||||
) -> tuple:
|
||||
"""Prepare common parameters for completion.
|
||||
|
||||
Azure Foundry Agents API uses Bearer token authentication:
|
||||
- Authorization: Bearer <token> (Azure AD token from 'az account get-access-token --resource https://ai.azure.com')
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
api_version = optional_params.get(
|
||||
"api_version", self.config.DEFAULT_API_VERSION
|
||||
)
|
||||
agent_id = self.config._get_agent_id(model, optional_params)
|
||||
thread_id = optional_params.get("thread_id")
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Azure AI Agents completion - api_base: {api_base}, agent_id: {agent_id}"
|
||||
)
|
||||
|
||||
return headers, api_version, agent_id, thread_id, api_base
|
||||
|
||||
def _check_response(
|
||||
self, response: httpx.Response, expected_codes: List[int], error_msg: str
|
||||
):
|
||||
"""Check response status and raise error if not expected."""
|
||||
if response.status_code not in expected_codes:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"{error_msg}: {response.text}",
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Sync Completion
|
||||
# -------------------------------------------------------------------------
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[HTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute synchronous completion using Azure Agent Service."""
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = _get_httpx_client(
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return client.get(url=url, headers=headers)
|
||||
return client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = self._execute_agent_flow_sync(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
def _execute_agent_flow_sync(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow synchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
time.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Async Completion
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ModelResponse:
|
||||
"""Execute asynchronous completion using Azure Agent Service."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
async def make_request(
|
||||
method: str, url: str, json_data: Optional[dict] = None
|
||||
) -> httpx.Response:
|
||||
if method == "GET":
|
||||
return await client.get(url=url, headers=headers)
|
||||
return await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(json_data) if json_data else None,
|
||||
)
|
||||
|
||||
# Execute the agent flow
|
||||
thread_id, content = await self._execute_agent_flow_async(
|
||||
make_request=make_request,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
agent_id=agent_id,
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return self._build_model_response(
|
||||
model, content, model_response, thread_id, messages
|
||||
)
|
||||
|
||||
async def _execute_agent_flow_async(
|
||||
self,
|
||||
make_request: Callable,
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
agent_id: str,
|
||||
thread_id: Optional[str],
|
||||
messages: List[Dict[str, Any]],
|
||||
optional_params: dict,
|
||||
) -> Tuple[str, str]:
|
||||
"""Execute the agent flow asynchronously. Returns (thread_id, content)."""
|
||||
|
||||
# Step 1: Create thread if not provided
|
||||
if not thread_id:
|
||||
verbose_logger.debug(
|
||||
f"Creating thread at: {self._build_thread_url(api_base, api_version)}"
|
||||
)
|
||||
response = await make_request(
|
||||
"POST", self._build_thread_url(api_base, api_version), {}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create thread")
|
||||
thread_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created thread: {thread_id}")
|
||||
|
||||
# At this point thread_id is guaranteed to be a string
|
||||
assert thread_id is not None
|
||||
|
||||
# Step 2: Add messages to thread
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
url = self._build_messages_url(api_base, thread_id, api_version)
|
||||
response = await make_request(
|
||||
"POST", url, {"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to add message")
|
||||
|
||||
# Step 3: Create run
|
||||
run_payload = {"assistant_id": agent_id}
|
||||
if "instructions" in optional_params:
|
||||
run_payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
response = await make_request(
|
||||
"POST", self._build_runs_url(api_base, thread_id, api_version), run_payload
|
||||
)
|
||||
self._check_response(response, [200, 201], "Failed to create run")
|
||||
run_id = response.json()["id"]
|
||||
verbose_logger.debug(f"Created run: {run_id}")
|
||||
|
||||
# Step 4: Poll for completion
|
||||
status_url = self._build_run_status_url(
|
||||
api_base, thread_id, run_id, api_version
|
||||
)
|
||||
for _ in range(self.config.MAX_POLL_ATTEMPTS):
|
||||
response = await make_request("GET", status_url)
|
||||
self._check_response(response, [200], "Failed to get run status")
|
||||
|
||||
status = response.json().get("status")
|
||||
verbose_logger.debug(f"Run status: {status}")
|
||||
|
||||
if status == "completed":
|
||||
break
|
||||
elif status in ["failed", "cancelled", "expired"]:
|
||||
error_msg = (
|
||||
response.json()
|
||||
.get("last_error", {})
|
||||
.get("message", "Unknown error")
|
||||
)
|
||||
raise AzureAIAgentsError(
|
||||
status_code=500, message=f"Run {status}: {error_msg}"
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.config.POLL_INTERVAL_SECONDS)
|
||||
else:
|
||||
raise AzureAIAgentsError(
|
||||
status_code=408, message="Run timed out waiting for completion"
|
||||
)
|
||||
|
||||
# Step 5: Get messages
|
||||
response = await make_request(
|
||||
"GET", self._build_list_messages_url(api_base, thread_id, api_version)
|
||||
)
|
||||
self._check_response(response, [200], "Failed to get messages")
|
||||
|
||||
content = self._extract_content_from_messages(response.json())
|
||||
return thread_id, content
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Streaming Completion (Native SSE)
|
||||
# -------------------------------------------------------------------------
|
||||
async def acompletion_stream(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_base: str,
|
||||
api_key: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: float,
|
||||
headers: Optional[dict] = None,
|
||||
) -> AsyncIterator:
|
||||
"""Execute async streaming completion using Azure Agent Service with native SSE."""
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
(
|
||||
headers,
|
||||
api_version,
|
||||
agent_id,
|
||||
thread_id,
|
||||
api_base,
|
||||
) = self._prepare_completion_params(
|
||||
model, api_base, api_key, optional_params, headers
|
||||
)
|
||||
|
||||
# Build payload for create-thread-and-run with streaming
|
||||
thread_messages = []
|
||||
for msg in messages:
|
||||
if msg.get("role") in ["user", "system"]:
|
||||
thread_messages.append(
|
||||
{"role": "user", "content": msg.get("content", "")}
|
||||
)
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"assistant_id": agent_id,
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add thread with messages if we don't have an existing thread
|
||||
if not thread_id:
|
||||
payload["thread"] = {"messages": thread_messages}
|
||||
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
url = self._build_create_thread_and_run_url(api_base, api_version)
|
||||
verbose_logger.debug(f"Azure AI Agents streaming - URL: {url}")
|
||||
|
||||
# Use LiteLLM's async HTTP client for streaming
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI,
|
||||
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if response.status_code not in [200, 201]:
|
||||
error_text = await response.aread()
|
||||
raise AzureAIAgentsError(
|
||||
status_code=response.status_code,
|
||||
message=f"Streaming request failed: {error_text.decode()}",
|
||||
)
|
||||
|
||||
async for chunk in self._process_sse_stream(response, model):
|
||||
yield chunk
|
||||
|
||||
async def _process_sse_stream(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
) -> AsyncIterator:
|
||||
"""Process SSE stream and yield OpenAI-compatible streaming chunks."""
|
||||
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
||||
|
||||
response_id = f"chatcmpl-{uuid.uuid4().hex[:8]}"
|
||||
created = int(time.time())
|
||||
thread_id = None
|
||||
|
||||
current_event = None
|
||||
|
||||
async for line in response.aiter_lines():
|
||||
line = line.strip()
|
||||
|
||||
if line.startswith("event:"):
|
||||
current_event = line[6:].strip()
|
||||
continue
|
||||
|
||||
if line.startswith("data:"):
|
||||
data_str = line[5:].strip()
|
||||
|
||||
if data_str == "[DONE]":
|
||||
# Send final chunk with finish_reason
|
||||
final_chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(content=None),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
final_chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield final_chunk
|
||||
return
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
|
||||
# Extract thread_id from thread.created event
|
||||
if current_event == "thread.created" and "id" in data:
|
||||
thread_id = data["id"]
|
||||
verbose_logger.debug(f"Stream created thread: {thread_id}")
|
||||
|
||||
# Process message deltas - this is where the actual content comes
|
||||
if current_event == "thread.message.delta":
|
||||
delta_content = data.get("delta", {}).get("content", [])
|
||||
for content_item in delta_content:
|
||||
if content_item.get("type") == "text":
|
||||
text_value = content_item.get("text", {}).get("value", "")
|
||||
if text_value:
|
||||
chunk = ModelResponseStream(
|
||||
id=response_id,
|
||||
created=created,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(
|
||||
content=text_value, role="assistant"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
if thread_id:
|
||||
chunk._hidden_params = {"thread_id": thread_id}
|
||||
yield chunk
|
||||
|
||||
|
||||
# Singleton instance
|
||||
azure_ai_agents_handler = AzureAIAgentsHandler()
|
||||
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Transformation for Azure Foundry Agent Service API.
|
||||
|
||||
Azure Foundry Agent Service provides an Assistants-like API for running agents.
|
||||
This follows the OpenAI Assistants pattern: create thread -> add messages -> create/poll run.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
API Base format: https://<AIFoundryResourceName>.services.ai.azure.com/api/projects/<ProjectName>
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens (not API keys)
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
The API uses these endpoints:
|
||||
- POST /threads - Create a thread
|
||||
- POST /threads/{thread_id}/messages - Add message to thread
|
||||
- POST /threads/{thread_id}/runs - Create a run
|
||||
- GET /threads/{thread_id}/runs/{run_id} - Poll run status
|
||||
- GET /threads/{thread_id}/messages - List messages in thread
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
|
||||
|
||||
class AzureAIAgentsError(BaseLLMException):
|
||||
"""Exception class for Azure AI Agent Service API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class AzureAIAgentsConfig(BaseConfig):
|
||||
"""
|
||||
Configuration for Azure AI Agent Service API.
|
||||
|
||||
Azure AI Agent Service is a fully managed service for building AI agents
|
||||
that can understand natural language and perform tasks.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
|
||||
The flow is:
|
||||
1. Create a thread
|
||||
2. Add user messages to the thread
|
||||
3. Create and poll a run
|
||||
4. Retrieve the assistant's response messages
|
||||
"""
|
||||
|
||||
# Default API version for Azure Foundry Agent Service
|
||||
# GA version: 2025-05-01, Preview: 2025-05-15-preview
|
||||
# See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
DEFAULT_API_VERSION = "2025-05-01"
|
||||
|
||||
# Polling configuration
|
||||
MAX_POLL_ATTEMPTS = 60
|
||||
POLL_INTERVAL_SECONDS = 1.0
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def is_azure_ai_agents_route(model: str) -> bool:
|
||||
"""
|
||||
Check if the model is an Azure AI Agents route.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id>
|
||||
"""
|
||||
return "agents/" in model
|
||||
|
||||
@staticmethod
|
||||
def get_agent_id_from_model(model: str) -> str:
|
||||
"""
|
||||
Extract agent ID from the model string.
|
||||
|
||||
Model format: azure_ai/agents/<agent_id> -> <agent_id>
|
||||
or: agents/<agent_id> -> <agent_id>
|
||||
"""
|
||||
if "agents/" in model:
|
||||
# Split on "agents/" and take the part after it
|
||||
parts = model.split("agents/", 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1]
|
||||
return model
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get Azure AI Agent Service API base and key from params or environment.
|
||||
|
||||
Returns:
|
||||
Tuple of (api_base, api_key)
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = api_base or get_secret_str("AZURE_AI_API_BASE")
|
||||
api_key = api_key or get_secret_str("AZURE_AI_API_KEY")
|
||||
|
||||
return api_base, api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Azure Agents supports minimal OpenAI params since it's an agent runtime.
|
||||
"""
|
||||
return ["stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI params to Azure Agents params.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def _get_api_version(self, optional_params: dict) -> str:
|
||||
"""Get API version from optional params or use default."""
|
||||
return optional_params.get("api_version", self.DEFAULT_API_VERSION)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the base URL for Azure AI Agent Service.
|
||||
|
||||
The actual endpoint will vary based on the operation:
|
||||
- /openai/threads for creating threads
|
||||
- /openai/threads/{thread_id}/messages for adding messages
|
||||
- /openai/threads/{thread_id}/runs for creating runs
|
||||
|
||||
This returns the base URL that will be modified for each operation.
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for Azure AI Agents. Set it via AZURE_AI_API_BASE env var or api_base parameter."
|
||||
)
|
||||
|
||||
# Remove trailing slash if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Return base URL - actual endpoints will be constructed during request
|
||||
return api_base
|
||||
|
||||
def _get_agent_id(self, model: str, optional_params: dict) -> str:
|
||||
"""
|
||||
Get the agent ID from model or optional_params.
|
||||
|
||||
model format: "azure_ai/agents/<agent_id>" or "agents/<agent_id>" or just "<agent_id>"
|
||||
"""
|
||||
agent_id = optional_params.get("agent_id") or optional_params.get(
|
||||
"assistant_id"
|
||||
)
|
||||
if agent_id:
|
||||
return agent_id
|
||||
|
||||
# Extract from model name using the static method
|
||||
return self.get_agent_id_from_model(model)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request for Azure Agents.
|
||||
|
||||
This stores the necessary data for the multi-step agent flow.
|
||||
The actual API calls happen in the custom handler.
|
||||
"""
|
||||
agent_id = self._get_agent_id(model, optional_params)
|
||||
|
||||
# Convert messages to a format we can use
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
# Handle content that might be a list
|
||||
if isinstance(content, list):
|
||||
content = convert_content_list_to_str(msg)
|
||||
|
||||
# Ensure content is a string
|
||||
if not isinstance(content, str):
|
||||
content = str(content)
|
||||
|
||||
converted_messages.append({"role": role, "content": content})
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"agent_id": agent_id,
|
||||
"messages": converted_messages,
|
||||
"api_version": self._get_api_version(optional_params),
|
||||
}
|
||||
|
||||
# Pass through thread_id if provided (for continuing conversations)
|
||||
if "thread_id" in optional_params:
|
||||
payload["thread_id"] = optional_params["thread_id"]
|
||||
|
||||
# Pass through any additional instructions
|
||||
if "instructions" in optional_params:
|
||||
payload["instructions"] = optional_params["instructions"]
|
||||
|
||||
verbose_logger.debug(f"Azure AI Agents request payload: {payload}")
|
||||
return payload
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and set up environment for Azure Foundry Agents requests.
|
||||
|
||||
Azure Foundry Agents uses Bearer token authentication with Azure AD tokens.
|
||||
Get token via: az account get-access-token --resource 'https://ai.azure.com'
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
# Azure Foundry Agents uses Bearer token authentication
|
||||
# The api_key here is expected to be an Azure AD token
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return AzureAIAgentsError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Azure Agents uses polling, so we fake stream by returning the final response.
|
||||
"""
|
||||
return True
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
"""Azure Agents doesn't have native streaming - uses fake stream."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Azure Agents does not use a stream param in request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the Azure Agents response to LiteLLM ModelResponse format.
|
||||
"""
|
||||
# This is not used since we have a custom handler
|
||||
return model_response
|
||||
|
||||
@staticmethod
|
||||
def completion(
|
||||
model: str,
|
||||
messages: List,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
timeout: Union[float, int, Any],
|
||||
acompletion: bool,
|
||||
stream: Optional[bool] = False,
|
||||
headers: Optional[dict] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Dispatch method for Azure Foundry Agents completion.
|
||||
|
||||
Routes to sync or async completion based on acompletion flag.
|
||||
Supports native streaming via SSE when stream=True and acompletion=True.
|
||||
|
||||
Authentication: Uses Azure AD Bearer tokens.
|
||||
- Pass api_key directly as an Azure AD token
|
||||
- Or set up Azure AD credentials via environment variables for automatic token retrieval:
|
||||
- AZURE_TENANT_ID, AZURE_CLIENT_ID, AZURE_CLIENT_SECRET (Service Principal)
|
||||
|
||||
See: https://learn.microsoft.com/en-us/azure/ai-foundry/agents/quickstart
|
||||
"""
|
||||
from litellm.llms.azure.common_utils import get_azure_ad_token
|
||||
from litellm.llms.azure_ai.agents.handler import azure_ai_agents_handler
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
# If no api_key is provided, try to get Azure AD token
|
||||
if api_key is None:
|
||||
# Try to get Azure AD token using the existing Azure auth mechanisms
|
||||
# This uses the scope for Azure AI (ai.azure.com) instead of cognitive services
|
||||
# Create a GenericLiteLLMParams with the scope override for Azure Foundry Agents
|
||||
azure_auth_params = dict(litellm_params) if litellm_params else {}
|
||||
azure_auth_params["azure_scope"] = "https://ai.azure.com/.default"
|
||||
api_key = get_azure_ad_token(GenericLiteLLMParams(**azure_auth_params))
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"api_key (Azure AD token) is required for Azure Foundry Agents. "
|
||||
"Either pass api_key directly, or set AZURE_TENANT_ID, AZURE_CLIENT_ID, "
|
||||
"and AZURE_CLIENT_SECRET environment variables for Service Principal auth. "
|
||||
"Manual token: az account get-access-token --resource 'https://ai.azure.com'"
|
||||
)
|
||||
if acompletion:
|
||||
if stream:
|
||||
# Native async streaming via SSE - return the async generator directly
|
||||
return azure_ai_agents_handler.acompletion_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
return azure_ai_agents_handler.acompletion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
# Sync completion - streaming not supported for sync
|
||||
return azure_ai_agents_handler.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Azure Anthropic provider - supports Claude models via Azure Foundry
|
||||
"""
|
||||
from .handler import AzureAnthropicChatCompletion
|
||||
from .transformation import AzureAnthropicConfig
|
||||
|
||||
try:
|
||||
from .messages_transformation import AzureAnthropicMessagesConfig
|
||||
|
||||
__all__ = [
|
||||
"AzureAnthropicChatCompletion",
|
||||
"AzureAnthropicConfig",
|
||||
"AzureAnthropicMessagesConfig",
|
||||
]
|
||||
except ImportError:
|
||||
__all__ = ["AzureAnthropicChatCompletion", "AzureAnthropicConfig"]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API implementation.
|
||||
"""
|
||||
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
|
||||
AzureAIAnthropicCountTokensHandler,
|
||||
)
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.token_counter import (
|
||||
AzureAIAnthropicTokenCounter,
|
||||
)
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
|
||||
AzureAIAnthropicCountTokensConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AzureAIAnthropicCountTokensHandler",
|
||||
"AzureAIAnthropicCountTokensConfig",
|
||||
"AzureAIAnthropicTokenCounter",
|
||||
]
|
||||
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API handler.
|
||||
|
||||
Uses httpx for HTTP requests with Azure authentication.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.transformation import (
|
||||
AzureAIAnthropicCountTokensConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
|
||||
class AzureAIAnthropicCountTokensHandler(AzureAIAnthropicCountTokensConfig):
|
||||
"""
|
||||
Handler for Azure AI Anthropic CountTokens API requests.
|
||||
|
||||
Uses httpx for HTTP requests with Azure authentication.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a CountTokens request using httpx with Azure authentication.
|
||||
|
||||
Args:
|
||||
model: The model identifier (e.g., "claude-3-5-sonnet")
|
||||
messages: The messages to count tokens for
|
||||
api_key: The Azure AI API key
|
||||
api_base: The Azure AI API base URL
|
||||
litellm_params: Optional LiteLLM parameters
|
||||
timeout: Optional timeout for the request (defaults to litellm.request_timeout)
|
||||
|
||||
Returns:
|
||||
Dictionary containing token count response
|
||||
|
||||
Raises:
|
||||
AnthropicError: If the API request fails
|
||||
"""
|
||||
try:
|
||||
# Validate the request
|
||||
self.validate_request(model, messages)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing Azure AI Anthropic CountTokens request for model: {model}"
|
||||
)
|
||||
|
||||
# Transform request to Anthropic format
|
||||
request_body = self.transform_request_to_count_tokens(
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Transformed request: {request_body}")
|
||||
|
||||
# Get endpoint URL
|
||||
endpoint_url = self.get_count_tokens_endpoint(api_base)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
# Get required headers with Azure authentication
|
||||
headers = self.get_required_headers(
|
||||
api_key=api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Use LiteLLM's async httpx client
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.AZURE_AI
|
||||
)
|
||||
|
||||
# Use provided timeout or fall back to litellm.request_timeout
|
||||
request_timeout = (
|
||||
timeout if timeout is not None else litellm.request_timeout
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"Azure AI Anthropic API error: {error_text}")
|
||||
raise AnthropicError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
azure_response = response.json()
|
||||
|
||||
verbose_logger.debug(f"Azure AI Anthropic response: {azure_response}")
|
||||
|
||||
# Return Anthropic-compatible response directly - no transformation needed
|
||||
return azure_response
|
||||
|
||||
except AnthropicError:
|
||||
# Re-raise Anthropic exceptions as-is
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
# HTTP errors - preserve the actual status code
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise AnthropicError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
Azure AI Anthropic Token Counter implementation using the CountTokens API.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure_ai.anthropic.count_tokens.handler import (
|
||||
AzureAIAnthropicCountTokensHandler,
|
||||
)
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
# Global handler instance - reuse across all token counting requests
|
||||
azure_ai_anthropic_count_tokens_handler = AzureAIAnthropicCountTokensHandler()
|
||||
|
||||
|
||||
class AzureAIAnthropicTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for Azure AI Anthropic provider using the CountTokens API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return custom_llm_provider == LlmProviders.AZURE_AI.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using Azure AI Anthropic's CountTokens API.
|
||||
|
||||
Args:
|
||||
model_to_use: The model identifier
|
||||
messages: The messages to count tokens for
|
||||
contents: Alternative content format (not used for Anthropic)
|
||||
deployment: Deployment configuration containing litellm_params
|
||||
request_model: The original request model name
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with token count, or None if counting fails
|
||||
"""
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Get Azure AI API key from deployment config or environment
|
||||
api_key = litellm_params.get("api_key")
|
||||
if not api_key:
|
||||
api_key = os.getenv("AZURE_AI_API_KEY")
|
||||
|
||||
# Get API base from deployment config or environment
|
||||
api_base = litellm_params.get("api_base")
|
||||
if not api_base:
|
||||
api_base = os.getenv("AZURE_AI_API_BASE")
|
||||
|
||||
if not api_key:
|
||||
verbose_logger.warning("No Azure AI API key found for token counting")
|
||||
return None
|
||||
|
||||
if not api_base:
|
||||
verbose_logger.warning("No Azure AI API base found for token counting")
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await azure_ai_anthropic_count_tokens_handler.handle_count_tokens_request(
|
||||
model=model_to_use,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
tools=tools,
|
||||
system=system,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
original_response=result,
|
||||
)
|
||||
except AnthropicError as e:
|
||||
verbose_logger.warning(
|
||||
f"Azure AI Anthropic CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Error calling Azure AI Anthropic CountTokens API: {e}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="azure_ai_anthropic_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Azure AI Anthropic CountTokens API transformation logic.
|
||||
|
||||
Extends the base Anthropic CountTokens transformation with Azure authentication.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.constants import ANTHROPIC_TOKEN_COUNTING_BETA_VERSION
|
||||
from litellm.llms.anthropic.count_tokens.transformation import (
|
||||
AnthropicCountTokensConfig,
|
||||
)
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
|
||||
class AzureAIAnthropicCountTokensConfig(AnthropicCountTokensConfig):
|
||||
"""
|
||||
Configuration and transformation logic for Azure AI Anthropic CountTokens API.
|
||||
|
||||
Extends AnthropicCountTokensConfig with Azure authentication.
|
||||
Azure AI Anthropic uses the same endpoint format but with Azure auth headers.
|
||||
"""
|
||||
|
||||
def get_required_headers(
|
||||
self,
|
||||
api_key: str,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Get the required headers for the Azure AI Anthropic CountTokens API.
|
||||
|
||||
Azure AI Anthropic uses Anthropic's native API format, which requires the
|
||||
x-api-key header for authentication (in addition to Azure's api-key header).
|
||||
|
||||
Args:
|
||||
api_key: The Azure AI API key
|
||||
litellm_params: Optional LiteLLM parameters for additional auth config
|
||||
|
||||
Returns:
|
||||
Dictionary of required headers with both x-api-key and Azure authentication
|
||||
"""
|
||||
# Start with base headers including x-api-key for Anthropic API compatibility
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"anthropic-version": "2023-06-01",
|
||||
"anthropic-beta": ANTHROPIC_TOKEN_COUNTING_BETA_VERSION,
|
||||
"x-api-key": api_key, # Azure AI Anthropic requires this header
|
||||
}
|
||||
|
||||
# Also set up Azure auth headers for flexibility
|
||||
litellm_params = litellm_params or {}
|
||||
if "api_key" not in litellm_params:
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
|
||||
# Get Azure auth headers (api-key or Authorization)
|
||||
azure_headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers={}, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Merge Azure auth headers
|
||||
headers.update(azure_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def get_count_tokens_endpoint(self, api_base: str) -> str:
|
||||
"""
|
||||
Get the Azure AI Anthropic CountTokens API endpoint.
|
||||
|
||||
Args:
|
||||
api_base: The Azure AI API base URL
|
||||
(e.g., https://my-resource.services.ai.azure.com or
|
||||
https://my-resource.services.ai.azure.com/anthropic)
|
||||
|
||||
Returns:
|
||||
The endpoint URL for the CountTokens API
|
||||
"""
|
||||
# Azure AI Anthropic endpoint format:
|
||||
# https://<resource>.services.ai.azure.com/anthropic/v1/messages/count_tokens
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Ensure the URL has /anthropic path
|
||||
if not api_base.endswith("/anthropic"):
|
||||
if "/anthropic" not in api_base:
|
||||
api_base = f"{api_base}/anthropic"
|
||||
|
||||
# Add the count_tokens path
|
||||
return f"{api_base}/v1/messages/count_tokens"
|
||||
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Azure Anthropic handler - reuses AnthropicChatCompletion logic with Azure authentication
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.handler import AnthropicChatCompletion
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from .transformation import AzureAnthropicConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicChatCompletion(AnthropicChatCompletion):
|
||||
"""
|
||||
Azure Anthropic chat completion handler.
|
||||
Reuses all Anthropic logic but with Azure authentication.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
litellm_params: dict,
|
||||
acompletion=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client=None,
|
||||
):
|
||||
"""
|
||||
Completion method that uses Azure authentication instead of Anthropic's x-api-key.
|
||||
All other logic is the same as AnthropicChatCompletion.
|
||||
"""
|
||||
|
||||
optional_params = copy.deepcopy(optional_params)
|
||||
stream = optional_params.pop("stream", None)
|
||||
json_mode: bool = optional_params.pop("json_mode", False)
|
||||
is_vertex_request: bool = optional_params.pop("is_vertex_request", False)
|
||||
_is_function_call = False
|
||||
messages = copy.deepcopy(messages)
|
||||
|
||||
# Use AzureAnthropicConfig for both azure_anthropic and azure_ai Claude models
|
||||
config = AzureAnthropicConfig()
|
||||
|
||||
headers = config.validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params={**optional_params, "is_vertex_request": is_vertex_request},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = config.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
print_verbose(f"_is_function_call: {_is_function_call}")
|
||||
if acompletion is True:
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
print_verbose("makes async azure anthropic streaming POST request")
|
||||
data["stream"] = stream
|
||||
return self.acompletion_stream_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
json_mode=json_mode,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.acompletion_function(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
api_base=api_base,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
provider_config=config,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream,
|
||||
_is_function_call=_is_function_call,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
## COMPLETION CALL
|
||||
if (
|
||||
stream is True
|
||||
): # if function call - fake the streaming (need complete blocks for output parsing in openai format)
|
||||
data["stream"] = stream
|
||||
# Import the make_sync_call from parent
|
||||
from litellm.llms.anthropic.chat.handler import make_sync_call
|
||||
|
||||
completion_stream, response_headers = make_sync_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers, # type: ignore
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
timeout=timeout,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
from litellm.llms.anthropic.common_utils import (
|
||||
process_anthropic_headers,
|
||||
)
|
||||
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="azure_ai",
|
||||
logging_obj=logging_obj,
|
||||
_response_headers=process_anthropic_headers(response_headers),
|
||||
)
|
||||
|
||||
else:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
from litellm.llms.custom_httpx.http_handler import _get_httpx_client
|
||||
|
||||
client = _get_httpx_client(params={"timeout": timeout})
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
timeout=timeout,
|
||||
)
|
||||
except Exception as e:
|
||||
from litellm.llms.anthropic.common_utils import AnthropicError
|
||||
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
if error_response and hasattr(error_response, "text"):
|
||||
error_text = getattr(error_response, "text", error_text)
|
||||
raise AnthropicError(
|
||||
message=error_text,
|
||||
status_code=status_code,
|
||||
headers=error_headers,
|
||||
)
|
||||
|
||||
return config.transform_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
request_data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,166 @@
|
||||
"""
|
||||
Azure Anthropic messages transformation config - extends AnthropicMessagesConfig with Azure authentication
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicMessagesConfig(AnthropicMessagesConfig):
|
||||
"""
|
||||
Azure Anthropic messages configuration that extends AnthropicMessagesConfig.
|
||||
The only difference is authentication - Azure uses x-api-key header (not api-key)
|
||||
and Azure endpoint format.
|
||||
"""
|
||||
|
||||
def validate_anthropic_messages_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
"""
|
||||
Validate environment and set up Azure authentication headers for /v1/messages endpoint.
|
||||
Azure Anthropic uses x-api-key header (not api-key).
|
||||
"""
|
||||
# Convert dict to GenericLiteLLMParams if needed
|
||||
if isinstance(litellm_params, dict):
|
||||
if api_key and "api_key" not in litellm_params:
|
||||
litellm_params = {**litellm_params, "api_key": api_key}
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
else:
|
||||
litellm_params_obj = litellm_params or GenericLiteLLMParams()
|
||||
if api_key and not litellm_params_obj.api_key:
|
||||
litellm_params_obj.api_key = api_key
|
||||
|
||||
# Use Azure authentication logic
|
||||
headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Azure Anthropic uses x-api-key header (not api-key)
|
||||
# Convert api-key to x-api-key if present
|
||||
if "api-key" in headers and "x-api-key" not in headers:
|
||||
headers["x-api-key"] = headers.pop("api-key")
|
||||
|
||||
# Set anthropic-version header
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
# Set content-type header
|
||||
if "content-type" not in headers:
|
||||
headers["content-type"] = "application/json"
|
||||
|
||||
headers = self._update_headers_with_anthropic_beta(
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Azure Anthropic /v1/messages endpoint.
|
||||
Azure Foundry endpoint format: https://<resource-name>.services.ai.azure.com/anthropic/v1/messages
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
api_base = api_base or get_secret_str("AZURE_API_BASE")
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Missing Azure API Base - Please set `api_base` or `AZURE_API_BASE` environment variable. "
|
||||
"Expected format: https://<resource-name>.services.ai.azure.com/anthropic"
|
||||
)
|
||||
|
||||
# Ensure the URL ends with /v1/messages
|
||||
api_base = api_base.rstrip("/")
|
||||
if api_base.endswith("/v1/messages"):
|
||||
# Already correct
|
||||
pass
|
||||
elif api_base.endswith("/anthropic/v1/messages"):
|
||||
# Already correct
|
||||
pass
|
||||
else:
|
||||
# Check if /anthropic is already in the path
|
||||
if "/anthropic" in api_base:
|
||||
# /anthropic exists, ensure we end with /anthropic/v1/messages
|
||||
# Extract the base URL up to and including /anthropic
|
||||
parts = api_base.split("/anthropic", 1)
|
||||
api_base = parts[0] + "/anthropic"
|
||||
else:
|
||||
# /anthropic not in path, add it
|
||||
api_base = api_base + "/anthropic"
|
||||
# Add /v1/messages
|
||||
api_base = api_base + "/v1/messages"
|
||||
|
||||
return api_base
|
||||
|
||||
def _remove_scope_from_cache_control(
|
||||
self, anthropic_messages_request: Dict
|
||||
) -> None:
|
||||
"""
|
||||
Remove `scope` field from cache_control for Azure AI Foundry.
|
||||
|
||||
Azure AI Foundry's Anthropic endpoint does not support the `scope` field
|
||||
(e.g., "global" for cross-request caching). Only `type` and `ttl` are supported.
|
||||
|
||||
Processes both `system` and `messages` content blocks.
|
||||
"""
|
||||
|
||||
def _sanitize(cache_control: Any) -> None:
|
||||
if isinstance(cache_control, dict):
|
||||
cache_control.pop("scope", None)
|
||||
|
||||
def _process_content_list(content: list) -> None:
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "cache_control" in item:
|
||||
_sanitize(item["cache_control"])
|
||||
|
||||
if "system" in anthropic_messages_request:
|
||||
system = anthropic_messages_request["system"]
|
||||
if isinstance(system, list):
|
||||
_process_content_list(system)
|
||||
|
||||
if "messages" in anthropic_messages_request:
|
||||
for message in anthropic_messages_request["messages"]:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
_process_content_list(content)
|
||||
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
anthropic_messages_request = super().transform_anthropic_messages_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
self._remove_scope_from_cache_control(anthropic_messages_request)
|
||||
return anthropic_messages_request
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Azure Anthropic transformation config - extends AnthropicConfig with Azure authentication
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.azure.common_utils import BaseAzureLLM
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AzureAnthropicConfig(AnthropicConfig):
|
||||
"""
|
||||
Azure Anthropic configuration that extends AnthropicConfig.
|
||||
The only difference is authentication - Azure uses api-key header or Azure AD token
|
||||
instead of x-api-key header.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "azure_ai"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: Union[dict, GenericLiteLLMParams],
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Validate environment and set up Azure authentication headers.
|
||||
Azure supports:
|
||||
1. API key via 'api-key' header
|
||||
2. Azure AD token via 'Authorization: Bearer <token>' header
|
||||
"""
|
||||
# Convert dict to GenericLiteLLMParams if needed
|
||||
if isinstance(litellm_params, dict):
|
||||
# Ensure api_key is included if provided
|
||||
if api_key and "api_key" not in litellm_params:
|
||||
litellm_params = {**litellm_params, "api_key": api_key}
|
||||
litellm_params_obj = GenericLiteLLMParams(**litellm_params)
|
||||
else:
|
||||
litellm_params_obj = litellm_params or GenericLiteLLMParams()
|
||||
# Set api_key if provided and not already set
|
||||
if api_key and not litellm_params_obj.api_key:
|
||||
litellm_params_obj.api_key = api_key
|
||||
|
||||
# Use Azure authentication logic
|
||||
headers = BaseAzureLLM._base_validate_azure_environment(
|
||||
headers=headers, litellm_params=litellm_params_obj
|
||||
)
|
||||
|
||||
# Get tools and other anthropic-specific setup
|
||||
tools = optional_params.get("tools")
|
||||
prompt_caching_set = self.is_cache_control_set(messages=messages)
|
||||
computer_tool_used = self.is_computer_tool_used(tools=tools)
|
||||
mcp_server_used = self.is_mcp_server_used(
|
||||
mcp_servers=optional_params.get("mcp_servers")
|
||||
)
|
||||
pdf_used = self.is_pdf_used(messages=messages)
|
||||
file_id_used = self.is_file_id_used(messages=messages)
|
||||
user_anthropic_beta_headers = self._get_user_anthropic_beta_headers(
|
||||
anthropic_beta_header=headers.get("anthropic-beta")
|
||||
)
|
||||
|
||||
# Get anthropic headers (but we'll replace x-api-key with Azure auth)
|
||||
anthropic_headers = self.get_anthropic_headers(
|
||||
computer_tool_used=computer_tool_used,
|
||||
prompt_caching_set=prompt_caching_set,
|
||||
pdf_used=pdf_used,
|
||||
api_key=api_key or "", # Azure auth is already in headers
|
||||
file_id_used=file_id_used,
|
||||
is_vertex_request=optional_params.get("is_vertex_request", False),
|
||||
user_anthropic_beta_headers=user_anthropic_beta_headers,
|
||||
mcp_server_used=mcp_server_used,
|
||||
)
|
||||
# Merge headers - Azure auth (api-key or Authorization) takes precedence
|
||||
headers = {**anthropic_headers, **headers}
|
||||
|
||||
# Ensure anthropic-version header is set
|
||||
if "anthropic-version" not in headers:
|
||||
headers["anthropic-version"] = "2023-06-01"
|
||||
|
||||
return headers
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform request using parent AnthropicConfig, then remove unsupported params.
|
||||
Azure Anthropic doesn't support extra_body, max_retries, or stream_options parameters.
|
||||
"""
|
||||
# Call parent transform_request
|
||||
data = super().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Remove unsupported parameters for Azure AI Anthropic
|
||||
data.pop("extra_body", None)
|
||||
data.pop("max_retries", None)
|
||||
data.pop("stream_options", None)
|
||||
|
||||
return data
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user