chore: initial public snapshot for github upload
This commit is contained in:
1
llm-gateway-competitors/litellm
Submodule
1
llm-gateway-competitors/litellm
Submodule
Submodule llm-gateway-competitors/litellm added at cd37ee1459
1
llm-gateway-competitors/litellm-sparse
Submodule
1
llm-gateway-competitors/litellm-sparse
Submodule
Submodule llm-gateway-competitors/litellm-sparse added at 58e74a631c
@@ -0,0 +1,26 @@
|
||||
Portions of this software are licensed as follows:
|
||||
|
||||
* All content that resides under the "enterprise/" directory of this repository, if that directory exists, is licensed under the license defined in "enterprise/LICENSE".
|
||||
* Content outside of the above mentioned directories or restrictions above is available under the MIT license as defined below.
|
||||
---
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Berri AI
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
@@ -0,0 +1,555 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: litellm
|
||||
Version: 1.82.2
|
||||
Summary: Library to easily interface with LLM API providers
|
||||
License: MIT
|
||||
Author: BerriAI
|
||||
Requires-Python: >=3.9,<4.0
|
||||
Classifier: License :: OSI Approved :: MIT License
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Programming Language :: Python :: 3.10
|
||||
Classifier: Programming Language :: Python :: 3.11
|
||||
Classifier: Programming Language :: Python :: 3.12
|
||||
Classifier: Programming Language :: Python :: 3.13
|
||||
Provides-Extra: caching
|
||||
Provides-Extra: extra-proxy
|
||||
Provides-Extra: google
|
||||
Provides-Extra: grpc
|
||||
Provides-Extra: mlflow
|
||||
Provides-Extra: proxy
|
||||
Provides-Extra: semantic-router
|
||||
Provides-Extra: utils
|
||||
Requires-Dist: PyJWT (>=2.10.1,<3.0.0) ; (python_version >= "3.9") and (extra == "proxy")
|
||||
Requires-Dist: a2a-sdk (>=0.3.22,<0.4.0) ; (python_version >= "3.10") and (extra == "extra-proxy")
|
||||
Requires-Dist: aiohttp (>=3.10)
|
||||
Requires-Dist: apscheduler (>=3.10.4,<4.0.0) ; extra == "proxy"
|
||||
Requires-Dist: azure-identity (>=1.15.0,<2.0.0) ; (python_version >= "3.9") and (extra == "proxy" or extra == "extra-proxy")
|
||||
Requires-Dist: azure-keyvault-secrets (>=4.8.0,<5.0.0) ; extra == "extra-proxy"
|
||||
Requires-Dist: azure-storage-blob (>=12.25.1,<13.0.0) ; extra == "proxy"
|
||||
Requires-Dist: backoff ; extra == "proxy"
|
||||
Requires-Dist: boto3 (>=1.40.76,<2.0.0) ; extra == "proxy"
|
||||
Requires-Dist: click
|
||||
Requires-Dist: cryptography ; extra == "proxy"
|
||||
Requires-Dist: diskcache (>=5.6.1,<6.0.0) ; extra == "caching"
|
||||
Requires-Dist: fastapi (>=0.120.1) ; extra == "proxy"
|
||||
Requires-Dist: fastapi-sso (>=0.16.0,<0.17.0) ; extra == "proxy"
|
||||
Requires-Dist: fastuuid (>=0.13.0)
|
||||
Requires-Dist: google-cloud-aiplatform (>=1.38.0) ; extra == "google"
|
||||
Requires-Dist: google-cloud-iam (>=2.19.1,<3.0.0) ; extra == "extra-proxy"
|
||||
Requires-Dist: google-cloud-kms (>=2.21.3,<3.0.0) ; extra == "extra-proxy"
|
||||
Requires-Dist: grpcio (>=1.62.3,!=1.68.*,!=1.69.*,!=1.70.*,!=1.71.0,!=1.71.1,!=1.72.0,!=1.72.1,!=1.73.0) ; (python_version < "3.14") and (extra == "grpc")
|
||||
Requires-Dist: grpcio (>=1.75.0) ; (python_version >= "3.14") and (extra == "grpc")
|
||||
Requires-Dist: gunicorn (>=23.0.0,<24.0.0) ; extra == "proxy"
|
||||
Requires-Dist: httpx (>=0.23.0)
|
||||
Requires-Dist: importlib-metadata (>=6.8.0)
|
||||
Requires-Dist: jinja2 (>=3.1.2,<4.0.0)
|
||||
Requires-Dist: jsonschema (>=4.23.0,<5.0.0)
|
||||
Requires-Dist: litellm-enterprise (>=0.1.33,<0.2.0) ; extra == "proxy"
|
||||
Requires-Dist: litellm-proxy-extras (>=0.4.56,<0.5.0) ; extra == "proxy"
|
||||
Requires-Dist: mcp (>=1.25.0,<2.0.0) ; (python_version >= "3.10") and (extra == "proxy")
|
||||
Requires-Dist: mlflow (>3.1.4) ; (python_version >= "3.10") and (extra == "mlflow")
|
||||
Requires-Dist: numpydoc ; extra == "utils"
|
||||
Requires-Dist: openai (>=2.8.0)
|
||||
Requires-Dist: orjson (>=3.9.7,<4.0.0) ; extra == "proxy"
|
||||
Requires-Dist: polars (>=1.31.0,<2.0.0) ; (python_version >= "3.10") and (extra == "proxy")
|
||||
Requires-Dist: prisma (>=0.11.0,<0.12.0) ; extra == "extra-proxy"
|
||||
Requires-Dist: pydantic (>=2.5.0,<3.0.0)
|
||||
Requires-Dist: pynacl (>=1.5.0,<2.0.0) ; extra == "proxy"
|
||||
Requires-Dist: pyroscope-io (>=0.8,<0.9) ; (sys_platform != "win32") and (extra == "proxy")
|
||||
Requires-Dist: python-dotenv (>=0.2.0)
|
||||
Requires-Dist: python-multipart (>=0.0.20) ; extra == "proxy"
|
||||
Requires-Dist: pyyaml (>=6.0.1,<7.0.0) ; extra == "proxy"
|
||||
Requires-Dist: redisvl (>=0.4.1,<0.5.0) ; (python_version >= "3.9" and python_version < "3.14") and (extra == "extra-proxy")
|
||||
Requires-Dist: resend (>=0.8.0) ; extra == "extra-proxy"
|
||||
Requires-Dist: rich (>=13.7.1,<14.0.0) ; extra == "proxy"
|
||||
Requires-Dist: rq ; extra == "proxy"
|
||||
Requires-Dist: semantic-router (>=0.1.12) ; (python_version >= "3.9" and python_version < "3.14") and (extra == "semantic-router")
|
||||
Requires-Dist: soundfile (>=0.12.1,<0.13.0) ; extra == "proxy"
|
||||
Requires-Dist: tiktoken (>=0.7.0)
|
||||
Requires-Dist: tokenizers
|
||||
Requires-Dist: uvicorn (>=0.32.1,<1.0.0) ; extra == "proxy"
|
||||
Requires-Dist: uvloop (>=0.21.0,<0.22.0) ; (sys_platform != "win32") and (extra == "proxy")
|
||||
Requires-Dist: websockets (>=15.0.1,<16.0.0) ; extra == "proxy"
|
||||
Project-URL: Documentation, https://docs.litellm.ai
|
||||
Project-URL: Homepage, https://litellm.ai
|
||||
Project-URL: Repository, https://github.com/BerriAI/litellm
|
||||
Project-URL: documentation, https://docs.litellm.ai
|
||||
Project-URL: homepage, https://litellm.ai
|
||||
Project-URL: repository, https://github.com/BerriAI/litellm
|
||||
Description-Content-Type: text/markdown
|
||||
|
||||
<h1 align="center">
|
||||
🚅 LiteLLM
|
||||
</h1>
|
||||
<p align="center">
|
||||
<p align="center">Call 100+ LLMs in OpenAI format. [Bedrock, Azure, OpenAI, VertexAI, Anthropic, Groq, etc.]
|
||||
</p>
|
||||
<p align="center">
|
||||
<a href="https://render.com/deploy?repo=https://github.com/BerriAI/litellm" target="_blank" rel="nofollow"><img src="https://render.com/images/deploy-to-render-button.svg" alt="Deploy to Render"></a>
|
||||
<a href="https://railway.app/template/HLP0Ub?referralCode=jch2ME">
|
||||
<img src="https://railway.app/button.svg" alt="Deploy on Railway">
|
||||
</a>
|
||||
</p>
|
||||
</p>
|
||||
<h4 align="center"><a href="https://docs.litellm.ai/docs/simple_proxy" target="_blank">LiteLLM Proxy Server (AI Gateway)</a> | <a href="https://docs.litellm.ai/docs/enterprise#hosted-litellm-proxy" target="_blank"> Hosted Proxy</a> | <a href="https://docs.litellm.ai/docs/enterprise"target="_blank">Enterprise Tier</a></h4>
|
||||
<h4 align="center">
|
||||
<a href="https://pypi.org/project/litellm/" target="_blank">
|
||||
<img src="https://img.shields.io/pypi/v/litellm.svg" alt="PyPI Version">
|
||||
</a>
|
||||
<a href="https://www.ycombinator.com/companies/berriai">
|
||||
<img src="https://img.shields.io/badge/Y%20Combinator-W23-orange?style=flat-square" alt="Y Combinator W23">
|
||||
</a>
|
||||
<a href="https://wa.link/huol9n">
|
||||
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=WhatsApp&color=success&logo=WhatsApp&style=flat-square" alt="Whatsapp">
|
||||
</a>
|
||||
<a href="https://discord.gg/wuPM9dRgDw">
|
||||
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=Discord&color=blue&logo=Discord&style=flat-square" alt="Discord">
|
||||
</a>
|
||||
<a href="https://www.litellm.ai/support">
|
||||
<img src="https://img.shields.io/static/v1?label=Chat%20on&message=Slack&color=black&logo=Slack&style=flat-square" alt="Slack">
|
||||
</a>
|
||||
</h4>
|
||||
|
||||
<img width="2688" height="1600" alt="Group 7154 (1)" src="https://github.com/user-attachments/assets/c5ee0412-6fb5-4fb6-ab5b-bafae4209ca6" />
|
||||
|
||||
|
||||
## Use LiteLLM for
|
||||
|
||||
<details open>
|
||||
<summary><b>LLMs</b> - Call 100+ LLMs (Python SDK + AI Gateway)</summary>
|
||||
|
||||
[**All Supported Endpoints**](https://docs.litellm.ai/docs/supported_endpoints) - `/chat/completions`, `/responses`, `/embeddings`, `/images`, `/audio`, `/batches`, `/rerank`, `/a2a`, `/messages` and more.
|
||||
|
||||
### Python SDK
|
||||
|
||||
```shell
|
||||
pip install litellm
|
||||
```
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = "your-openai-key"
|
||||
os.environ["ANTHROPIC_API_KEY"] = "your-anthropic-key"
|
||||
|
||||
# OpenAI
|
||||
response = completion(model="openai/gpt-4o", messages=[{"role": "user", "content": "Hello!"}])
|
||||
|
||||
# Anthropic
|
||||
response = completion(model="anthropic/claude-sonnet-4-20250514", messages=[{"role": "user", "content": "Hello!"}])
|
||||
```
|
||||
|
||||
### AI Gateway (Proxy Server)
|
||||
|
||||
[**Getting Started - E2E Tutorial**](https://docs.litellm.ai/docs/proxy/docker_quick_start) - Setup virtual keys, make your first request
|
||||
|
||||
```shell
|
||||
pip install 'litellm[proxy]'
|
||||
litellm --model gpt-4o
|
||||
```
|
||||
|
||||
```python
|
||||
import openai
|
||||
|
||||
client = openai.OpenAI(api_key="anything", base_url="http://0.0.0.0:4000")
|
||||
response = client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "Hello!"}]
|
||||
)
|
||||
```
|
||||
|
||||
[**Docs: LLM Providers**](https://docs.litellm.ai/docs/providers)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Agents</b> - Invoke A2A Agents (Python SDK + AI Gateway)</summary>
|
||||
|
||||
[**Supported Providers**](https://docs.litellm.ai/docs/a2a#add-a2a-agents) - LangGraph, Vertex AI Agent Engine, Azure AI Foundry, Bedrock AgentCore, Pydantic AI
|
||||
|
||||
### Python SDK - A2A Protocol
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
|
||||
### AI Gateway (Proxy Server)
|
||||
|
||||
**Step 1.** [Add your Agent to the AI Gateway](https://docs.litellm.ai/docs/a2a#adding-your-agent)
|
||||
|
||||
**Step 2.** Call Agent via A2A SDK
|
||||
|
||||
```python
|
||||
from a2a.client import A2ACardResolver, A2AClient
|
||||
from a2a.types import MessageSendParams, SendMessageRequest
|
||||
from uuid import uuid4
|
||||
import httpx
|
||||
|
||||
base_url = "http://localhost:4000/a2a/my-agent" # LiteLLM proxy + agent name
|
||||
headers = {"Authorization": "Bearer sk-1234"} # LiteLLM Virtual Key
|
||||
|
||||
async with httpx.AsyncClient(headers=headers) as httpx_client:
|
||||
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=base_url)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
client = A2AClient(httpx_client=httpx_client, agent_card=agent_card)
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
|
||||
[**Docs: A2A Agent Gateway**](https://docs.litellm.ai/docs/a2a)
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>MCP Tools</b> - Connect MCP servers to any LLM (Python SDK + AI Gateway)</summary>
|
||||
|
||||
### Python SDK - MCP Bridge
|
||||
|
||||
```python
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
from litellm import experimental_mcp_client
|
||||
import litellm
|
||||
|
||||
server_params = StdioServerParameters(command="python", args=["mcp_server.py"])
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
await session.initialize()
|
||||
|
||||
# Load MCP tools in OpenAI format
|
||||
tools = await experimental_mcp_client.load_mcp_tools(session=session, format="openai")
|
||||
|
||||
# Use with any LiteLLM model
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": "What's 3 + 5?"}],
|
||||
tools=tools
|
||||
)
|
||||
```
|
||||
|
||||
### AI Gateway - MCP Gateway
|
||||
|
||||
**Step 1.** [Add your MCP Server to the AI Gateway](https://docs.litellm.ai/docs/mcp#adding-your-mcp)
|
||||
|
||||
**Step 2.** Call MCP tools via `/chat/completions`
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/v1/chat/completions' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Summarize the latest open PR"}],
|
||||
"tools": [{
|
||||
"type": "mcp",
|
||||
"server_url": "litellm_proxy/mcp/github",
|
||||
"server_label": "github_mcp",
|
||||
"require_approval": "never"
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
### Use with Cursor IDE
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"LiteLLM": {
|
||||
"url": "http://localhost:4000/mcp/",
|
||||
"headers": {
|
||||
"x-litellm-api-key": "Bearer sk-1234"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
[**Docs: MCP Gateway**](https://docs.litellm.ai/docs/mcp)
|
||||
|
||||
</details>
|
||||
|
||||
---
|
||||
|
||||
## How to use LiteLLM
|
||||
|
||||
You can use LiteLLM through either the Proxy Server or Python SDK. Both gives you a unified interface to access multiple LLMs (100+ LLMs). Choose the option that best fits your needs:
|
||||
|
||||
<table style={{width: '100%', tableLayout: 'fixed'}}>
|
||||
<thead>
|
||||
<tr>
|
||||
<th style={{width: '14%'}}></th>
|
||||
<th style={{width: '43%'}}><strong><a href="https://docs.litellm.ai/docs/simple_proxy">LiteLLM AI Gateway</a></strong></th>
|
||||
<th style={{width: '43%'}}><strong><a href="https://docs.litellm.ai/docs/">LiteLLM Python SDK</a></strong></th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr>
|
||||
<td style={{width: '14%'}}><strong>Use Case</strong></td>
|
||||
<td style={{width: '43%'}}>Central service (LLM Gateway) to access multiple LLMs</td>
|
||||
<td style={{width: '43%'}}>Use LiteLLM directly in your Python code</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style={{width: '14%'}}><strong>Who Uses It?</strong></td>
|
||||
<td style={{width: '43%'}}>Gen AI Enablement / ML Platform Teams</td>
|
||||
<td style={{width: '43%'}}>Developers building LLM projects</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td style={{width: '14%'}}><strong>Key Features</strong></td>
|
||||
<td style={{width: '43%'}}>Centralized API gateway with authentication and authorization, multi-tenant cost tracking and spend management per project/user, per-project customization (logging, guardrails, caching), virtual keys for secure access control, admin dashboard UI for monitoring and management</td>
|
||||
<td style={{width: '43%'}}>Direct Python library integration in your codebase, Router with retry/fallback logic across multiple deployments (e.g. Azure/OpenAI) - <a href="https://docs.litellm.ai/docs/routing">Router</a>, application-level load balancing and cost tracking, exception handling with OpenAI-compatible errors, observability callbacks (Lunary, MLflow, Langfuse, etc.)</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
LiteLLM Performance: **8ms P95 latency** at 1k RPS (See benchmarks [here](https://docs.litellm.ai/docs/benchmarks))
|
||||
|
||||
[**Jump to LiteLLM Proxy (LLM Gateway) Docs**](https://docs.litellm.ai/docs/simple_proxy) <br>
|
||||
[**Jump to Supported LLM Providers**](https://docs.litellm.ai/docs/providers)
|
||||
|
||||
**Stable Release:** Use docker images with the `-stable` tag. These have undergone 12 hour load tests, before being published. [More information about the release cycle here](https://docs.litellm.ai/docs/proxy/release_cycle)
|
||||
|
||||
Support for more providers. Missing a provider or LLM Platform, raise a [feature request](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=enhancement&projects=&template=feature_request.yml&title=%5BFeature%5D%3A+).
|
||||
|
||||
## OSS Adopters
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img height="60" alt="Stripe" src="https://github.com/user-attachments/assets/f7296d4f-9fbd-460d-9d05-e4df31697c4b" /></td>
|
||||
<td><img height="60" alt="Google ADK" src="https://github.com/user-attachments/assets/caf270a2-5aee-45c4-8222-41a2070c4f19" /></td>
|
||||
<td><img height="60" alt="Greptile" src="https://github.com/user-attachments/assets/0be4bd8a-7cfa-48d3-9090-f415fe948280" /></td>
|
||||
<td><img height="60" alt="OpenHands" src="https://github.com/user-attachments/assets/a6150c4c-149e-4cae-888b-8b92be6e003f" /></td>
|
||||
<td><h2>Netflix</h2></td>
|
||||
<td><img height="60" alt="OpenAI Agents SDK" src="https://github.com/user-attachments/assets/c02f7be0-8c2e-4d27-aea7-7c024bfaebc0" /></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Supported Providers ([Website Supported Models](https://models.litellm.ai/) | [Docs](https://docs.litellm.ai/docs/providers))
|
||||
|
||||
| Provider | `/chat/completions` | `/messages` | `/responses` | `/embeddings` | `/image/generations` | `/audio/transcriptions` | `/audio/speech` | `/moderations` | `/batches` | `/rerank` |
|
||||
|-------------------------------------------------------------------------------------|---------------------|-------------|--------------|---------------|----------------------|-------------------------|-----------------|----------------|-----------|-----------|
|
||||
| [Abliteration (`abliteration`)](https://docs.litellm.ai/docs/providers/abliteration) | ✅ | | | | | | | | | |
|
||||
| [AI/ML API (`aiml`)](https://docs.litellm.ai/docs/providers/aiml) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
| [AI21 (`ai21`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [AI21 Chat (`ai21_chat`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Aleph Alpha](https://docs.litellm.ai/docs/providers/aleph_alpha) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Amazon Nova](https://docs.litellm.ai/docs/providers/amazon_nova) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Anthropic (`anthropic`)](https://docs.litellm.ai/docs/providers/anthropic) | ✅ | ✅ | ✅ | | | | | | ✅ | |
|
||||
| [Anthropic Text (`anthropic_text`)](https://docs.litellm.ai/docs/providers/anthropic) | ✅ | ✅ | ✅ | | | | | | ✅ | |
|
||||
| [Anyscale](https://docs.litellm.ai/docs/providers/anyscale) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [AssemblyAI (`assemblyai`)](https://docs.litellm.ai/docs/pass_through/assembly_ai) | ✅ | ✅ | ✅ | | | ✅ | | | | |
|
||||
| [Auto Router (`auto_router`)](https://docs.litellm.ai/docs/proxy/auto_routing) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [AWS - Bedrock (`bedrock`)](https://docs.litellm.ai/docs/providers/bedrock) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
|
||||
| [AWS - Sagemaker (`sagemaker`)](https://docs.litellm.ai/docs/providers/aws_sagemaker) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [Azure (`azure`)](https://docs.litellm.ai/docs/providers/azure) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Azure AI (`azure_ai`)](https://docs.litellm.ai/docs/providers/azure_ai) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Azure Text (`azure_text`)](https://docs.litellm.ai/docs/providers/azure) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Baseten (`baseten`)](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Bytez (`bytez`)](https://docs.litellm.ai/docs/providers/bytez) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Cerebras (`cerebras`)](https://docs.litellm.ai/docs/providers/cerebras) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Clarifai (`clarifai`)](https://docs.litellm.ai/docs/providers/clarifai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Cloudflare AI Workers (`cloudflare`)](https://docs.litellm.ai/docs/providers/cloudflare_workers) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Codestral (`codestral`)](https://docs.litellm.ai/docs/providers/codestral) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Cohere (`cohere`)](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
|
||||
| [Cohere Chat (`cohere_chat`)](https://docs.litellm.ai/docs/providers/cohere) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [CometAPI (`cometapi`)](https://docs.litellm.ai/docs/providers/cometapi) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [CompactifAI (`compactifai`)](https://docs.litellm.ai/docs/providers/compactifai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Custom (`custom`)](https://docs.litellm.ai/docs/providers/custom_llm_server) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Custom OpenAI (`custom_openai`)](https://docs.litellm.ai/docs/providers/openai_compatible) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Dashscope (`dashscope`)](https://docs.litellm.ai/docs/providers/dashscope) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Databricks (`databricks`)](https://docs.litellm.ai/docs/providers/databricks) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [DataRobot (`datarobot`)](https://docs.litellm.ai/docs/providers/datarobot) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Deepgram (`deepgram`)](https://docs.litellm.ai/docs/providers/deepgram) | ✅ | ✅ | ✅ | | | ✅ | | | | |
|
||||
| [DeepInfra (`deepinfra`)](https://docs.litellm.ai/docs/providers/deepinfra) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Deepseek (`deepseek`)](https://docs.litellm.ai/docs/providers/deepseek) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [ElevenLabs (`elevenlabs`)](https://docs.litellm.ai/docs/providers/elevenlabs) | ✅ | ✅ | ✅ | | | ✅ | ✅ | | | |
|
||||
| [Empower (`empower`)](https://docs.litellm.ai/docs/providers/empower) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Fal AI (`fal_ai`)](https://docs.litellm.ai/docs/providers/fal_ai) | ✅ | ✅ | ✅ | | ✅ | | | | | |
|
||||
| [Featherless AI (`featherless_ai`)](https://docs.litellm.ai/docs/providers/featherless_ai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Fireworks AI (`fireworks_ai`)](https://docs.litellm.ai/docs/providers/fireworks_ai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [FriendliAI (`friendliai`)](https://docs.litellm.ai/docs/providers/friendliai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Galadriel (`galadriel`)](https://docs.litellm.ai/docs/providers/galadriel) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [GitHub Copilot (`github_copilot`)](https://docs.litellm.ai/docs/providers/github_copilot) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [GitHub Models (`github`)](https://docs.litellm.ai/docs/providers/github) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Google - PaLM](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Google - Vertex AI (`vertex_ai`)](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
| [Google AI Studio - Gemini (`gemini`)](https://docs.litellm.ai/docs/providers/gemini) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [GradientAI (`gradient_ai`)](https://docs.litellm.ai/docs/providers/gradient_ai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Groq AI (`groq`)](https://docs.litellm.ai/docs/providers/groq) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Heroku (`heroku`)](https://docs.litellm.ai/docs/providers/heroku) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Hosted VLLM (`hosted_vllm`)](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Huggingface (`huggingface`)](https://docs.litellm.ai/docs/providers/huggingface) | ✅ | ✅ | ✅ | ✅ | | | | | | ✅ |
|
||||
| [Hyperbolic (`hyperbolic`)](https://docs.litellm.ai/docs/providers/hyperbolic) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [IBM - Watsonx.ai (`watsonx`)](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [Infinity (`infinity`)](https://docs.litellm.ai/docs/providers/infinity) | | | | ✅ | | | | | | |
|
||||
| [Jina AI (`jina_ai`)](https://docs.litellm.ai/docs/providers/jina_ai) | | | | ✅ | | | | | | |
|
||||
| [Lambda AI (`lambda_ai`)](https://docs.litellm.ai/docs/providers/lambda_ai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Lemonade (`lemonade`)](https://docs.litellm.ai/docs/providers/lemonade) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [LiteLLM Proxy (`litellm_proxy`)](https://docs.litellm.ai/docs/providers/litellm_proxy) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
| [Llamafile (`llamafile`)](https://docs.litellm.ai/docs/providers/llamafile) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [LM Studio (`lm_studio`)](https://docs.litellm.ai/docs/providers/lm_studio) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Maritalk (`maritalk`)](https://docs.litellm.ai/docs/providers/maritalk) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Meta - Llama API (`meta_llama`)](https://docs.litellm.ai/docs/providers/meta_llama) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Mistral AI API (`mistral`)](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [Moonshot (`moonshot`)](https://docs.litellm.ai/docs/providers/moonshot) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Morph (`morph`)](https://docs.litellm.ai/docs/providers/morph) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Nebius AI Studio (`nebius`)](https://docs.litellm.ai/docs/providers/nebius) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [NLP Cloud (`nlp_cloud`)](https://docs.litellm.ai/docs/providers/nlp_cloud) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Novita AI (`novita`)](https://novita.ai/models/llm?utm_source=github_litellm&utm_medium=github_readme&utm_campaign=github_link) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Nscale (`nscale`)](https://docs.litellm.ai/docs/providers/nscale) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Nvidia NIM (`nvidia_nim`)](https://docs.litellm.ai/docs/providers/nvidia_nim) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [OCI (`oci`)](https://docs.litellm.ai/docs/providers/oci) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Ollama (`ollama`)](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | ✅ | | | | | | |
|
||||
| [Ollama Chat (`ollama_chat`)](https://docs.litellm.ai/docs/providers/ollama) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Oobabooga (`oobabooga`)](https://docs.litellm.ai/docs/providers/openai_compatible) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [OpenAI (`openai`)](https://docs.litellm.ai/docs/providers/openai) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [OpenAI-like (`openai_like`)](https://docs.litellm.ai/docs/providers/openai_compatible) | | | | ✅ | | | | | | |
|
||||
| [OpenRouter (`openrouter`)](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [OVHCloud AI Endpoints (`ovhcloud`)](https://docs.litellm.ai/docs/providers/ovhcloud) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Perplexity AI (`perplexity`)](https://docs.litellm.ai/docs/providers/perplexity) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Petals (`petals`)](https://docs.litellm.ai/docs/providers/petals) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Predibase (`predibase`)](https://docs.litellm.ai/docs/providers/predibase) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Recraft (`recraft`)](https://docs.litellm.ai/docs/providers/recraft) | | | | | ✅ | | | | | |
|
||||
| [Replicate (`replicate`)](https://docs.litellm.ai/docs/providers/replicate) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Sagemaker Chat (`sagemaker_chat`)](https://docs.litellm.ai/docs/providers/aws_sagemaker) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Sambanova (`sambanova`)](https://docs.litellm.ai/docs/providers/sambanova) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Snowflake (`snowflake`)](https://docs.litellm.ai/docs/providers/snowflake) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Text Completion Codestral (`text-completion-codestral`)](https://docs.litellm.ai/docs/providers/codestral) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Text Completion OpenAI (`text-completion-openai`)](https://docs.litellm.ai/docs/providers/text_completion_openai) | ✅ | ✅ | ✅ | | | ✅ | ✅ | ✅ | ✅ | |
|
||||
| [Together AI (`together_ai`)](https://docs.litellm.ai/docs/providers/togetherai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Topaz (`topaz`)](https://docs.litellm.ai/docs/providers/topaz) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Triton (`triton`)](https://docs.litellm.ai/docs/providers/triton-inference-server) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [V0 (`v0`)](https://docs.litellm.ai/docs/providers/v0) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Vercel AI Gateway (`vercel_ai_gateway`)](https://docs.litellm.ai/docs/providers/vercel_ai_gateway) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [VLLM (`vllm`)](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Volcengine (`volcengine`)](https://docs.litellm.ai/docs/providers/volcano) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Voyage AI (`voyage`)](https://docs.litellm.ai/docs/providers/voyage) | | | | ✅ | | | | | | |
|
||||
| [WandB Inference (`wandb`)](https://docs.litellm.ai/docs/providers/wandb_inference) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Watsonx Text (`watsonx_text`)](https://docs.litellm.ai/docs/providers/watsonx) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [xAI (`xai`)](https://docs.litellm.ai/docs/providers/xai) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [Xinference (`xinference`)](https://docs.litellm.ai/docs/providers/xinference) | | | | ✅ | | | | | | |
|
||||
|
||||
[**Read the Docs**](https://docs.litellm.ai/docs/)
|
||||
|
||||
## Run in Developer mode
|
||||
### Services
|
||||
1. Setup .env file in root
|
||||
2. Run dependant services `docker-compose up db prometheus`
|
||||
|
||||
### Backend
|
||||
1. (In root) create virtual environment `python -m venv .venv`
|
||||
2. Activate virtual environment `source .venv/bin/activate`
|
||||
3. Install dependencies `pip install -e ".[all]"`
|
||||
4. `pip install prisma`
|
||||
5. `prisma generate`
|
||||
6. Start proxy backend `python litellm/proxy/proxy_cli.py`
|
||||
|
||||
### Frontend
|
||||
1. Navigate to `ui/litellm-dashboard`
|
||||
2. Install dependencies `npm install`
|
||||
3. Run `npm run dev` to start the dashboard
|
||||
|
||||
# Enterprise
|
||||
For companies that need better security, user management and professional support
|
||||
|
||||
[Talk to founders](https://calendly.com/d/cx9p-5yf-2nm/litellm-introductions)
|
||||
|
||||
This covers:
|
||||
- ✅ **Features under the [LiteLLM Commercial License](https://docs.litellm.ai/docs/proxy/enterprise):**
|
||||
- ✅ **Feature Prioritization**
|
||||
- ✅ **Custom Integrations**
|
||||
- ✅ **Professional Support - Dedicated discord + slack**
|
||||
- ✅ **Custom SLAs**
|
||||
- ✅ **Secure access with Single Sign-On**
|
||||
|
||||
# Contributing
|
||||
|
||||
We welcome contributions to LiteLLM! Whether you're fixing bugs, adding features, or improving documentation, we appreciate your help.
|
||||
|
||||
## Quick Start for Contributors
|
||||
|
||||
This requires poetry to be installed.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/BerriAI/litellm.git
|
||||
cd litellm
|
||||
make install-dev # Install development dependencies
|
||||
make format # Format your code
|
||||
make lint # Run all linting checks
|
||||
make test-unit # Run unit tests
|
||||
make format-check # Check formatting only
|
||||
```
|
||||
|
||||
For detailed contributing guidelines, see [CONTRIBUTING.md](CONTRIBUTING.md).
|
||||
|
||||
## Code Quality / Linting
|
||||
|
||||
LiteLLM follows the [Google Python Style Guide](https://google.github.io/styleguide/pyguide.html).
|
||||
|
||||
Our automated checks include:
|
||||
- **Black** for code formatting
|
||||
- **Ruff** for linting and code quality
|
||||
- **MyPy** for type checking
|
||||
- **Circular import detection**
|
||||
- **Import safety checks**
|
||||
|
||||
|
||||
All these checks must pass before your PR can be merged.
|
||||
|
||||
|
||||
# Support / talk with founders
|
||||
|
||||
- [Schedule Demo 👋](https://calendly.com/d/4mp-gd3-k5k/berriai-1-1-onboarding-litellm-hosted-version)
|
||||
- [Community Discord 💭](https://discord.gg/wuPM9dRgDw)
|
||||
- [Community Slack 💭](https://www.litellm.ai/support)
|
||||
- Our numbers 📞 +1 (770) 8783-106 / +1 (412) 618-6238
|
||||
- Our emails ✉️ ishaan@berri.ai / krrish@berri.ai
|
||||
|
||||
# Why did we build this
|
||||
|
||||
- **Need for simplicity**: Our code started to get extremely complicated managing & translating calls between Azure, OpenAI and Cohere.
|
||||
|
||||
# Contributors
|
||||
|
||||
<!-- ALL-CONTRIBUTORS-LIST:START - Do not remove or modify this section -->
|
||||
<!-- prettier-ignore-start -->
|
||||
<!-- markdownlint-disable -->
|
||||
|
||||
<!-- markdownlint-restore -->
|
||||
<!-- prettier-ignore-end -->
|
||||
|
||||
<!-- ALL-CONTRIBUTORS-LIST:END -->
|
||||
|
||||
<a href="https://github.com/BerriAI/litellm/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=BerriAI/litellm" />
|
||||
</a>
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,4 @@
|
||||
Wheel-Version: 1.0
|
||||
Generator: poetry-core 1.9.1
|
||||
Root-Is-Purelib: true
|
||||
Tag: py3-none-any
|
||||
@@ -0,0 +1,4 @@
|
||||
[console_scripts]
|
||||
litellm=litellm:run_server
|
||||
litellm-proxy=litellm.proxy.client.cli:cli
|
||||
|
||||
2170
llm-gateway-competitors/litellm-wheel-src/litellm/__init__.py
Normal file
2170
llm-gateway-competitors/litellm-wheel-src/litellm/__init__.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,449 @@
|
||||
"""
|
||||
Lazy Import System
|
||||
|
||||
This module implements lazy loading for LiteLLM attributes. Instead of importing
|
||||
everything when the module loads, we only import things when they're actually used.
|
||||
|
||||
How it works:
|
||||
1. When someone accesses `litellm.some_attribute`, Python calls __getattr__ in __init__.py
|
||||
2. __getattr__ looks up the attribute name in a registry
|
||||
3. The registry points to a handler function (like _lazy_import_utils)
|
||||
4. The handler function imports the module and returns the attribute
|
||||
5. The result is cached so we don't import it again
|
||||
|
||||
This makes importing litellm much faster because we don't load heavy dependencies
|
||||
until they're actually needed.
|
||||
"""
|
||||
import importlib
|
||||
import sys
|
||||
from typing import Any, Optional, cast, Callable
|
||||
|
||||
# Import all the data structures that define what can be lazy-loaded
|
||||
# These are just lists of names and maps of where to find them
|
||||
from ._lazy_imports_registry import (
|
||||
# Name tuples
|
||||
COST_CALCULATOR_NAMES,
|
||||
LITELLM_LOGGING_NAMES,
|
||||
UTILS_NAMES,
|
||||
TOKEN_COUNTER_NAMES,
|
||||
LLM_CLIENT_CACHE_NAMES,
|
||||
BEDROCK_TYPES_NAMES,
|
||||
TYPES_UTILS_NAMES,
|
||||
CACHING_NAMES,
|
||||
HTTP_HANDLER_NAMES,
|
||||
DOTPROMPT_NAMES,
|
||||
LLM_CONFIG_NAMES,
|
||||
TYPES_NAMES,
|
||||
LLM_PROVIDER_LOGIC_NAMES,
|
||||
UTILS_MODULE_NAMES,
|
||||
# Import maps
|
||||
_UTILS_IMPORT_MAP,
|
||||
_COST_CALCULATOR_IMPORT_MAP,
|
||||
_TYPES_UTILS_IMPORT_MAP,
|
||||
_TOKEN_COUNTER_IMPORT_MAP,
|
||||
_BEDROCK_TYPES_IMPORT_MAP,
|
||||
_CACHING_IMPORT_MAP,
|
||||
_LITELLM_LOGGING_IMPORT_MAP,
|
||||
_DOTPROMPT_IMPORT_MAP,
|
||||
_TYPES_IMPORT_MAP,
|
||||
_LLM_CONFIGS_IMPORT_MAP,
|
||||
_LLM_PROVIDER_LOGIC_IMPORT_MAP,
|
||||
_UTILS_MODULE_IMPORT_MAP,
|
||||
)
|
||||
|
||||
|
||||
def _get_litellm_globals() -> dict:
|
||||
"""
|
||||
Get the globals dictionary of the litellm module.
|
||||
|
||||
This is where we cache imported attributes so we don't import them twice.
|
||||
When you do `litellm.some_function`, it gets stored in this dictionary.
|
||||
"""
|
||||
return sys.modules["litellm"].__dict__
|
||||
|
||||
|
||||
def _get_utils_globals() -> dict:
|
||||
"""
|
||||
Get the globals dictionary of the utils module.
|
||||
|
||||
This is where we cache imported attributes so we don't import them twice.
|
||||
When you do `litellm.utils.some_function`, it gets stored in this dictionary.
|
||||
"""
|
||||
return sys.modules["litellm.utils"].__dict__
|
||||
|
||||
|
||||
# These are special lazy loaders for things that are used internally
|
||||
# They're separate from the main lazy import system because they have specific use cases
|
||||
|
||||
# Lazy loader for default encoding - avoids importing heavy tiktoken library at startup
|
||||
_default_encoding: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_default_encoding() -> Any:
|
||||
"""
|
||||
Lazily load and cache the default OpenAI encoding.
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.default_encoding` (and thus tiktoken)
|
||||
at `litellm` import time. The encoding is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the encoding but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _default_encoding
|
||||
if _default_encoding is None:
|
||||
from litellm.litellm_core_utils.default_encoding import encoding
|
||||
|
||||
_default_encoding = encoding
|
||||
return _default_encoding
|
||||
|
||||
|
||||
# Lazy loader for get_modified_max_tokens to avoid importing token_counter at module import time
|
||||
_get_modified_max_tokens_func: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_modified_max_tokens() -> Any:
|
||||
"""
|
||||
Lazily load and cache the get_modified_max_tokens function.
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
|
||||
The function is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the token counter but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _get_modified_max_tokens_func
|
||||
if _get_modified_max_tokens_func is None:
|
||||
from litellm.litellm_core_utils.token_counter import (
|
||||
get_modified_max_tokens as _get_modified_max_tokens_imported,
|
||||
)
|
||||
|
||||
_get_modified_max_tokens_func = _get_modified_max_tokens_imported
|
||||
return _get_modified_max_tokens_func
|
||||
|
||||
|
||||
# Lazy loader for token_counter to avoid importing token_counter module at module import time
|
||||
_token_counter_new_func: Optional[Any] = None
|
||||
|
||||
|
||||
def _get_token_counter_new() -> Any:
|
||||
"""
|
||||
Lazily load and cache the token_counter function (aliased as token_counter_new).
|
||||
|
||||
This avoids importing `litellm.litellm_core_utils.token_counter` at `litellm` import time.
|
||||
The function is cached after the first import.
|
||||
|
||||
This is used internally by utils.py functions that need the token counter but shouldn't
|
||||
trigger its import during module load.
|
||||
"""
|
||||
global _token_counter_new_func
|
||||
if _token_counter_new_func is None:
|
||||
from litellm.litellm_core_utils.token_counter import (
|
||||
token_counter as _token_counter_imported,
|
||||
)
|
||||
|
||||
_token_counter_new_func = _token_counter_imported
|
||||
return _token_counter_new_func
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# MAIN LAZY IMPORT SYSTEM
|
||||
# ============================================================================
|
||||
|
||||
# This registry maps attribute names (like "ModelResponse") to handler functions
|
||||
# It's built once the first time someone accesses a lazy-loaded attribute
|
||||
# Example: {"ModelResponse": _lazy_import_utils, "Cache": _lazy_import_caching, ...}
|
||||
_LAZY_IMPORT_REGISTRY: Optional[dict[str, Callable[[str], Any]]] = None
|
||||
|
||||
|
||||
def _get_lazy_import_registry() -> dict[str, Callable[[str], Any]]:
|
||||
"""
|
||||
Build the registry that maps attribute names to their handler functions.
|
||||
|
||||
This is called once, the first time someone accesses a lazy-loaded attribute.
|
||||
After that, we just look up the handler function in this dictionary.
|
||||
|
||||
Returns:
|
||||
Dictionary like {"ModelResponse": _lazy_import_utils, ...}
|
||||
"""
|
||||
global _LAZY_IMPORT_REGISTRY
|
||||
if _LAZY_IMPORT_REGISTRY is None:
|
||||
# Build the registry by going through each category and mapping
|
||||
# all the names in that category to their handler function
|
||||
_LAZY_IMPORT_REGISTRY = {}
|
||||
# For each category, map all its names to the handler function
|
||||
# Example: All names in UTILS_NAMES get mapped to _lazy_import_utils
|
||||
for name in COST_CALCULATOR_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_cost_calculator
|
||||
for name in LITELLM_LOGGING_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_litellm_logging
|
||||
for name in UTILS_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils
|
||||
for name in TOKEN_COUNTER_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_token_counter
|
||||
for name in LLM_CLIENT_CACHE_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_client_cache
|
||||
for name in BEDROCK_TYPES_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_bedrock_types
|
||||
for name in TYPES_UTILS_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types_utils
|
||||
for name in CACHING_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_caching
|
||||
for name in HTTP_HANDLER_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_http_handlers
|
||||
for name in DOTPROMPT_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_dotprompt
|
||||
for name in LLM_CONFIG_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_configs
|
||||
for name in TYPES_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_types
|
||||
for name in LLM_PROVIDER_LOGIC_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_llm_provider_logic
|
||||
for name in UTILS_MODULE_NAMES:
|
||||
_LAZY_IMPORT_REGISTRY[name] = _lazy_import_utils_module
|
||||
|
||||
return _LAZY_IMPORT_REGISTRY
|
||||
|
||||
|
||||
def _generic_lazy_import(
|
||||
name: str, import_map: dict[str, tuple[str, str]], category: str
|
||||
) -> Any:
|
||||
"""
|
||||
Generic function that handles lazy importing for most attributes.
|
||||
|
||||
This is the workhorse function - it does the actual importing and caching.
|
||||
Most handler functions just call this with their specific import map.
|
||||
|
||||
Steps:
|
||||
1. Check if the name exists in the import map (if not, raise error)
|
||||
2. Check if we've already imported it (if yes, return cached value)
|
||||
3. Look up where to find it (module_path and attr_name from the map)
|
||||
4. Import the module (Python caches this automatically)
|
||||
5. Get the attribute from the module
|
||||
6. Cache it in _globals so we don't import again
|
||||
7. Return it
|
||||
|
||||
Args:
|
||||
name: The attribute name someone is trying to access (e.g., "ModelResponse")
|
||||
import_map: Dictionary telling us where to find each attribute
|
||||
Format: {"ModelResponse": (".utils", "ModelResponse")}
|
||||
category: Just for error messages (e.g., "Utils", "Cost calculator")
|
||||
"""
|
||||
# Step 1: Make sure this attribute exists in our map
|
||||
if name not in import_map:
|
||||
raise AttributeError(f"{category} lazy import: unknown attribute {name!r}")
|
||||
|
||||
# Step 2: Get the cache (where we store imported things)
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
# Step 3: If we've already imported it, just return the cached version
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Step 4: Look up where to find this attribute
|
||||
# The map tells us: (module_path, attribute_name)
|
||||
# Example: (".utils", "ModelResponse") means "look in .utils module, get ModelResponse"
|
||||
module_path, attr_name = import_map[name]
|
||||
|
||||
# Step 5: Import the module
|
||||
# Python automatically caches modules in sys.modules, so calling this twice is fast
|
||||
# If module_path starts with ".", it's a relative import (needs package="litellm")
|
||||
# Otherwise it's an absolute import (like "litellm.caching.caching")
|
||||
if module_path.startswith("."):
|
||||
module = importlib.import_module(module_path, package="litellm")
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Step 6: Get the actual attribute from the module
|
||||
# Example: getattr(utils_module, "ModelResponse") returns the ModelResponse class
|
||||
value = getattr(module, attr_name)
|
||||
|
||||
# Step 7: Cache it so we don't have to import again next time
|
||||
_globals[name] = value
|
||||
|
||||
# Step 8: Return it
|
||||
return value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# HANDLER FUNCTIONS
|
||||
# ============================================================================
|
||||
# These functions are called when someone accesses a lazy-loaded attribute.
|
||||
# Most of them just call _generic_lazy_import with their specific import map.
|
||||
# The registry (above) maps attribute names to these handler functions.
|
||||
|
||||
|
||||
def _lazy_import_utils(name: str) -> Any:
|
||||
"""Handler for utils module attributes (ModelResponse, token_counter, etc.)"""
|
||||
return _generic_lazy_import(name, _UTILS_IMPORT_MAP, "Utils")
|
||||
|
||||
|
||||
def _lazy_import_cost_calculator(name: str) -> Any:
|
||||
"""Handler for cost calculator functions (completion_cost, cost_per_token, etc.)"""
|
||||
return _generic_lazy_import(name, _COST_CALCULATOR_IMPORT_MAP, "Cost calculator")
|
||||
|
||||
|
||||
def _lazy_import_token_counter(name: str) -> Any:
|
||||
"""Handler for token counter utilities"""
|
||||
return _generic_lazy_import(name, _TOKEN_COUNTER_IMPORT_MAP, "Token counter")
|
||||
|
||||
|
||||
def _lazy_import_bedrock_types(name: str) -> Any:
|
||||
"""Handler for Bedrock type aliases"""
|
||||
return _generic_lazy_import(name, _BEDROCK_TYPES_IMPORT_MAP, "Bedrock types")
|
||||
|
||||
|
||||
def _lazy_import_types_utils(name: str) -> Any:
|
||||
"""Handler for types from litellm.types.utils (BudgetConfig, ImageObject, etc.)"""
|
||||
return _generic_lazy_import(name, _TYPES_UTILS_IMPORT_MAP, "Types utils")
|
||||
|
||||
|
||||
def _lazy_import_caching(name: str) -> Any:
|
||||
"""Handler for caching classes (Cache, DualCache, RedisCache, etc.)"""
|
||||
return _generic_lazy_import(name, _CACHING_IMPORT_MAP, "Caching")
|
||||
|
||||
|
||||
def _lazy_import_dotprompt(name: str) -> Any:
|
||||
"""Handler for dotprompt integration globals"""
|
||||
return _generic_lazy_import(name, _DOTPROMPT_IMPORT_MAP, "Dotprompt")
|
||||
|
||||
|
||||
def _lazy_import_types(name: str) -> Any:
|
||||
"""Handler for type classes (GuardrailItem, etc.)"""
|
||||
return _generic_lazy_import(name, _TYPES_IMPORT_MAP, "Types")
|
||||
|
||||
|
||||
def _lazy_import_llm_configs(name: str) -> Any:
|
||||
"""Handler for LLM config classes (AnthropicConfig, OpenAILikeChatConfig, etc.)"""
|
||||
return _generic_lazy_import(name, _LLM_CONFIGS_IMPORT_MAP, "LLM config")
|
||||
|
||||
|
||||
def _lazy_import_litellm_logging(name: str) -> Any:
|
||||
"""Handler for litellm_logging module (Logging, modify_integration)"""
|
||||
return _generic_lazy_import(name, _LITELLM_LOGGING_IMPORT_MAP, "Litellm logging")
|
||||
|
||||
|
||||
def _lazy_import_llm_provider_logic(name: str) -> Any:
|
||||
"""Handler for LLM provider logic functions (get_llm_provider, etc.)"""
|
||||
return _generic_lazy_import(
|
||||
name, _LLM_PROVIDER_LOGIC_IMPORT_MAP, "LLM provider logic"
|
||||
)
|
||||
|
||||
|
||||
def _lazy_import_utils_module(name: str) -> Any:
|
||||
"""
|
||||
Handler for utils module lazy imports.
|
||||
|
||||
This uses a custom implementation because utils module needs to use
|
||||
_get_utils_globals() instead of _get_litellm_globals() for caching.
|
||||
"""
|
||||
# Check if this attribute exists in our map
|
||||
if name not in _UTILS_MODULE_IMPORT_MAP:
|
||||
raise AttributeError(f"Utils module lazy import: unknown attribute {name!r}")
|
||||
|
||||
# Get the cache (where we store imported things) - use utils globals
|
||||
_globals = _get_utils_globals()
|
||||
|
||||
# If we've already imported it, just return the cached version
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Look up where to find this attribute
|
||||
module_path, attr_name = _UTILS_MODULE_IMPORT_MAP[name]
|
||||
|
||||
# Import the module
|
||||
if module_path.startswith("."):
|
||||
module = importlib.import_module(module_path, package="litellm")
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
# Get the actual attribute from the module
|
||||
value = getattr(module, attr_name)
|
||||
|
||||
# Cache it so we don't have to import again next time
|
||||
_globals[name] = value
|
||||
|
||||
# Return it
|
||||
return value
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# SPECIAL HANDLERS
|
||||
# ============================================================================
|
||||
# These handlers have custom logic that doesn't fit the generic pattern
|
||||
|
||||
|
||||
def _lazy_import_llm_client_cache(name: str) -> Any:
|
||||
"""
|
||||
Handler for LLM client cache - has special logic for singleton instance.
|
||||
|
||||
This one is different because:
|
||||
- "LLMClientCache" is the class itself
|
||||
- "in_memory_llm_clients_cache" is a singleton instance of that class
|
||||
So we need custom logic to handle both cases.
|
||||
"""
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
# If already cached, return it
|
||||
if name in _globals:
|
||||
return _globals[name]
|
||||
|
||||
# Import the class
|
||||
module = importlib.import_module("litellm.caching.llm_caching_handler")
|
||||
LLMClientCache = getattr(module, "LLMClientCache")
|
||||
|
||||
# If they want the class itself, return it
|
||||
if name == "LLMClientCache":
|
||||
_globals["LLMClientCache"] = LLMClientCache
|
||||
return LLMClientCache
|
||||
|
||||
# If they want the singleton instance, create it (only once)
|
||||
if name == "in_memory_llm_clients_cache":
|
||||
instance = LLMClientCache()
|
||||
_globals["in_memory_llm_clients_cache"] = instance
|
||||
return instance
|
||||
|
||||
raise AttributeError(f"LLM client cache lazy import: unknown attribute {name!r}")
|
||||
|
||||
|
||||
def _lazy_import_http_handlers(name: str) -> Any:
|
||||
"""
|
||||
Handler for HTTP clients - has special logic for creating client instances.
|
||||
|
||||
This one is different because:
|
||||
- These aren't just imports, they're actual client instances that need to be created
|
||||
- They need configuration (timeout, etc.) from the module globals
|
||||
- They use factory functions instead of direct instantiation
|
||||
"""
|
||||
_globals = _get_litellm_globals()
|
||||
|
||||
if name == "module_level_aclient":
|
||||
# Create an async HTTP client using the factory function
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
# Get timeout from module config (if set)
|
||||
timeout = _globals.get("request_timeout")
|
||||
params = {"timeout": timeout, "client_alias": "module level aclient"}
|
||||
|
||||
# Create the client instance
|
||||
provider_id = cast(Any, "litellm_module_level_client")
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=provider_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Cache it so we don't create it again
|
||||
_globals["module_level_aclient"] = async_client
|
||||
return async_client
|
||||
|
||||
if name == "module_level_client":
|
||||
# Create a sync HTTP client
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
timeout = _globals.get("request_timeout")
|
||||
sync_client = HTTPHandler(timeout=timeout)
|
||||
|
||||
# Cache it
|
||||
_globals["module_level_client"] = sync_client
|
||||
return sync_client
|
||||
|
||||
raise AttributeError(f"HTTP handlers lazy import: unknown attribute {name!r}")
|
||||
File diff suppressed because it is too large
Load Diff
352
llm-gateway-competitors/litellm-wheel-src/litellm/_logging.py
Normal file
352
llm-gateway-competitors/litellm-wheel-src/litellm/_logging.py
Normal file
@@ -0,0 +1,352 @@
|
||||
import ast
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from logging import Formatter
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
|
||||
set_verbose = False
|
||||
|
||||
if set_verbose is True:
|
||||
logging.warning(
|
||||
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||
)
|
||||
json_logs = bool(os.getenv("JSON_LOGS", False))
|
||||
# Create a handler for the logger (you may need to adapt this based on your needs)
|
||||
log_level = os.getenv("LITELLM_LOG", "DEBUG")
|
||||
numeric_level: str = getattr(logging, log_level.upper())
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(numeric_level)
|
||||
|
||||
|
||||
def _try_parse_json_message(message: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Try to parse a log message as JSON. Returns parsed dict if valid, else None.
|
||||
Handles messages that are entirely valid JSON (e.g. json.dumps output).
|
||||
Uses shared safe_json_loads for consistent error handling.
|
||||
"""
|
||||
if not message or not isinstance(message, str):
|
||||
return None
|
||||
msg_stripped = message.strip()
|
||||
if not (msg_stripped.startswith("{") or msg_stripped.startswith("[")):
|
||||
return None
|
||||
parsed = safe_json_loads(message, default=None)
|
||||
if parsed is None or not isinstance(parsed, dict):
|
||||
return None
|
||||
return parsed
|
||||
|
||||
|
||||
def _try_parse_embedded_python_dict(message: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Try to find and parse a Python dict repr (e.g. str(d) or repr(d)) embedded in
|
||||
the message. Handles patterns like:
|
||||
"get_available_deployment for model: X, Selected deployment: {'model_name': '...', ...} for model: X"
|
||||
Uses ast.literal_eval for safe parsing. Returns the parsed dict or None.
|
||||
"""
|
||||
if not message or not isinstance(message, str) or "{" not in message:
|
||||
return None
|
||||
i = 0
|
||||
while i < len(message):
|
||||
start = message.find("{", i)
|
||||
if start == -1:
|
||||
break
|
||||
depth = 0
|
||||
for j in range(start, len(message)):
|
||||
c = message[j]
|
||||
if c == "{":
|
||||
depth += 1
|
||||
elif c == "}":
|
||||
depth -= 1
|
||||
if depth == 0:
|
||||
substr = message[start : j + 1]
|
||||
try:
|
||||
result = ast.literal_eval(substr)
|
||||
if isinstance(result, dict) and len(result) > 0:
|
||||
return result
|
||||
except (ValueError, SyntaxError, TypeError):
|
||||
pass
|
||||
break
|
||||
i = start + 1
|
||||
return None
|
||||
|
||||
|
||||
# Standard LogRecord attribute names - used to identify 'extra' fields.
|
||||
# Derived at runtime so we automatically include version-specific attrs (e.g. taskName).
|
||||
def _get_standard_record_attrs() -> frozenset:
|
||||
"""Standard LogRecord attribute names - excludes extra keys from logger.debug(..., extra={...})."""
|
||||
return frozenset(logging.LogRecord("", 0, "", 0, "", (), None).__dict__.keys())
|
||||
|
||||
|
||||
_STANDARD_RECORD_ATTRS = _get_standard_record_attrs()
|
||||
|
||||
|
||||
class JsonFormatter(Formatter):
|
||||
def __init__(self):
|
||||
super(JsonFormatter, self).__init__()
|
||||
|
||||
def formatTime(self, record, datefmt=None):
|
||||
# Use datetime to format the timestamp in ISO 8601 format
|
||||
dt = datetime.fromtimestamp(record.created)
|
||||
return dt.isoformat()
|
||||
|
||||
def format(self, record):
|
||||
message_str = record.getMessage()
|
||||
json_record: Dict[str, Any] = {
|
||||
"message": message_str,
|
||||
"level": record.levelname,
|
||||
"timestamp": self.formatTime(record),
|
||||
}
|
||||
|
||||
# Parse embedded JSON or Python dict repr in message so sub-fields become first-class properties
|
||||
parsed = _try_parse_json_message(message_str)
|
||||
if parsed is None:
|
||||
parsed = _try_parse_embedded_python_dict(message_str)
|
||||
if parsed is not None:
|
||||
for key, value in parsed.items():
|
||||
if key not in json_record:
|
||||
json_record[key] = value
|
||||
|
||||
# Include extra attributes passed via logger.debug("msg", extra={...})
|
||||
for key, value in record.__dict__.items():
|
||||
if key not in _STANDARD_RECORD_ATTRS and key not in json_record:
|
||||
json_record[key] = value
|
||||
|
||||
if record.exc_info:
|
||||
json_record["stacktrace"] = self.formatException(record.exc_info)
|
||||
|
||||
return safe_dumps(json_record)
|
||||
|
||||
|
||||
# Function to set up exception handlers for JSON logging
|
||||
def _setup_json_exception_handlers(formatter):
|
||||
# Create a handler with JSON formatting for exceptions
|
||||
error_handler = logging.StreamHandler()
|
||||
error_handler.setFormatter(formatter)
|
||||
|
||||
# Setup excepthook for uncaught exceptions
|
||||
def json_excepthook(exc_type, exc_value, exc_traceback):
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exc_value),
|
||||
args=(),
|
||||
exc_info=(exc_type, exc_value, exc_traceback),
|
||||
)
|
||||
error_handler.handle(record)
|
||||
|
||||
sys.excepthook = json_excepthook
|
||||
|
||||
# Configure asyncio exception handler if possible
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
def async_json_exception_handler(loop, context):
|
||||
exception = context.get("exception")
|
||||
if exception:
|
||||
record = logging.LogRecord(
|
||||
name="LiteLLM",
|
||||
level=logging.ERROR,
|
||||
pathname="",
|
||||
lineno=0,
|
||||
msg=str(exception),
|
||||
args=(),
|
||||
exc_info=None,
|
||||
)
|
||||
error_handler.handle(record)
|
||||
else:
|
||||
loop.default_exception_handler(context)
|
||||
|
||||
asyncio.get_event_loop().set_exception_handler(async_json_exception_handler)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Create a formatter and set it for the handler
|
||||
if json_logs:
|
||||
handler.setFormatter(JsonFormatter())
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
else:
|
||||
formatter = logging.Formatter(
|
||||
"\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
verbose_proxy_logger = logging.getLogger("LiteLLM Proxy")
|
||||
verbose_router_logger = logging.getLogger("LiteLLM Router")
|
||||
verbose_logger = logging.getLogger("LiteLLM")
|
||||
|
||||
# Add the handler to the logger
|
||||
verbose_router_logger.addHandler(handler)
|
||||
verbose_proxy_logger.addHandler(handler)
|
||||
verbose_logger.addHandler(handler)
|
||||
|
||||
|
||||
def _suppress_loggers():
|
||||
"""Suppress noisy loggers at INFO level"""
|
||||
# Suppress httpx request logging at INFO level
|
||||
httpx_logger = logging.getLogger("httpx")
|
||||
httpx_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Suppress APScheduler logging at INFO level
|
||||
apscheduler_executors_logger = logging.getLogger("apscheduler.executors.default")
|
||||
apscheduler_executors_logger.setLevel(logging.WARNING)
|
||||
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
|
||||
apscheduler_scheduler_logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
# Call the suppression function
|
||||
_suppress_loggers()
|
||||
|
||||
ALL_LOGGERS = [
|
||||
logging.getLogger(),
|
||||
verbose_logger,
|
||||
verbose_router_logger,
|
||||
verbose_proxy_logger,
|
||||
]
|
||||
|
||||
|
||||
def _get_loggers_to_initialize():
|
||||
"""
|
||||
Get all loggers that should be initialized with the JSON handler.
|
||||
|
||||
Includes third-party integration loggers (like langfuse) if they are
|
||||
configured as callbacks.
|
||||
"""
|
||||
import litellm
|
||||
|
||||
loggers = list(ALL_LOGGERS)
|
||||
|
||||
# Add langfuse logger if langfuse is being used as a callback
|
||||
langfuse_callbacks = {"langfuse", "langfuse_otel"}
|
||||
all_callbacks = set(litellm.success_callback + litellm.failure_callback)
|
||||
if langfuse_callbacks & all_callbacks:
|
||||
loggers.append(logging.getLogger("langfuse"))
|
||||
|
||||
return loggers
|
||||
|
||||
|
||||
def _initialize_loggers_with_handler(handler: logging.Handler):
|
||||
"""
|
||||
Initialize all loggers with a handler
|
||||
|
||||
- Adds a handler to each logger
|
||||
- Prevents bubbling to parent/root (critical to prevent duplicate JSON logs)
|
||||
"""
|
||||
for lg in _get_loggers_to_initialize():
|
||||
lg.handlers.clear() # remove any existing handlers
|
||||
lg.addHandler(handler) # add JSON formatter handler
|
||||
lg.propagate = False # prevent bubbling to parent/root
|
||||
|
||||
|
||||
def _get_uvicorn_json_log_config():
|
||||
"""
|
||||
Generate a uvicorn log_config dictionary that applies JSON formatting to all loggers.
|
||||
|
||||
This ensures that uvicorn's access logs, error logs, and all application logs
|
||||
are formatted as JSON when json_logs is enabled.
|
||||
"""
|
||||
json_formatter_class = "litellm._logging.JsonFormatter"
|
||||
|
||||
# Use the module-level log_level variable for consistency
|
||||
uvicorn_log_level = log_level.upper()
|
||||
|
||||
log_config = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"json": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
"default": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
"access": {
|
||||
"()": json_formatter_class,
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"default": {
|
||||
"formatter": "json",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
"access": {
|
||||
"formatter": "access",
|
||||
"class": "logging.StreamHandler",
|
||||
"stream": "ext://sys.stdout",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {
|
||||
"handlers": ["default"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.error": {
|
||||
"handlers": ["default"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": uvicorn_log_level,
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return log_config
|
||||
|
||||
|
||||
def _turn_on_json():
|
||||
"""
|
||||
Turn on JSON logging
|
||||
|
||||
- Adds a JSON formatter to all loggers
|
||||
"""
|
||||
handler = logging.StreamHandler()
|
||||
handler.setFormatter(JsonFormatter())
|
||||
_initialize_loggers_with_handler(handler)
|
||||
# Set up exception handlers
|
||||
_setup_json_exception_handlers(JsonFormatter())
|
||||
|
||||
|
||||
def _turn_on_debug():
|
||||
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
|
||||
verbose_router_logger.setLevel(level=logging.DEBUG) # set router logs to debug
|
||||
verbose_proxy_logger.setLevel(level=logging.DEBUG) # set proxy logs to debug
|
||||
|
||||
|
||||
def _disable_debugging():
|
||||
verbose_logger.disabled = True
|
||||
verbose_router_logger.disabled = True
|
||||
verbose_proxy_logger.disabled = True
|
||||
|
||||
|
||||
def _enable_debugging():
|
||||
verbose_logger.disabled = False
|
||||
verbose_router_logger.disabled = False
|
||||
verbose_proxy_logger.disabled = False
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _is_debugging_on() -> bool:
|
||||
"""
|
||||
Returns True if debugging is on
|
||||
"""
|
||||
return verbose_logger.isEnabledFor(logging.DEBUG) or set_verbose is True
|
||||
598
llm-gateway-competitors/litellm-wheel-src/litellm/_redis.py
Normal file
598
llm-gateway-competitors/litellm-wheel-src/litellm/_redis.py
Normal file
@@ -0,0 +1,598 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
||||
# s/o [@Frank Colson](https://www.linkedin.com/in/frank-colson-422b9b183/) for this redis implementation
|
||||
import os
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import redis # type: ignore
|
||||
import redis.asyncio as async_redis # type: ignore
|
||||
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm.constants import REDIS_CONNECTION_POOL_TIMEOUT, REDIS_SOCKET_TIMEOUT
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
from ._logging import verbose_logger
|
||||
|
||||
|
||||
def _get_redis_kwargs():
|
||||
arg_spec = inspect.getfullargspec(redis.Redis)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = [
|
||||
"url",
|
||||
"redis_connect_func",
|
||||
"gcp_service_account",
|
||||
"gcp_ssl_ca_certs",
|
||||
]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_url_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.Redis.from_url)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {
|
||||
"self",
|
||||
"connection_pool",
|
||||
"retry",
|
||||
}
|
||||
|
||||
include_args = ["url"]
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args] + include_args
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_cluster_kwargs(client=None):
|
||||
if client is None:
|
||||
client = redis.Redis.from_url
|
||||
arg_spec = inspect.getfullargspec(redis.RedisCluster)
|
||||
|
||||
# Only allow primitive arguments
|
||||
exclude_args = {"self", "connection_pool", "retry", "host", "port", "startup_nodes"}
|
||||
|
||||
available_args = [x for x in arg_spec.args if x not in exclude_args]
|
||||
available_args.append("password")
|
||||
available_args.append("username")
|
||||
available_args.append("ssl")
|
||||
available_args.append("ssl_cert_reqs")
|
||||
available_args.append("ssl_check_hostname")
|
||||
available_args.append("ssl_ca_certs")
|
||||
available_args.append(
|
||||
"redis_connect_func"
|
||||
) # Needed for sync clusters and IAM detection
|
||||
available_args.append("gcp_service_account")
|
||||
available_args.append("gcp_ssl_ca_certs")
|
||||
available_args.append("max_connections")
|
||||
|
||||
return available_args
|
||||
|
||||
|
||||
def _get_redis_env_kwarg_mapping():
|
||||
PREFIX = "REDIS_"
|
||||
|
||||
return {f"{PREFIX}{x.upper()}": x for x in _get_redis_kwargs()}
|
||||
|
||||
|
||||
def _redis_kwargs_from_environment():
|
||||
mapping = _get_redis_env_kwarg_mapping()
|
||||
|
||||
return_dict = {}
|
||||
for k, v in mapping.items():
|
||||
value = get_secret(k, default_value=None) # type: ignore
|
||||
if value is not None:
|
||||
return_dict[v] = value
|
||||
return return_dict
|
||||
|
||||
|
||||
def _generate_gcp_iam_access_token(service_account: str) -> str:
|
||||
"""
|
||||
Generate GCP IAM access token for Redis authentication.
|
||||
|
||||
Args:
|
||||
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
|
||||
|
||||
Returns:
|
||||
Access token string for GCP IAM authentication
|
||||
"""
|
||||
try:
|
||||
from google.cloud import iam_credentials_v1
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"google-cloud-iam is required for GCP IAM Redis authentication. "
|
||||
"Install it with: pip install google-cloud-iam"
|
||||
)
|
||||
|
||||
client = iam_credentials_v1.IAMCredentialsClient()
|
||||
request = iam_credentials_v1.GenerateAccessTokenRequest(
|
||||
name=service_account,
|
||||
scope=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
response = client.generate_access_token(request=request)
|
||||
return str(response.access_token)
|
||||
|
||||
|
||||
def create_gcp_iam_redis_connect_func(
|
||||
service_account: str,
|
||||
ssl_ca_certs: Optional[str] = None,
|
||||
) -> Callable:
|
||||
"""
|
||||
Creates a custom Redis connection function for GCP IAM authentication.
|
||||
|
||||
Args:
|
||||
service_account: GCP service account in format 'projects/-/serviceAccounts/name@project.iam.gserviceaccount.com'
|
||||
ssl_ca_certs: Path to SSL CA certificate file for secure connections
|
||||
|
||||
Returns:
|
||||
A connection function that can be used with Redis clients
|
||||
"""
|
||||
|
||||
def iam_connect(self):
|
||||
"""Initialize the connection and authenticate using GCP IAM"""
|
||||
from redis.exceptions import (
|
||||
AuthenticationError,
|
||||
AuthenticationWrongNumberOfArgsError,
|
||||
)
|
||||
from redis.utils import str_if_bytes
|
||||
|
||||
self._parser.on_connect(self)
|
||||
|
||||
auth_args = (_generate_gcp_iam_access_token(service_account),)
|
||||
self.send_command("AUTH", *auth_args, check_health=False)
|
||||
|
||||
try:
|
||||
auth_response = self.read_response()
|
||||
except AuthenticationWrongNumberOfArgsError:
|
||||
# Fallback to password auth if IAM fails
|
||||
if hasattr(self, "password") and self.password:
|
||||
self.send_command("AUTH", self.password, check_health=False)
|
||||
auth_response = self.read_response()
|
||||
else:
|
||||
raise
|
||||
|
||||
if str_if_bytes(auth_response) != "OK":
|
||||
raise AuthenticationError("GCP IAM authentication failed")
|
||||
|
||||
return iam_connect
|
||||
|
||||
|
||||
def get_redis_url_from_environment():
|
||||
if "REDIS_URL" in os.environ:
|
||||
return os.environ["REDIS_URL"]
|
||||
|
||||
if "REDIS_HOST" not in os.environ or "REDIS_PORT" not in os.environ:
|
||||
raise ValueError(
|
||||
"Either 'REDIS_URL' or both 'REDIS_HOST' and 'REDIS_PORT' must be specified for Redis."
|
||||
)
|
||||
|
||||
if "REDIS_SSL" in os.environ and os.environ["REDIS_SSL"].lower() == "true":
|
||||
redis_protocol = "rediss"
|
||||
else:
|
||||
redis_protocol = "redis"
|
||||
|
||||
# Build authentication part of URL
|
||||
auth_part = ""
|
||||
if "REDIS_USERNAME" in os.environ and "REDIS_PASSWORD" in os.environ:
|
||||
auth_part = f"{os.environ['REDIS_USERNAME']}:{os.environ['REDIS_PASSWORD']}@"
|
||||
elif "REDIS_PASSWORD" in os.environ:
|
||||
auth_part = f"{os.environ['REDIS_PASSWORD']}@"
|
||||
|
||||
return f"{redis_protocol}://{auth_part}{os.environ['REDIS_HOST']}:{os.environ['REDIS_PORT']}"
|
||||
|
||||
|
||||
def _get_redis_client_logic(**env_overrides):
|
||||
"""
|
||||
Common functionality across sync + async redis client implementations
|
||||
"""
|
||||
### check if "os.environ/<key-name>" passed in
|
||||
for k, v in env_overrides.items():
|
||||
if isinstance(v, str) and v.startswith("os.environ/"):
|
||||
v = v.replace("os.environ/", "")
|
||||
value = get_secret(v) # type: ignore
|
||||
env_overrides[k] = value
|
||||
|
||||
redis_kwargs = {
|
||||
**_redis_kwargs_from_environment(),
|
||||
**env_overrides,
|
||||
}
|
||||
|
||||
_startup_nodes: Optional[Union[str, list]] = redis_kwargs.get("startup_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_CLUSTER_NODES"
|
||||
)
|
||||
|
||||
if _startup_nodes is not None and isinstance(_startup_nodes, str):
|
||||
redis_kwargs["startup_nodes"] = json.loads(_startup_nodes)
|
||||
|
||||
_sentinel_nodes: Optional[Union[str, list]] = redis_kwargs.get("sentinel_nodes", None) or get_secret( # type: ignore
|
||||
"REDIS_SENTINEL_NODES"
|
||||
)
|
||||
|
||||
if _sentinel_nodes is not None and isinstance(_sentinel_nodes, str):
|
||||
redis_kwargs["sentinel_nodes"] = json.loads(_sentinel_nodes)
|
||||
|
||||
_sentinel_password: Optional[str] = redis_kwargs.get(
|
||||
"sentinel_password", None
|
||||
) or get_secret_str("REDIS_SENTINEL_PASSWORD")
|
||||
|
||||
if _sentinel_password is not None:
|
||||
redis_kwargs["sentinel_password"] = _sentinel_password
|
||||
|
||||
_service_name: Optional[str] = redis_kwargs.get("service_name", None) or get_secret( # type: ignore
|
||||
"REDIS_SERVICE_NAME"
|
||||
)
|
||||
|
||||
if _service_name is not None:
|
||||
redis_kwargs["service_name"] = _service_name
|
||||
|
||||
# Handle GCP IAM authentication
|
||||
_gcp_service_account = redis_kwargs.get("gcp_service_account") or get_secret_str(
|
||||
"REDIS_GCP_SERVICE_ACCOUNT"
|
||||
)
|
||||
_gcp_ssl_ca_certs = redis_kwargs.get("gcp_ssl_ca_certs") or get_secret_str(
|
||||
"REDIS_GCP_SSL_CA_CERTS"
|
||||
)
|
||||
|
||||
if _gcp_service_account is not None:
|
||||
verbose_logger.debug(
|
||||
"Setting up GCP IAM authentication for Redis with service account."
|
||||
)
|
||||
redis_kwargs["redis_connect_func"] = create_gcp_iam_redis_connect_func(
|
||||
service_account=_gcp_service_account, ssl_ca_certs=_gcp_ssl_ca_certs
|
||||
)
|
||||
# Store GCP service account in redis_connect_func for async cluster access
|
||||
redis_kwargs["redis_connect_func"]._gcp_service_account = _gcp_service_account
|
||||
|
||||
# Remove GCP-specific kwargs that shouldn't be passed to Redis client
|
||||
redis_kwargs.pop("gcp_service_account", None)
|
||||
redis_kwargs.pop("gcp_ssl_ca_certs", None)
|
||||
|
||||
# Only enable SSL if explicitly requested AND SSL CA certs are provided
|
||||
if _gcp_ssl_ca_certs and redis_kwargs.get("ssl", False):
|
||||
redis_kwargs["ssl_ca_certs"] = _gcp_ssl_ca_certs
|
||||
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
redis_kwargs.pop("host", None)
|
||||
redis_kwargs.pop("port", None)
|
||||
redis_kwargs.pop("db", None)
|
||||
redis_kwargs.pop("password", None)
|
||||
elif "startup_nodes" in redis_kwargs and redis_kwargs["startup_nodes"] is not None:
|
||||
pass
|
||||
elif (
|
||||
"sentinel_nodes" in redis_kwargs and redis_kwargs["sentinel_nodes"] is not None
|
||||
):
|
||||
pass
|
||||
elif "host" not in redis_kwargs or redis_kwargs["host"] is None:
|
||||
raise ValueError("Either 'host' or 'url' must be specified for redis.")
|
||||
|
||||
# litellm.print_verbose(f"redis_kwargs: {redis_kwargs}")
|
||||
return redis_kwargs
|
||||
|
||||
|
||||
def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
||||
_redis_cluster_nodes_in_env: Optional[str] = get_secret("REDIS_CLUSTER_NODES") # type: ignore
|
||||
if _redis_cluster_nodes_in_env is not None:
|
||||
try:
|
||||
redis_kwargs["startup_nodes"] = json.loads(_redis_cluster_nodes_in_env)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(
|
||||
"REDIS_CLUSTER_NODES environment variable is not valid JSON. Please ensure it's properly formatted."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_cluster: startup nodes are being initialized.")
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
cluster_kwargs.pop("startup_nodes", None)
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
|
||||
|
||||
|
||||
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def _init_async_redis_sentinel(redis_kwargs) -> async_redis.Redis:
|
||||
sentinel_nodes = redis_kwargs.get("sentinel_nodes")
|
||||
sentinel_password = redis_kwargs.get("sentinel_password")
|
||||
service_name = redis_kwargs.get("service_name")
|
||||
|
||||
if not sentinel_nodes or not service_name:
|
||||
raise ValueError(
|
||||
"Both 'sentinel_nodes' and 'service_name' are required for Redis Sentinel."
|
||||
)
|
||||
|
||||
verbose_logger.debug("init_redis_sentinel: sentinel nodes are being initialized.")
|
||||
|
||||
# Set up the Sentinel client
|
||||
sentinel = async_redis.Sentinel(
|
||||
sentinel_nodes,
|
||||
socket_timeout=REDIS_SOCKET_TIMEOUT,
|
||||
password=sentinel_password,
|
||||
)
|
||||
|
||||
# Return the master instance for the given service
|
||||
|
||||
return sentinel.master_for(service_name)
|
||||
|
||||
|
||||
def get_redis_client(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
args = _get_redis_url_kwargs()
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
return redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs or get_secret("REDIS_CLUSTER_NODES") is not None: # type: ignore
|
||||
return init_redis_cluster(redis_kwargs)
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_redis_sentinel(redis_kwargs)
|
||||
|
||||
return redis.Redis(**redis_kwargs)
|
||||
|
||||
|
||||
def get_redis_async_client(
|
||||
connection_pool: Optional[async_redis.BlockingConnectionPool] = None,
|
||||
**env_overrides,
|
||||
) -> Union[async_redis.Redis, async_redis.RedisCluster]:
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
if connection_pool is not None:
|
||||
return async_redis.Redis(connection_pool=connection_pool)
|
||||
args = _get_redis_url_kwargs(client=async_redis.Redis.from_url)
|
||||
url_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
url_kwargs[arg] = redis_kwargs[arg]
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
"REDIS: ignoring argument: {}. Not an allowed async_redis.Redis.from_url arg.".format(
|
||||
arg
|
||||
)
|
||||
)
|
||||
return async_redis.Redis.from_url(**url_kwargs)
|
||||
|
||||
if "startup_nodes" in redis_kwargs:
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
args = _get_redis_cluster_kwargs()
|
||||
cluster_kwargs = {}
|
||||
for arg in redis_kwargs:
|
||||
if arg in args:
|
||||
cluster_kwargs[arg] = redis_kwargs[arg]
|
||||
|
||||
# Handle GCP IAM authentication for async clusters
|
||||
redis_connect_func = cluster_kwargs.pop("redis_connect_func", None)
|
||||
from litellm import get_secret_str
|
||||
|
||||
# Get GCP service account - first try from redis_connect_func, then from environment
|
||||
gcp_service_account = None
|
||||
if redis_connect_func and hasattr(redis_connect_func, "_gcp_service_account"):
|
||||
gcp_service_account = redis_connect_func._gcp_service_account
|
||||
else:
|
||||
gcp_service_account = redis_kwargs.get(
|
||||
"gcp_service_account"
|
||||
) or get_secret_str("REDIS_GCP_SERVICE_ACCOUNT")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"DEBUG: Redis cluster kwargs: redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
|
||||
)
|
||||
|
||||
# If GCP IAM is configured (indicated by redis_connect_func), generate access token and use as password
|
||||
if redis_connect_func and gcp_service_account:
|
||||
verbose_logger.debug(
|
||||
"DEBUG: Generating IAM token for service account (value not logged for security reasons)"
|
||||
)
|
||||
try:
|
||||
# Generate IAM access token using the helper function
|
||||
access_token = _generate_gcp_iam_access_token(gcp_service_account)
|
||||
cluster_kwargs["password"] = access_token
|
||||
verbose_logger.debug(
|
||||
"DEBUG: Successfully generated GCP IAM access token for async Redis cluster"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to generate GCP IAM access token: {e}")
|
||||
from redis.exceptions import AuthenticationError
|
||||
|
||||
raise AuthenticationError("Failed to generate GCP IAM access token")
|
||||
else:
|
||||
verbose_logger.debug(
|
||||
f"DEBUG: Not using GCP IAM auth - redis_connect_func={redis_connect_func is not None}, gcp_service_account_provided={gcp_service_account is not None}"
|
||||
)
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
|
||||
for item in redis_kwargs["startup_nodes"]:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
cluster_kwargs.pop("startup_nodes", None)
|
||||
|
||||
# Create async RedisCluster with IAM token as password if available
|
||||
cluster_client = async_redis.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
return cluster_client
|
||||
|
||||
# Check for Redis Sentinel
|
||||
if "sentinel_nodes" in redis_kwargs and "service_name" in redis_kwargs:
|
||||
return _init_async_redis_sentinel(redis_kwargs)
|
||||
_pretty_print_redis_config(redis_kwargs=redis_kwargs)
|
||||
|
||||
if connection_pool is not None:
|
||||
redis_kwargs["connection_pool"] = connection_pool
|
||||
|
||||
return async_redis.Redis(
|
||||
**redis_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def get_redis_connection_pool(**env_overrides):
|
||||
redis_kwargs = _get_redis_client_logic(**env_overrides)
|
||||
verbose_logger.debug("get_redis_connection_pool: redis_kwargs", redis_kwargs)
|
||||
if "url" in redis_kwargs and redis_kwargs["url"] is not None:
|
||||
pool_kwargs = {
|
||||
"timeout": REDIS_CONNECTION_POOL_TIMEOUT,
|
||||
"url": redis_kwargs["url"],
|
||||
}
|
||||
if "max_connections" in redis_kwargs:
|
||||
try:
|
||||
pool_kwargs["max_connections"] = int(redis_kwargs["max_connections"])
|
||||
except (TypeError, ValueError):
|
||||
verbose_logger.warning(
|
||||
"REDIS: invalid max_connections value %r, ignoring",
|
||||
redis_kwargs["max_connections"],
|
||||
)
|
||||
return async_redis.BlockingConnectionPool.from_url(**pool_kwargs)
|
||||
connection_class = async_redis.Connection
|
||||
if "ssl" in redis_kwargs:
|
||||
connection_class = async_redis.SSLConnection
|
||||
redis_kwargs.pop("ssl", None)
|
||||
redis_kwargs["connection_class"] = connection_class
|
||||
redis_kwargs.pop("startup_nodes", None)
|
||||
return async_redis.BlockingConnectionPool(
|
||||
timeout=REDIS_CONNECTION_POOL_TIMEOUT, **redis_kwargs
|
||||
)
|
||||
|
||||
|
||||
def _pretty_print_redis_config(redis_kwargs: dict) -> None:
|
||||
"""Pretty print the Redis configuration using rich with sensitive data masking"""
|
||||
try:
|
||||
import logging
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from rich.text import Text
|
||||
|
||||
if not verbose_logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
console = Console()
|
||||
|
||||
# Initialize the sensitive data masker
|
||||
masker = SensitiveDataMasker()
|
||||
|
||||
# Mask sensitive data in redis_kwargs
|
||||
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
|
||||
|
||||
# Create main panel title
|
||||
title = Text("Redis Configuration", style="bold blue")
|
||||
|
||||
# Create configuration table
|
||||
config_table = Table(
|
||||
title="🔧 Redis Connection Parameters",
|
||||
show_header=True,
|
||||
header_style="bold magenta",
|
||||
title_justify="left",
|
||||
)
|
||||
config_table.add_column("Parameter", style="cyan", no_wrap=True)
|
||||
config_table.add_column("Value", style="yellow")
|
||||
|
||||
# Add rows for each configuration parameter
|
||||
for key, value in masked_redis_kwargs.items():
|
||||
if value is not None:
|
||||
# Special handling for complex objects
|
||||
if isinstance(value, list):
|
||||
if key == "startup_nodes" and value:
|
||||
# Special handling for cluster nodes
|
||||
value_str = f"[{len(value)} cluster nodes]"
|
||||
elif key == "sentinel_nodes" and value:
|
||||
# Special handling for sentinel nodes
|
||||
value_str = f"[{len(value)} sentinel nodes]"
|
||||
else:
|
||||
value_str = str(value)
|
||||
else:
|
||||
value_str = str(value)
|
||||
|
||||
config_table.add_row(key, value_str)
|
||||
|
||||
# Determine connection type
|
||||
connection_type = "Standard Redis"
|
||||
if masked_redis_kwargs.get("startup_nodes"):
|
||||
connection_type = "Redis Cluster"
|
||||
elif masked_redis_kwargs.get("sentinel_nodes"):
|
||||
connection_type = "Redis Sentinel"
|
||||
elif masked_redis_kwargs.get("url"):
|
||||
connection_type = "Redis (URL-based)"
|
||||
|
||||
# Create connection type info
|
||||
info_table = Table(
|
||||
title="📊 Connection Info",
|
||||
show_header=True,
|
||||
header_style="bold green",
|
||||
title_justify="left",
|
||||
)
|
||||
info_table.add_column("Property", style="cyan", no_wrap=True)
|
||||
info_table.add_column("Value", style="yellow")
|
||||
info_table.add_row("Connection Type", connection_type)
|
||||
|
||||
# Print everything in a nice panel
|
||||
console.print("\n")
|
||||
console.print(Panel(title, border_style="blue"))
|
||||
console.print(info_table)
|
||||
console.print(config_table)
|
||||
console.print("\n")
|
||||
|
||||
except ImportError:
|
||||
# Fallback to simple logging if rich is not available
|
||||
masker = SensitiveDataMasker()
|
||||
masked_redis_kwargs = masker.mask_dict(redis_kwargs)
|
||||
verbose_logger.info(f"Redis configuration: {masked_redis_kwargs}")
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error pretty printing Redis configuration: {e}")
|
||||
@@ -0,0 +1,323 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
from .integrations.custom_logger import CustomLogger
|
||||
from .integrations.datadog.datadog import DataDogLogger
|
||||
from .integrations.opentelemetry import OpenTelemetry
|
||||
from .integrations.prometheus_services import PrometheusServicesLogger
|
||||
from .types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
OTELClass = OpenTelemetry
|
||||
else:
|
||||
Span = Any
|
||||
OTELClass = Any
|
||||
UserAPIKeyAuth = Any
|
||||
|
||||
|
||||
class ServiceLogging(CustomLogger):
|
||||
"""
|
||||
Separate class used for monitoring health of litellm-adjacent services (redis/postgres).
|
||||
"""
|
||||
|
||||
def __init__(self, mock_testing: bool = False) -> None:
|
||||
self.mock_testing = mock_testing
|
||||
self.mock_testing_sync_success_hook = 0
|
||||
self.mock_testing_async_success_hook = 0
|
||||
self.mock_testing_sync_failure_hook = 0
|
||||
self.mock_testing_async_failure_hook = 0
|
||||
if "prometheus_system" in litellm.service_callback:
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
|
||||
def service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
):
|
||||
"""
|
||||
Handles both sync and async monitoring by checking for existing event loop.
|
||||
"""
|
||||
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_success_hook += 1
|
||||
|
||||
try:
|
||||
# Try to get the current event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
# Check if the loop is running
|
||||
if loop.is_running():
|
||||
# If we're in a running loop, create a task
|
||||
loop.create_task(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Loop exists but not running, we can use run_until_complete
|
||||
loop.run_until_complete(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one and run
|
||||
asyncio.run(
|
||||
self.async_service_success_hook(
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
|
||||
def service_failure_hook(
|
||||
self, service: ServiceTypes, duration: float, error: Exception, call_type: str
|
||||
):
|
||||
"""
|
||||
[TODO] Not implemented for sync calls yet. V0 is focused on async monitoring (used by proxy).
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_sync_failure_hook += 1
|
||||
|
||||
async def async_service_success_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
call_type: str,
|
||||
duration: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[datetime, float]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is successful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_success_hook += 1
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=False,
|
||||
error=None,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_success_hook(
|
||||
payload=payload
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use: Optional[OpenTelemetry] = None
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use = callback
|
||||
else:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
_otel_logger_to_use = open_telemetry_logger
|
||||
|
||||
if _otel_logger_to_use is not None and parent_otel_span is not None:
|
||||
await _otel_logger_to_use.async_service_success_hook(
|
||||
payload=payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def init_prometheus_services_logger_if_none(self):
|
||||
"""
|
||||
initializes prometheusServicesLogger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
if not hasattr(self, "prometheusServicesLogger"):
|
||||
self.prometheusServicesLogger = PrometheusServicesLogger()
|
||||
elif self.prometheusServicesLogger is None:
|
||||
self.prometheusServicesLogger = self.prometheusServicesLogger()
|
||||
return
|
||||
|
||||
async def init_datadog_logger_if_none(self):
|
||||
"""
|
||||
initializes dd_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.integrations.datadog.datadog import DataDogLogger
|
||||
|
||||
if not hasattr(self, "dd_logger"):
|
||||
self.dd_logger: DataDogLogger = DataDogLogger()
|
||||
|
||||
return
|
||||
|
||||
async def init_otel_logger_if_none(self):
|
||||
"""
|
||||
initializes otel_logger if it is None or no attribute exists on ServiceLogging Object
|
||||
|
||||
"""
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if not hasattr(self, "otel_logger"):
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
self.otel_logger: OpenTelemetry = open_telemetry_logger
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
"ServiceLogger: open_telemetry_logger is None or not an instance of OpenTelemetry"
|
||||
)
|
||||
return
|
||||
|
||||
async def async_service_failure_hook(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
duration: float,
|
||||
error: Union[str, Exception],
|
||||
call_type: str,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
start_time: Optional[Union[datetime, float]] = None,
|
||||
end_time: Optional[Union[float, datetime]] = None,
|
||||
event_metadata: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
- For counting if the redis, postgres call is unsuccessful
|
||||
"""
|
||||
if self.mock_testing:
|
||||
self.mock_testing_async_failure_hook += 1
|
||||
|
||||
error_message = ""
|
||||
if isinstance(error, Exception):
|
||||
error_message = str(error)
|
||||
elif isinstance(error, str):
|
||||
error_message = error
|
||||
|
||||
payload = ServiceLoggerPayload(
|
||||
is_error=True,
|
||||
error=error_message,
|
||||
service=service,
|
||||
duration=duration,
|
||||
call_type=call_type,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
for callback in litellm.service_callback:
|
||||
if callback == "prometheus_system":
|
||||
await self.init_prometheus_services_logger_if_none()
|
||||
await self.prometheusServicesLogger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error,
|
||||
)
|
||||
elif callback == "datadog" or isinstance(callback, DataDogLogger):
|
||||
await self.init_datadog_logger_if_none()
|
||||
await self.dd_logger.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error_message,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
elif callback == "otel" or isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use: Optional[OpenTelemetry] = None
|
||||
if isinstance(callback, OpenTelemetry):
|
||||
_otel_logger_to_use = callback
|
||||
else:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None and isinstance(
|
||||
open_telemetry_logger, OpenTelemetry
|
||||
):
|
||||
_otel_logger_to_use = open_telemetry_logger
|
||||
|
||||
if not isinstance(error, str):
|
||||
error = str(error)
|
||||
|
||||
if _otel_logger_to_use is not None and parent_otel_span is not None:
|
||||
await _otel_logger_to_use.async_service_failure_hook(
|
||||
payload=payload,
|
||||
error=error,
|
||||
parent_otel_span=parent_otel_span,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=event_metadata,
|
||||
)
|
||||
|
||||
async def async_post_call_failure_hook(
|
||||
self,
|
||||
request_data: dict,
|
||||
original_exception: Exception,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
traceback_str: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Hook to track failed litellm-service calls
|
||||
"""
|
||||
return await super().async_post_call_failure_hook(
|
||||
request_data,
|
||||
original_exception,
|
||||
user_api_key_dict,
|
||||
)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Hook to track latency for litellm proxy llm api calls
|
||||
"""
|
||||
try:
|
||||
_duration = end_time - start_time
|
||||
if isinstance(_duration, timedelta):
|
||||
_duration = _duration.total_seconds()
|
||||
elif isinstance(_duration, float):
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
"Duration={} is not a float or timedelta object. type={}".format(
|
||||
_duration, type(_duration)
|
||||
)
|
||||
) # invalid _duration value
|
||||
# Batch polling callbacks (check_batch_cost) don't include call_type in kwargs.
|
||||
# Use .get() to avoid KeyError.
|
||||
await self.async_service_success_hook(
|
||||
service=ServiceTypes.LITELLM,
|
||||
duration=_duration,
|
||||
call_type=kwargs.get("call_type", "unknown"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
16
llm-gateway-competitors/litellm-wheel-src/litellm/_uuid.py
Normal file
16
llm-gateway-competitors/litellm-wheel-src/litellm/_uuid.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Internal unified UUID helper.
|
||||
|
||||
Always uses fastuuid for performance.
|
||||
"""
|
||||
|
||||
import fastuuid as _uuid # type: ignore
|
||||
|
||||
|
||||
# Expose a module-like alias so callers can use: uuid.uuid4()
|
||||
uuid = _uuid
|
||||
|
||||
|
||||
def uuid4():
|
||||
"""Return a UUID4 using the selected backend."""
|
||||
return uuid.uuid4()
|
||||
@@ -0,0 +1,6 @@
|
||||
import importlib_metadata
|
||||
|
||||
try:
|
||||
version = importlib_metadata.version("litellm")
|
||||
except Exception:
|
||||
version = "unknown"
|
||||
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
LiteLLM A2A - Wrapper for invoking A2A protocol agents.
|
||||
|
||||
This module provides a thin wrapper around the official `a2a` SDK that:
|
||||
- Handles httpx client creation and agent card resolution
|
||||
- Adds LiteLLM logging via @client decorator
|
||||
- Matches the A2A SDK interface (SendMessageRequest, SendMessageResponse, etc.)
|
||||
|
||||
Example usage (standalone functions with @client decorator):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
base_url="http://localhost:10001",
|
||||
request=request,
|
||||
)
|
||||
print(response.model_dump(mode='json', exclude_none=True))
|
||||
```
|
||||
|
||||
Example usage (class-based):
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.client import A2AClient
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.a2a_protocol.main import (
|
||||
aget_agent_card,
|
||||
asend_message,
|
||||
asend_message_streaming,
|
||||
create_a2a_client,
|
||||
send_message,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
__all__ = [
|
||||
# Client
|
||||
"A2AClient",
|
||||
# Functions
|
||||
"asend_message",
|
||||
"send_message",
|
||||
"asend_message_streaming",
|
||||
"aget_agent_card",
|
||||
"create_a2a_client",
|
||||
# Response types
|
||||
"LiteLLMSendMessageResponse",
|
||||
# Exceptions
|
||||
"A2AError",
|
||||
"A2AConnectionError",
|
||||
"A2AAgentCardError",
|
||||
"A2ALocalhostURLError",
|
||||
]
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Custom A2A Card Resolver for LiteLLM.
|
||||
|
||||
Extends the A2A SDK's card resolver to support multiple well-known paths.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import LOCALHOST_URL_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import AgentCard
|
||||
|
||||
# Runtime imports with availability check
|
||||
_A2ACardResolver: Any = None
|
||||
AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent-card.json"
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH: str = "/.well-known/agent.json"
|
||||
|
||||
try:
|
||||
from a2a.client import A2ACardResolver as _A2ACardResolver # type: ignore[no-redef]
|
||||
from a2a.utils.constants import ( # type: ignore[no-redef]
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_localhost_or_internal_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost or internal URL.
|
||||
|
||||
This detects common development URLs that are accidentally left in
|
||||
agent cards when deploying to production.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
if not url:
|
||||
return False
|
||||
|
||||
url_lower = url.lower()
|
||||
|
||||
return any(pattern in url_lower for pattern in LOCALHOST_URL_PATTERNS)
|
||||
|
||||
|
||||
def fix_agent_card_url(agent_card: "AgentCard", base_url: str) -> "AgentCard":
|
||||
"""
|
||||
Fix the agent card URL if it contains a localhost/internal address.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This function
|
||||
replaces such URLs with the provided base_url.
|
||||
|
||||
Args:
|
||||
agent_card: The agent card to fix
|
||||
base_url: The base URL to use as replacement
|
||||
|
||||
Returns:
|
||||
The agent card with the URL fixed if necessary
|
||||
"""
|
||||
card_url = getattr(agent_card, "url", None)
|
||||
|
||||
if card_url and is_localhost_or_internal_url(card_url):
|
||||
# Normalize base_url to ensure it ends with /
|
||||
fixed_url = base_url.rstrip("/") + "/"
|
||||
agent_card.url = fixed_url
|
||||
|
||||
return agent_card
|
||||
|
||||
|
||||
class LiteLLMA2ACardResolver(_A2ACardResolver): # type: ignore[misc]
|
||||
"""
|
||||
Custom A2A card resolver that supports multiple well-known paths.
|
||||
|
||||
Extends the base A2ACardResolver to try both:
|
||||
- /.well-known/agent-card.json (standard)
|
||||
- /.well-known/agent.json (previous/alternative)
|
||||
"""
|
||||
|
||||
async def get_agent_card(
|
||||
self,
|
||||
relative_card_path: Optional[str] = None,
|
||||
http_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card, trying multiple well-known paths.
|
||||
|
||||
First tries the standard path, then falls back to the previous path.
|
||||
|
||||
Args:
|
||||
relative_card_path: Optional path to the agent card endpoint.
|
||||
If None, tries both well-known paths.
|
||||
http_kwargs: Optional dictionary of keyword arguments to pass to httpx.get
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
|
||||
Raises:
|
||||
A2AClientHTTPError or A2AClientJSONError if both paths fail
|
||||
"""
|
||||
# If a specific path is provided, use the parent implementation
|
||||
if relative_card_path is not None:
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=relative_card_path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
|
||||
# Try both well-known paths
|
||||
paths = [
|
||||
AGENT_CARD_WELL_KNOWN_PATH,
|
||||
PREV_AGENT_CARD_WELL_KNOWN_PATH,
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for path in paths:
|
||||
try:
|
||||
verbose_logger.debug(
|
||||
f"Attempting to fetch agent card from {self.base_url}{path}"
|
||||
)
|
||||
return await super().get_agent_card(
|
||||
relative_card_path=path,
|
||||
http_kwargs=http_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to fetch agent card from {self.base_url}{path}: {e}"
|
||||
)
|
||||
last_error = e
|
||||
continue
|
||||
|
||||
# If we get here, all paths failed - re-raise the last error
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
|
||||
# This shouldn't happen, but just in case
|
||||
raise Exception(
|
||||
f"Failed to fetch agent card from {self.base_url}. "
|
||||
f"Tried paths: {', '.join(paths)}"
|
||||
)
|
||||
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
LiteLLM A2A Client class.
|
||||
|
||||
Provides a class-based interface for A2A agent invocation.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, AsyncIterator, Dict, Optional
|
||||
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import (
|
||||
AgentCard,
|
||||
SendMessageRequest,
|
||||
SendStreamingMessageRequest,
|
||||
SendStreamingMessageResponse,
|
||||
)
|
||||
|
||||
|
||||
class A2AClient:
|
||||
"""
|
||||
LiteLLM wrapper for A2A agent invocation.
|
||||
|
||||
Creates the underlying A2A client once on first use and reuses it.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import A2AClient
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
client = A2AClient(base_url="http://localhost:10001")
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
)
|
||||
)
|
||||
response = await client.send_message(request)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the A2A client wrapper.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.extra_headers = extra_headers
|
||||
self._a2a_client: Optional["A2AClientType"] = None
|
||||
|
||||
async def _get_client(self) -> "A2AClientType":
|
||||
"""Get or create the underlying A2A client."""
|
||||
if self._a2a_client is None:
|
||||
from litellm.a2a_protocol.main import create_a2a_client
|
||||
|
||||
self._a2a_client = await create_a2a_client(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
return self._a2a_client
|
||||
|
||||
async def get_agent_card(self) -> "AgentCard":
|
||||
"""Fetch the agent card from the server."""
|
||||
from litellm.a2a_protocol.main import aget_agent_card
|
||||
|
||||
return await aget_agent_card(
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
extra_headers=self.extra_headers,
|
||||
)
|
||||
|
||||
async def send_message(
|
||||
self, request: "SendMessageRequest"
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""Send a message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
return await asend_message(a2a_client=a2a_client, request=request)
|
||||
|
||||
async def send_message_streaming(
|
||||
self, request: "SendStreamingMessageRequest"
|
||||
) -> AsyncIterator["SendStreamingMessageResponse"]:
|
||||
"""Send a streaming message to the A2A agent."""
|
||||
from litellm.a2a_protocol.main import asend_message_streaming
|
||||
|
||||
a2a_client = await self._get_client()
|
||||
async for chunk in asend_message_streaming(
|
||||
a2a_client=a2a_client, request=request
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Cost calculator for A2A (Agent-to-Agent) calls.
|
||||
|
||||
Supports dynamic cost parameters that allow platform owners
|
||||
to define custom costs per agent query or per token.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class A2ACostCalculator:
|
||||
@staticmethod
|
||||
def calculate_a2a_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an A2A send_message call.
|
||||
|
||||
Supports multiple cost parameters for platform owners:
|
||||
- cost_per_query: Fixed cost per query
|
||||
- input_cost_per_token + output_cost_per_token: Token-based pricing
|
||||
|
||||
Priority order:
|
||||
1. response_cost - if set directly (backward compatibility)
|
||||
2. cost_per_query - fixed cost per query
|
||||
3. input_cost_per_token + output_cost_per_token - token-based cost
|
||||
4. Default to 0.0
|
||||
|
||||
Args:
|
||||
litellm_logging_obj: The LiteLLM logging object containing call details
|
||||
|
||||
Returns:
|
||||
float: The cost of the A2A call
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
model_call_details = litellm_logging_obj.model_call_details
|
||||
|
||||
# Check if user set a custom response cost (backward compatibility)
|
||||
response_cost = model_call_details.get("response_cost", None)
|
||||
if response_cost is not None:
|
||||
return float(response_cost)
|
||||
|
||||
# Get litellm_params for cost parameters
|
||||
litellm_params = model_call_details.get("litellm_params", {}) or {}
|
||||
|
||||
# Check for cost_per_query (fixed cost per query)
|
||||
if litellm_params.get("cost_per_query") is not None:
|
||||
return float(litellm_params["cost_per_query"])
|
||||
|
||||
# Check for token-based pricing
|
||||
input_cost_per_token = litellm_params.get("input_cost_per_token")
|
||||
output_cost_per_token = litellm_params.get("output_cost_per_token")
|
||||
|
||||
if input_cost_per_token is not None or output_cost_per_token is not None:
|
||||
return A2ACostCalculator._calculate_token_based_cost(
|
||||
model_call_details=model_call_details,
|
||||
input_cost_per_token=input_cost_per_token,
|
||||
output_cost_per_token=output_cost_per_token,
|
||||
)
|
||||
|
||||
# Default to 0.0 for A2A calls
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _calculate_token_based_cost(
|
||||
model_call_details: dict,
|
||||
input_cost_per_token: Optional[float],
|
||||
output_cost_per_token: Optional[float],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost based on token usage and per-token pricing.
|
||||
|
||||
Args:
|
||||
model_call_details: The model call details containing usage
|
||||
input_cost_per_token: Cost per input token (can be None, defaults to 0)
|
||||
output_cost_per_token: Cost per output token (can be None, defaults to 0)
|
||||
|
||||
Returns:
|
||||
float: The calculated cost
|
||||
"""
|
||||
# Get usage from model_call_details
|
||||
usage = model_call_details.get("usage")
|
||||
if usage is None:
|
||||
return 0.0
|
||||
|
||||
# Get token counts
|
||||
prompt_tokens = getattr(usage, "prompt_tokens", 0) or 0
|
||||
completion_tokens = getattr(usage, "completion_tokens", 0) or 0
|
||||
|
||||
# Calculate costs
|
||||
input_cost = prompt_tokens * (
|
||||
float(input_cost_per_token) if input_cost_per_token else 0.0
|
||||
)
|
||||
output_cost = completion_tokens * (
|
||||
float(output_cost_per_token) if output_cost_per_token else 0.0
|
||||
)
|
||||
|
||||
return input_cost + output_cost
|
||||
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
A2A Protocol Exception Mapping Utils.
|
||||
|
||||
Maps A2A SDK exceptions to LiteLLM A2A exception types.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.card_resolver import (
|
||||
fix_agent_card_url,
|
||||
is_localhost_or_internal_url,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import (
|
||||
A2AAgentCardError,
|
||||
A2AConnectionError,
|
||||
A2AError,
|
||||
A2ALocalhostURLError,
|
||||
)
|
||||
from litellm.constants import CONNECTION_ERROR_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
|
||||
|
||||
# Runtime import
|
||||
A2A_SDK_AVAILABLE = False
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
_A2AClient = None # type: ignore[assignment, misc]
|
||||
|
||||
|
||||
class A2AExceptionCheckers:
|
||||
"""
|
||||
Helper class for checking various A2A error conditions.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_connection_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates a connection error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error indicates a connection issue
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
return any(pattern in error_str_lower for pattern in CONNECTION_ERROR_PATTERNS)
|
||||
|
||||
@staticmethod
|
||||
def is_localhost_url(url: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a URL is a localhost/internal URL.
|
||||
|
||||
Args:
|
||||
url: The URL to check
|
||||
|
||||
Returns:
|
||||
True if the URL is localhost/internal
|
||||
"""
|
||||
return is_localhost_or_internal_url(url)
|
||||
|
||||
@staticmethod
|
||||
def is_agent_card_error(error_str: str) -> bool:
|
||||
"""
|
||||
Check if an error string indicates an agent card error.
|
||||
|
||||
Args:
|
||||
error_str: The error string to check
|
||||
|
||||
Returns:
|
||||
True if the error is related to agent card fetching/parsing
|
||||
"""
|
||||
if not isinstance(error_str, str):
|
||||
return False
|
||||
|
||||
error_str_lower = error_str.lower()
|
||||
agent_card_patterns = [
|
||||
"agent card",
|
||||
"agent-card",
|
||||
".well-known",
|
||||
"card not found",
|
||||
"invalid agent",
|
||||
]
|
||||
return any(pattern in error_str_lower for pattern in agent_card_patterns)
|
||||
|
||||
|
||||
def map_a2a_exception(
|
||||
original_exception: Exception,
|
||||
card_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> Exception:
|
||||
"""
|
||||
Map an A2A SDK exception to a LiteLLM A2A exception type.
|
||||
|
||||
Args:
|
||||
original_exception: The original exception from the A2A SDK
|
||||
card_url: The URL from the agent card (if available)
|
||||
api_base: The original API base URL
|
||||
model: The model/agent name
|
||||
|
||||
Returns:
|
||||
A mapped LiteLLM A2A exception
|
||||
|
||||
Raises:
|
||||
A2ALocalhostURLError: If the error is a connection error to a localhost URL
|
||||
A2AConnectionError: If the error is a general connection error
|
||||
A2AAgentCardError: If the error is related to agent card issues
|
||||
A2AError: For other A2A-related errors
|
||||
"""
|
||||
error_str = str(original_exception)
|
||||
|
||||
# Check for localhost URL connection error (special case - retryable)
|
||||
if (
|
||||
card_url
|
||||
and api_base
|
||||
and A2AExceptionCheckers.is_localhost_url(card_url)
|
||||
and A2AExceptionCheckers.is_connection_error(error_str)
|
||||
):
|
||||
raise A2ALocalhostURLError(
|
||||
localhost_url=card_url,
|
||||
base_url=api_base,
|
||||
original_error=original_exception,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for agent card errors
|
||||
if A2AExceptionCheckers.is_agent_card_error(error_str):
|
||||
raise A2AAgentCardError(
|
||||
message=error_str,
|
||||
url=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Check for general connection errors
|
||||
if A2AExceptionCheckers.is_connection_error(error_str):
|
||||
raise A2AConnectionError(
|
||||
message=error_str,
|
||||
url=card_url or api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Default: wrap in generic A2AError
|
||||
raise A2AError(
|
||||
message=error_str,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
def handle_a2a_localhost_retry(
|
||||
error: A2ALocalhostURLError,
|
||||
agent_card: Any,
|
||||
a2a_client: "A2AClientType",
|
||||
is_streaming: bool = False,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Handle A2ALocalhostURLError by fixing the URL and creating a new client.
|
||||
|
||||
This is called when we catch an A2ALocalhostURLError and want to retry
|
||||
with the corrected URL.
|
||||
|
||||
Args:
|
||||
error: The localhost URL error
|
||||
agent_card: The agent card object to fix
|
||||
a2a_client: The current A2A client
|
||||
is_streaming: Whether this is a streaming request (for logging)
|
||||
|
||||
Returns:
|
||||
A new A2A client with the fixed URL
|
||||
|
||||
Raises:
|
||||
ImportError: If the A2A SDK is not installed
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE or _A2AClient is None:
|
||||
raise ImportError(
|
||||
"A2A SDK is required for localhost retry handling. "
|
||||
"Install it with: pip install a2a"
|
||||
)
|
||||
|
||||
request_type = "streaming " if is_streaming else ""
|
||||
verbose_logger.warning(
|
||||
f"A2A {request_type}request to '{error.localhost_url}' failed: {error.original_error}. "
|
||||
f"Agent card contains localhost/internal URL. "
|
||||
f"Retrying with base_url '{error.base_url}'."
|
||||
)
|
||||
|
||||
# Fix the agent card URL
|
||||
fix_agent_card_url(agent_card, error.base_url)
|
||||
|
||||
# Create a new client with the fixed agent card (transport caches URL)
|
||||
return _A2AClient(
|
||||
httpx_client=a2a_client._transport.httpx_client, # type: ignore[union-attr]
|
||||
agent_card=agent_card,
|
||||
)
|
||||
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
A2A Protocol Exceptions.
|
||||
|
||||
Custom exception types for A2A protocol operations, following LiteLLM's exception pattern.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class A2AError(Exception):
|
||||
"""
|
||||
Base exception for A2A protocol errors.
|
||||
|
||||
Follows the same pattern as LiteLLM's main exceptions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
status_code: int = 500,
|
||||
llm_provider: str = "a2a_agent",
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = f"litellm.A2AError: {message}"
|
||||
self.llm_provider = llm_provider
|
||||
self.model = model
|
||||
self.litellm_debug_info = litellm_debug_info
|
||||
self.max_retries = max_retries
|
||||
self.num_retries = num_retries
|
||||
self.response = response or httpx.Response(
|
||||
status_code=self.status_code,
|
||||
request=httpx.Request(method="POST", url="https://litellm.ai"),
|
||||
)
|
||||
super().__init__(self.message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
_message = self.message
|
||||
if self.num_retries:
|
||||
_message += f" LiteLLM Retried: {self.num_retries} times"
|
||||
if self.max_retries:
|
||||
_message += f", LiteLLM Max Retries: {self.max_retries}"
|
||||
return _message
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class A2AConnectionError(A2AError):
|
||||
"""
|
||||
Raised when connection to an A2A agent fails.
|
||||
|
||||
This typically occurs when:
|
||||
- The agent is unreachable
|
||||
- The agent card contains a localhost/internal URL
|
||||
- Network issues prevent connection
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
num_retries: Optional[int] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=503,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
max_retries=max_retries,
|
||||
num_retries=num_retries,
|
||||
)
|
||||
|
||||
|
||||
class A2AAgentCardError(A2AError):
|
||||
"""
|
||||
Raised when there's an issue with the agent card.
|
||||
|
||||
This includes:
|
||||
- Failed to fetch agent card
|
||||
- Invalid agent card format
|
||||
- Missing required fields
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
url: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
litellm_debug_info: Optional[str] = None,
|
||||
):
|
||||
self.url = url
|
||||
super().__init__(
|
||||
message=message,
|
||||
status_code=404,
|
||||
llm_provider="a2a_agent",
|
||||
model=model,
|
||||
response=response,
|
||||
litellm_debug_info=litellm_debug_info,
|
||||
)
|
||||
|
||||
|
||||
class A2ALocalhostURLError(A2AConnectionError):
|
||||
"""
|
||||
Raised when an agent card contains a localhost/internal URL.
|
||||
|
||||
Many A2A agents are deployed with agent cards that contain internal URLs
|
||||
like "http://0.0.0.0:8001/" or "http://localhost:8000/". This error
|
||||
indicates that the URL needs to be corrected and the request should be retried.
|
||||
|
||||
Attributes:
|
||||
localhost_url: The localhost/internal URL found in the agent card
|
||||
base_url: The public base URL that should be used instead
|
||||
original_error: The original connection error that was raised
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
localhost_url: str,
|
||||
base_url: str,
|
||||
original_error: Optional[Exception] = None,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
self.localhost_url = localhost_url
|
||||
self.base_url = base_url
|
||||
self.original_error = original_error
|
||||
|
||||
message = (
|
||||
f"Agent card contains localhost/internal URL '{localhost_url}'. "
|
||||
f"Retrying with base URL '{base_url}'."
|
||||
)
|
||||
super().__init__(
|
||||
message=message,
|
||||
url=localhost_url,
|
||||
model=model,
|
||||
)
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
A2A to LiteLLM Completion Bridge.
|
||||
|
||||
This module provides transformation between A2A protocol messages and
|
||||
LiteLLM completion API, enabling any LiteLLM-supported provider to be
|
||||
invoked via the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
handle_a2a_completion,
|
||||
handle_a2a_completion_streaming,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"A2ACompletionBridgeTransformation",
|
||||
"A2ACompletionBridgeHandler",
|
||||
"handle_a2a_completion",
|
||||
"handle_a2a_completion_streaming",
|
||||
]
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(f"A2A: Using provider config for {custom_llm_provider}")
|
||||
|
||||
response_data = await a2a_provider_config.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Get provider config for custom_llm_provider
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
a2a_provider_config = A2AProviderConfigManager.get_provider_config(
|
||||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
# If provider config exists, use it
|
||||
if a2a_provider_config is not None:
|
||||
if api_base is None:
|
||||
raise ValueError(f"api_base is required for {custom_llm_provider}")
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A: Using provider config for {custom_llm_provider} (streaming)"
|
||||
)
|
||||
|
||||
async for chunk in a2a_provider_config.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,744 @@
|
||||
"""
|
||||
LiteLLM A2A SDK functions.
|
||||
|
||||
Provides standalone functions with @client decorator for LiteLLM logging integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Coroutine, Dict, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.a2a_protocol.streaming_iterator import A2AStreamingIterator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.constants import DEFAULT_A2A_AGENT_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.types.agents import LiteLLMSendMessageResponse
|
||||
from litellm.utils import client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.client import A2AClient as A2AClientType
|
||||
from a2a.types import AgentCard, SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
# Runtime imports with availability check
|
||||
A2A_SDK_AVAILABLE = False
|
||||
A2ACardResolver: Any = None
|
||||
_A2AClient: Any = None
|
||||
|
||||
try:
|
||||
from a2a.client import A2AClient as _A2AClient # type: ignore[no-redef]
|
||||
|
||||
A2A_SDK_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Import our custom card resolver that supports multiple well-known paths
|
||||
from litellm.a2a_protocol.card_resolver import LiteLLMA2ACardResolver
|
||||
from litellm.a2a_protocol.exception_mapping_utils import (
|
||||
handle_a2a_localhost_retry,
|
||||
map_a2a_exception,
|
||||
)
|
||||
from litellm.a2a_protocol.exceptions import A2ALocalhostURLError
|
||||
|
||||
# Use our custom resolver instead of the default A2A SDK resolver
|
||||
A2ACardResolver = LiteLLMA2ACardResolver
|
||||
|
||||
|
||||
def _set_usage_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
) -> None:
|
||||
"""
|
||||
Set usage on litellm_logging_obj for standard logging payload.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
prompt_tokens: Number of input tokens
|
||||
completion_tokens: Number of output tokens
|
||||
"""
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
litellm_logging_obj.model_call_details["usage"] = usage
|
||||
|
||||
|
||||
def _set_agent_id_on_logging_obj(
|
||||
kwargs: Dict[str, Any],
|
||||
agent_id: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Set agent_id on litellm_logging_obj for SpendLogs tracking.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs dict containing litellm_logging_obj
|
||||
agent_id: The A2A agent ID
|
||||
"""
|
||||
if agent_id is None:
|
||||
return
|
||||
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
# Set agent_id directly on model_call_details (same pattern as custom_llm_provider)
|
||||
litellm_logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
|
||||
def _get_a2a_model_info(a2a_client: Any, kwargs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract agent info and set model/custom_llm_provider for cost tracking.
|
||||
|
||||
Sets model info on the litellm_logging_obj if available.
|
||||
Returns the agent name for logging.
|
||||
"""
|
||||
agent_name = "unknown"
|
||||
|
||||
# Try to get agent card from our stored attribute first, then fallback to SDK attribute
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None)
|
||||
if agent_card is None:
|
||||
agent_card = getattr(a2a_client, "agent_card", None)
|
||||
|
||||
if agent_card is not None:
|
||||
agent_name = getattr(agent_card, "name", "unknown") or "unknown"
|
||||
|
||||
# Build model string
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
custom_llm_provider = "a2a_agent"
|
||||
|
||||
# Set on litellm_logging_obj if available (for standard logging payload)
|
||||
litellm_logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if litellm_logging_obj is not None:
|
||||
litellm_logging_obj.model = model
|
||||
litellm_logging_obj.custom_llm_provider = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
|
||||
return agent_name
|
||||
|
||||
|
||||
async def _send_message_via_completion_bridge(
|
||||
request: "SendMessageRequest",
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Route a send_message through the LiteLLM completion bridge (e.g. LangGraph, Bedrock AgentCore).
|
||||
|
||||
Requires request; api_base is optional for providers that derive endpoint from model.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"A2A using completion bridge: provider={custom_llm_provider}, api_base={api_base}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
response_dict = await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
return LiteLLMSendMessageResponse.from_dict(response_dict)
|
||||
|
||||
|
||||
async def _execute_a2a_send_with_retry(
|
||||
a2a_client: Any,
|
||||
request: Any,
|
||||
agent_card: Any,
|
||||
card_url: Optional[str],
|
||||
api_base: Optional[str],
|
||||
agent_name: Optional[str],
|
||||
) -> Any:
|
||||
"""Send an A2A message with retry logic for localhost URL errors."""
|
||||
a2a_response = None
|
||||
for _ in range(2): # max 2 attempts: original + 1 retry
|
||||
try:
|
||||
a2a_response = await a2a_client.send_message(request)
|
||||
break # success, exit retry loop
|
||||
except A2ALocalhostURLError as e:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
except Exception as e:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=False,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
raise
|
||||
if a2a_response is None:
|
||||
raise RuntimeError(
|
||||
"A2A send_message failed: no response received after retry attempts."
|
||||
)
|
||||
return a2a_response
|
||||
|
||||
|
||||
@client
|
||||
async def asend_message(
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> LiteLLMSendMessageResponse:
|
||||
"""
|
||||
Async: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendMessageRequest from a2a.types (optional if using completion bridge with api_base)
|
||||
api_base: API base URL (required for completion bridge, optional for standard A2A)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
|
||||
Example (standard A2A):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, create_a2a_client
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
a2a_client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(a2a_client=a2a_client, request=request)
|
||||
```
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message
|
||||
from a2a.types import SendMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
trace_id = getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
return await _send_message_via_completion_bridge(
|
||||
request=request,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
trace_id = trace_id or str(uuid.uuid4())
|
||||
extra_headers: Dict[str, str] = {"X-LiteLLM-Trace-Id": trace_id}
|
||||
if agent_id:
|
||||
extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
# Overlay agent-level headers (agent headers take precedence over LiteLLM internal ones)
|
||||
if agent_extra_headers:
|
||||
extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
agent_name = _get_a2a_model_info(a2a_client, kwargs)
|
||||
|
||||
verbose_logger.info(f"A2A send_message request_id={request.id}, agent={agent_name}")
|
||||
|
||||
# Get agent card URL for localhost retry logic
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
|
||||
context_id = trace_id or str(uuid.uuid4())
|
||||
message = request.params.message
|
||||
if isinstance(message, dict):
|
||||
if message.get("context_id") is None:
|
||||
message["context_id"] = context_id
|
||||
else:
|
||||
if getattr(message, "context_id", None) is None:
|
||||
message.context_id = context_id
|
||||
|
||||
a2a_response = await _execute_a2a_send_with_retry(
|
||||
a2a_client=a2a_client,
|
||||
request=request,
|
||||
agent_card=agent_card,
|
||||
card_url=card_url,
|
||||
api_base=api_base,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A send_message completed, request_id={request.id}")
|
||||
|
||||
# Wrap in LiteLLM response type for _hidden_params support
|
||||
response = LiteLLMSendMessageResponse.from_a2a_response(a2a_response)
|
||||
|
||||
# Calculate token usage from request and response
|
||||
response_dict = a2a_response.model_dump(mode="json", exclude_none=True)
|
||||
(
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
_,
|
||||
) = A2ARequestUtils.calculate_usage_from_request_response(
|
||||
request=request,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
# Set usage on logging obj for standard logging payload
|
||||
_set_usage_on_logging_obj(
|
||||
kwargs=kwargs,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
# Set agent_id on logging obj for SpendLogs tracking
|
||||
_set_agent_id_on_logging_obj(kwargs=kwargs, agent_id=agent_id)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@client
|
||||
def send_message(
|
||||
a2a_client: "A2AClientType",
|
||||
request: "SendMessageRequest",
|
||||
**kwargs: Any,
|
||||
) -> Union[LiteLLMSendMessageResponse, Coroutine[Any, Any, LiteLLMSendMessageResponse]]:
|
||||
"""
|
||||
Sync: Send a message to an A2A agent.
|
||||
|
||||
Uses the @client decorator for LiteLLM logging and tracking.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance
|
||||
request: SendMessageRequest from a2a.types
|
||||
**kwargs: Additional arguments passed to the client decorator
|
||||
|
||||
Returns:
|
||||
LiteLLMSendMessageResponse (wraps a2a SendMessageResponse with _hidden_params)
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop is not None:
|
||||
return asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
else:
|
||||
return asyncio.run(
|
||||
asend_message(a2a_client=a2a_client, request=request, **kwargs)
|
||||
)
|
||||
|
||||
|
||||
def _build_streaming_logging_obj(
|
||||
request: "SendStreamingMessageRequest",
|
||||
agent_name: str,
|
||||
agent_id: Optional[str],
|
||||
litellm_params: Optional[Dict[str, Any]],
|
||||
metadata: Optional[Dict[str, Any]],
|
||||
proxy_server_request: Optional[Dict[str, Any]],
|
||||
) -> Logging:
|
||||
"""Build logging object for streaming A2A requests."""
|
||||
start_time = datetime.datetime.now()
|
||||
model = f"a2a_agent/{agent_name}"
|
||||
|
||||
logging_obj = Logging(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "streaming-request"}],
|
||||
stream=False,
|
||||
call_type="asend_message_streaming",
|
||||
start_time=start_time,
|
||||
litellm_call_id=str(request.id),
|
||||
function_id=str(request.id),
|
||||
)
|
||||
logging_obj.model = model
|
||||
logging_obj.custom_llm_provider = "a2a_agent"
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "a2a_agent"
|
||||
if agent_id:
|
||||
logging_obj.model_call_details["agent_id"] = agent_id
|
||||
|
||||
_litellm_params = litellm_params.copy() if litellm_params else {}
|
||||
if metadata:
|
||||
_litellm_params["metadata"] = metadata
|
||||
if proxy_server_request:
|
||||
_litellm_params["proxy_server_request"] = proxy_server_request
|
||||
|
||||
logging_obj.litellm_params = _litellm_params
|
||||
logging_obj.optional_params = _litellm_params
|
||||
logging_obj.model_call_details["litellm_params"] = _litellm_params
|
||||
logging_obj.model_call_details["metadata"] = metadata or {}
|
||||
|
||||
return logging_obj
|
||||
|
||||
|
||||
async def asend_message_streaming( # noqa: PLR0915
|
||||
a2a_client: Optional["A2AClientType"] = None,
|
||||
request: Optional["SendStreamingMessageRequest"] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
proxy_server_request: Optional[Dict[str, Any]] = None,
|
||||
agent_extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> AsyncIterator[Any]:
|
||||
"""
|
||||
Async: Send a streaming message to an A2A agent.
|
||||
|
||||
If litellm_params contains custom_llm_provider, routes through the completion bridge.
|
||||
|
||||
Args:
|
||||
a2a_client: An initialized a2a.client.A2AClient instance (optional if using completion bridge)
|
||||
request: SendStreamingMessageRequest from a2a.types
|
||||
api_base: API base URL (required for completion bridge)
|
||||
litellm_params: Optional dict with custom_llm_provider, model, etc. for completion bridge
|
||||
agent_id: Optional agent ID for tracking in SpendLogs
|
||||
metadata: Optional metadata dict (contains user_api_key, user_id, team_id, etc.)
|
||||
proxy_server_request: Optional proxy server request data
|
||||
|
||||
Yields:
|
||||
SendStreamingMessageResponse chunks from the agent
|
||||
|
||||
Example (completion bridge with LangGraph):
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message_streaming
|
||||
from a2a.types import SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
"""
|
||||
litellm_params = litellm_params or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
|
||||
# Route through completion bridge if custom_llm_provider is set
|
||||
if custom_llm_provider:
|
||||
if request is None:
|
||||
raise ValueError("request is required for completion bridge")
|
||||
# api_base is optional for providers that derive endpoint from model (e.g., bedrock/agentcore)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming using completion bridge: provider={custom_llm_provider}"
|
||||
)
|
||||
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.handler import (
|
||||
A2ACompletionBridgeHandler,
|
||||
)
|
||||
|
||||
# Extract params from request
|
||||
params = (
|
||||
request.params.model_dump(mode="json")
|
||||
if hasattr(request.params, "model_dump")
|
||||
else dict(request.params)
|
||||
)
|
||||
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=str(request.id),
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Standard A2A client flow
|
||||
if request is None:
|
||||
raise ValueError("request is required")
|
||||
|
||||
# Create A2A client if not provided but api_base is available
|
||||
if a2a_client is None:
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Either a2a_client or api_base is required for standard A2A flow"
|
||||
)
|
||||
# Mirror the non-streaming path: always include trace and agent-id headers
|
||||
streaming_extra_headers: Dict[str, str] = {
|
||||
"X-LiteLLM-Trace-Id": str(request.id),
|
||||
}
|
||||
if agent_id:
|
||||
streaming_extra_headers["X-LiteLLM-Agent-Id"] = agent_id
|
||||
if agent_extra_headers:
|
||||
streaming_extra_headers.update(agent_extra_headers)
|
||||
a2a_client = await create_a2a_client(
|
||||
base_url=api_base, extra_headers=streaming_extra_headers
|
||||
)
|
||||
|
||||
# Type assertion: a2a_client is guaranteed to be non-None here
|
||||
assert a2a_client is not None
|
||||
|
||||
verbose_logger.info(f"A2A send_message_streaming request_id={request.id}")
|
||||
|
||||
# Build logging object for streaming completion callbacks
|
||||
agent_card = getattr(a2a_client, "_litellm_agent_card", None) or getattr(
|
||||
a2a_client, "agent_card", None
|
||||
)
|
||||
card_url = getattr(agent_card, "url", None) if agent_card else None
|
||||
agent_name = getattr(agent_card, "name", "unknown") if agent_card else "unknown"
|
||||
|
||||
logging_obj = _build_streaming_logging_obj(
|
||||
request=request,
|
||||
agent_name=agent_name,
|
||||
agent_id=agent_id,
|
||||
litellm_params=litellm_params,
|
||||
metadata=metadata,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
# Retry loop: if connection fails due to localhost URL in agent card, retry with fixed URL
|
||||
# Connection errors in streaming typically occur on first chunk iteration
|
||||
first_chunk = True
|
||||
for attempt in range(2): # max 2 attempts: original + 1 retry
|
||||
stream = a2a_client.send_message_streaming(request)
|
||||
iterator = A2AStreamingIterator(
|
||||
stream=stream,
|
||||
request=request,
|
||||
logging_obj=logging_obj,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
try:
|
||||
first_chunk = True
|
||||
async for chunk in iterator:
|
||||
if first_chunk:
|
||||
first_chunk = False # connection succeeded
|
||||
yield chunk
|
||||
return # stream completed successfully
|
||||
except A2ALocalhostURLError as e:
|
||||
# Only retry on first chunk, not mid-stream
|
||||
if first_chunk and attempt == 0:
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=e,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
# Only map exception on first chunk
|
||||
if first_chunk and attempt == 0:
|
||||
try:
|
||||
map_a2a_exception(e, card_url, api_base, model=agent_name)
|
||||
except A2ALocalhostURLError as localhost_err:
|
||||
# Localhost URL error - fix and retry
|
||||
a2a_client = handle_a2a_localhost_retry(
|
||||
error=localhost_err,
|
||||
agent_card=agent_card,
|
||||
a2a_client=a2a_client,
|
||||
is_streaming=True,
|
||||
)
|
||||
card_url = agent_card.url if agent_card else None
|
||||
continue
|
||||
except Exception:
|
||||
# Re-raise the mapped exception
|
||||
raise
|
||||
raise
|
||||
|
||||
|
||||
async def create_a2a_client(
|
||||
base_url: str,
|
||||
timeout: float = 60.0,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "A2AClientType":
|
||||
"""
|
||||
Create an A2A client for the given agent URL.
|
||||
|
||||
This resolves the agent card and returns a ready-to-use A2A client.
|
||||
The client can be reused for multiple requests.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
An initialized a2a.client.A2AClient instance
|
||||
|
||||
Example:
|
||||
```python
|
||||
from litellm.a2a_protocol import create_a2a_client, asend_message
|
||||
|
||||
# Create client once
|
||||
client = await create_a2a_client(base_url="http://localhost:10001")
|
||||
|
||||
# Reuse for multiple requests
|
||||
response1 = await asend_message(a2a_client=client, request=request1)
|
||||
response2 = await asend_message(a2a_client=client, request=request2)
|
||||
```
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Creating A2A client for {base_url}")
|
||||
|
||||
# Use get_async_httpx_client with per-agent params so that different agents
|
||||
# (with different extra_headers) get separate cached clients. The params
|
||||
# dict is hashed into the cache key, keeping agent auth isolated while
|
||||
# still reusing connections within the same agent.
|
||||
#
|
||||
# Only pass params that AsyncHTTPHandler.__init__ accepts (e.g. timeout).
|
||||
# Use "disable_aiohttp_transport" key for cache-key-only data (it's
|
||||
# filtered out before reaching the constructor).
|
||||
_client_params: dict = {"timeout": timeout}
|
||||
if extra_headers:
|
||||
# Encode headers into a cache-key-only param so each unique header
|
||||
# set produces a distinct cache key.
|
||||
_client_params["disable_aiohttp_transport"] = str(sorted(extra_headers.items()))
|
||||
_async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2AProvider,
|
||||
params=_client_params,
|
||||
)
|
||||
httpx_client = _async_handler.client
|
||||
if extra_headers:
|
||||
httpx_client.headers.update(extra_headers)
|
||||
verbose_proxy_logger.debug(
|
||||
f"A2A client created with extra_headers={list(extra_headers.keys())}"
|
||||
)
|
||||
|
||||
# Resolve agent card
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Resolved agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
|
||||
# Create A2A client
|
||||
a2a_client = _A2AClient(
|
||||
httpx_client=httpx_client,
|
||||
agent_card=agent_card,
|
||||
)
|
||||
|
||||
# Store agent_card on client for later retrieval (SDK doesn't expose it)
|
||||
a2a_client._litellm_agent_card = agent_card # type: ignore[attr-defined]
|
||||
|
||||
verbose_logger.info(f"A2A client created for {base_url}")
|
||||
|
||||
return a2a_client
|
||||
|
||||
|
||||
async def aget_agent_card(
|
||||
base_url: str,
|
||||
timeout: float = DEFAULT_A2A_AGENT_TIMEOUT,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> "AgentCard":
|
||||
"""
|
||||
Fetch the agent card from an A2A agent.
|
||||
|
||||
Args:
|
||||
base_url: The base URL of the A2A agent (e.g., "http://localhost:10001")
|
||||
timeout: Request timeout in seconds (default: 60.0)
|
||||
extra_headers: Optional additional headers to include in requests
|
||||
|
||||
Returns:
|
||||
AgentCard from the A2A agent
|
||||
"""
|
||||
if not A2A_SDK_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The 'a2a' package is required for A2A agent invocation. "
|
||||
"Install it with: pip install a2a-sdk"
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Fetching agent card from {base_url}")
|
||||
|
||||
# Use LiteLLM's cached httpx client
|
||||
http_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.A2A,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
httpx_client = http_handler.client
|
||||
|
||||
resolver = A2ACardResolver(
|
||||
httpx_client=httpx_client,
|
||||
base_url=base_url,
|
||||
)
|
||||
agent_card = await resolver.get_agent_card()
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched agent card: {agent_card.name if hasattr(agent_card, 'name') else 'unknown'}"
|
||||
)
|
||||
return agent_card
|
||||
@@ -0,0 +1,10 @@
|
||||
"""
|
||||
A2A Protocol Providers.
|
||||
|
||||
This module contains provider-specific implementations for the A2A protocol.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.config_manager import A2AProviderConfigManager
|
||||
|
||||
__all__ = ["BaseA2AProviderConfig", "A2AProviderConfigManager"]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base configuration for A2A protocol providers.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
|
||||
class BaseA2AProviderConfig(ABC):
|
||||
"""
|
||||
Base configuration class for A2A protocol providers.
|
||||
|
||||
Each provider should implement this interface to define how to handle
|
||||
A2A requests for their specific agent type.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the agent
|
||||
**kwargs: Additional provider-specific parameters
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# This is an abstract method - subclasses must implement
|
||||
# The yield is here to make this a generator function
|
||||
if False: # pragma: no cover
|
||||
yield {}
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
A2A Provider Config Manager.
|
||||
|
||||
Manages provider-specific configurations for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
|
||||
|
||||
class A2AProviderConfigManager:
|
||||
"""
|
||||
Manager for A2A provider configurations.
|
||||
|
||||
Similar to ProviderConfigManager in litellm.utils but specifically for A2A providers.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_provider_config(
|
||||
custom_llm_provider: Optional[str],
|
||||
) -> Optional[BaseA2AProviderConfig]:
|
||||
"""
|
||||
Get the provider configuration for a given custom_llm_provider.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider identifier (e.g., "pydantic_ai_agents")
|
||||
|
||||
Returns:
|
||||
Provider configuration instance or None if not found
|
||||
"""
|
||||
if custom_llm_provider is None:
|
||||
return None
|
||||
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
|
||||
return PydanticAIProviderConfig()
|
||||
|
||||
# Add more providers here as needed
|
||||
# elif custom_llm_provider == "another_provider":
|
||||
# from litellm.a2a_protocol.providers.another_provider.config import AnotherProviderConfig
|
||||
# return AnotherProviderConfig()
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,74 @@
|
||||
# A2A to LiteLLM Completion Bridge
|
||||
|
||||
Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A.
|
||||
|
||||
## Flow
|
||||
|
||||
```
|
||||
A2A Request → Transform → litellm.acompletion → Transform → A2A Response
|
||||
```
|
||||
|
||||
## SDK Usage
|
||||
|
||||
Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`:
|
||||
|
||||
```python
|
||||
from litellm.a2a_protocol import asend_message, asend_message_streaming
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams
|
||||
from uuid import uuid4
|
||||
|
||||
# Non-streaming
|
||||
request = SendMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
response = await asend_message(
|
||||
request=request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
)
|
||||
|
||||
# Streaming
|
||||
stream_request = SendStreamingMessageRequest(
|
||||
id=str(uuid4()),
|
||||
params=MessageSendParams(
|
||||
message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex}
|
||||
)
|
||||
)
|
||||
async for chunk in asend_message_streaming(
|
||||
request=stream_request,
|
||||
api_base="http://localhost:2024",
|
||||
litellm_params={"custom_llm_provider": "langgraph", "model": "agent"},
|
||||
):
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Proxy Usage
|
||||
|
||||
Configure an agent with `custom_llm_provider` in `litellm_params`:
|
||||
|
||||
```yaml
|
||||
agents:
|
||||
- agent_name: my-langgraph-agent
|
||||
agent_card_params:
|
||||
name: "LangGraph Agent"
|
||||
url: "http://localhost:2024" # Used as api_base
|
||||
litellm_params:
|
||||
custom_llm_provider: langgraph
|
||||
model: agent
|
||||
```
|
||||
|
||||
When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge:
|
||||
|
||||
1. Detects `custom_llm_provider` in agent's `litellm_params`
|
||||
2. Transforms A2A message → OpenAI messages
|
||||
3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")`
|
||||
4. Transforms response → A2A format
|
||||
|
||||
## Classes
|
||||
|
||||
- `A2ACompletionBridgeTransformation` - Static methods for message format conversion
|
||||
- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming)
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
LiteLLM Completion bridge provider for A2A protocol.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
"""
|
||||
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Handler for A2A to LiteLLM completion bridge.
|
||||
|
||||
Routes A2A requests through litellm.acompletion based on custom_llm_provider.
|
||||
|
||||
A2A Streaming Events (in order):
|
||||
1. Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status change to "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
4. Status update (kind: "status-update") - Final status "completed" with final=true
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
from litellm.a2a_protocol.litellm_completion_bridge.transformation import (
|
||||
A2ACompletionBridgeTransformation,
|
||||
A2AStreamingContext,
|
||||
)
|
||||
|
||||
|
||||
class A2ACompletionBridgeHandler:
|
||||
"""
|
||||
Static methods for handling A2A requests via LiteLLM completion.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming A2A request via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Routing to Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": False,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# Call litellm.acompletion
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# Transform response to A2A format
|
||||
a2a_response = (
|
||||
A2ACompletionBridgeTransformation.openai_response_to_a2a_response(
|
||||
response=response,
|
||||
request_id=request_id,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming A2A request via litellm.acompletion with stream=True.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update (kind: "artifact-update") - Content delivery
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.)
|
||||
api_base: API base URL from agent_card_params
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Check if this is a Pydantic AI agent request
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider == "pydantic_ai_agents":
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Pydantic AI agents")
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get non-streaming response first
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
)
|
||||
|
||||
# Convert to fake streaming
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=response_data,
|
||||
request_id=request_id,
|
||||
):
|
||||
yield chunk
|
||||
|
||||
return
|
||||
|
||||
# Extract message from params
|
||||
message = params.get("message", {})
|
||||
|
||||
# Create streaming context
|
||||
ctx = A2AStreamingContext(
|
||||
request_id=request_id,
|
||||
input_message=message,
|
||||
)
|
||||
|
||||
# Transform A2A message to OpenAI format
|
||||
openai_messages = (
|
||||
A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message)
|
||||
)
|
||||
|
||||
# Get completion params
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
model = litellm_params.get("model", "agent")
|
||||
|
||||
# Build full model string if provider specified
|
||||
# Skip prepending if model already starts with the provider prefix
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
full_model = f"{custom_llm_provider}/{model}"
|
||||
else:
|
||||
full_model = model
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming: model={full_model}, api_base={api_base}"
|
||||
)
|
||||
|
||||
# Build completion params dict
|
||||
completion_params = {
|
||||
"model": full_model,
|
||||
"messages": openai_messages,
|
||||
"api_base": api_base,
|
||||
"stream": True,
|
||||
}
|
||||
# Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.)
|
||||
litellm_params_to_add = {
|
||||
k: v
|
||||
for k, v in litellm_params.items()
|
||||
if k not in ("model", "custom_llm_provider")
|
||||
}
|
||||
completion_params.update(litellm_params_to_add)
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
task_event = A2ACompletionBridgeTransformation.create_task_event(ctx)
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
working_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="working",
|
||||
final=False,
|
||||
message_text="Processing request...",
|
||||
)
|
||||
yield working_event
|
||||
|
||||
# Call litellm.acompletion with streaming
|
||||
response = await litellm.acompletion(**completion_params)
|
||||
|
||||
# 3. Accumulate content and emit artifact update
|
||||
accumulated_text = ""
|
||||
chunk_count = 0
|
||||
async for chunk in response: # type: ignore[union-attr]
|
||||
chunk_count += 1
|
||||
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if content:
|
||||
accumulated_text += content
|
||||
|
||||
# Emit artifact update with accumulated content
|
||||
if accumulated_text:
|
||||
artifact_event = (
|
||||
A2ACompletionBridgeTransformation.create_artifact_update_event(
|
||||
ctx=ctx,
|
||||
text=accumulated_text,
|
||||
)
|
||||
)
|
||||
yield artifact_event
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = A2ACompletionBridgeTransformation.create_status_update_event(
|
||||
ctx=ctx,
|
||||
state="completed",
|
||||
final=True,
|
||||
)
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}"
|
||||
)
|
||||
|
||||
|
||||
# Convenience functions that delegate to the class methods
|
||||
async def handle_a2a_completion(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Convenience function for non-streaming A2A completion."""
|
||||
return await A2ACompletionBridgeHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
)
|
||||
|
||||
|
||||
async def handle_a2a_completion_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
api_base: Optional[str] = None,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Convenience function for streaming A2A completion."""
|
||||
async for chunk in A2ACompletionBridgeHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
litellm_params=litellm_params,
|
||||
api_base=api_base,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,284 @@
|
||||
"""
|
||||
Transformation utilities for A2A <-> OpenAI message format conversion.
|
||||
|
||||
A2A Message Format:
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [{"kind": "text", "text": "Hello!"}],
|
||||
"messageId": "abc123"
|
||||
}
|
||||
|
||||
OpenAI Message Format:
|
||||
{"role": "user", "content": "Hello!"}
|
||||
|
||||
A2A Streaming Events:
|
||||
- Task event (kind: "task") - Initial task creation with status "submitted"
|
||||
- Status update (kind: "status-update") - Status changes (working, completed)
|
||||
- Artifact update (kind: "artifact-update") - Content/artifact delivery
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class A2AStreamingContext:
|
||||
"""
|
||||
Context holder for A2A streaming state.
|
||||
Tracks task_id, context_id, and message accumulation.
|
||||
"""
|
||||
|
||||
def __init__(self, request_id: str, input_message: Dict[str, Any]):
|
||||
self.request_id = request_id
|
||||
self.task_id = str(uuid4())
|
||||
self.context_id = str(uuid4())
|
||||
self.input_message = input_message
|
||||
self.accumulated_text = ""
|
||||
self.has_emitted_task = False
|
||||
self.has_emitted_working = False
|
||||
|
||||
|
||||
class A2ACompletionBridgeTransformation:
|
||||
"""
|
||||
Static methods for transforming between A2A and OpenAI message formats.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def a2a_message_to_openai_messages(
|
||||
a2a_message: Dict[str, Any],
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Transform an A2A message to OpenAI message format.
|
||||
|
||||
Args:
|
||||
a2a_message: A2A message with role, parts, and messageId
|
||||
|
||||
Returns:
|
||||
List of OpenAI-format messages
|
||||
"""
|
||||
role = a2a_message.get("role", "user")
|
||||
parts = a2a_message.get("parts", [])
|
||||
|
||||
# Map A2A roles to OpenAI roles
|
||||
openai_role = role
|
||||
if role == "user":
|
||||
openai_role = "user"
|
||||
elif role == "assistant":
|
||||
openai_role = "assistant"
|
||||
elif role == "system":
|
||||
openai_role = "system"
|
||||
|
||||
# Extract text content from parts
|
||||
content_parts = []
|
||||
for part in parts:
|
||||
kind = part.get("kind", "")
|
||||
if kind == "text":
|
||||
text = part.get("text", "")
|
||||
content_parts.append(text)
|
||||
|
||||
content = "\n".join(content_parts) if content_parts else ""
|
||||
|
||||
verbose_logger.debug(
|
||||
f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}"
|
||||
)
|
||||
|
||||
return [{"role": openai_role, "content": content}]
|
||||
|
||||
@staticmethod
|
||||
def openai_response_to_a2a_response(
|
||||
response: Any,
|
||||
request_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform a LiteLLM ModelResponse to A2A SendMessageResponse format.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
request_id: Original A2A request ID
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
# Extract content from response
|
||||
content = ""
|
||||
if hasattr(response, "choices") and response.choices:
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message") and choice.message:
|
||||
content = choice.message.content or ""
|
||||
|
||||
# Build A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
}
|
||||
|
||||
# Build A2A response
|
||||
a2a_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}")
|
||||
|
||||
return a2a_response
|
||||
|
||||
@staticmethod
|
||||
def _get_timestamp() -> str:
|
||||
"""Get current timestamp in ISO format with timezone."""
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
@staticmethod
|
||||
def create_task_event(
|
||||
ctx: A2AStreamingContext,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create the initial task event with status 'submitted'.
|
||||
|
||||
This is the first event emitted in an A2A streaming response.
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": ctx.input_message.get("messageId", uuid4().hex),
|
||||
"parts": ctx.input_message.get("parts", []),
|
||||
"role": ctx.input_message.get("role", "user"),
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
],
|
||||
"id": ctx.task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_status_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
state: str,
|
||||
final: bool = False,
|
||||
message_text: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a status update event.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
state: Status state ('working', 'completed')
|
||||
final: Whether this is the final event
|
||||
message_text: Optional message text for 'working' status
|
||||
"""
|
||||
status: Dict[str, Any] = {
|
||||
"state": state,
|
||||
"timestamp": A2ACompletionBridgeTransformation._get_timestamp(),
|
||||
}
|
||||
|
||||
# Add message for 'working' status
|
||||
if state == "working" and message_text:
|
||||
status["message"] = {
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "message",
|
||||
"messageId": str(uuid4()),
|
||||
"parts": [{"kind": "text", "text": message_text}],
|
||||
"role": "agent",
|
||||
"taskId": ctx.task_id,
|
||||
}
|
||||
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"contextId": ctx.context_id,
|
||||
"final": final,
|
||||
"kind": "status-update",
|
||||
"status": status,
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def create_artifact_update_event(
|
||||
ctx: A2AStreamingContext,
|
||||
text: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create an artifact update event with content.
|
||||
|
||||
Args:
|
||||
ctx: Streaming context
|
||||
text: The text content for the artifact
|
||||
"""
|
||||
return {
|
||||
"id": ctx.request_id,
|
||||
"jsonrpc": "2.0",
|
||||
"result": {
|
||||
"artifact": {
|
||||
"artifactId": str(uuid4()),
|
||||
"name": "response",
|
||||
"parts": [{"kind": "text", "text": text}],
|
||||
},
|
||||
"contextId": ctx.context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": ctx.task_id,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def openai_chunk_to_a2a_chunk(
|
||||
chunk: Any,
|
||||
request_id: Optional[str] = None,
|
||||
is_final: bool = False,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Transform a LiteLLM streaming chunk to A2A streaming format.
|
||||
|
||||
NOTE: This method is deprecated for streaming. Use the event-based
|
||||
methods (create_task_event, create_status_update_event,
|
||||
create_artifact_update_event) instead for proper A2A streaming.
|
||||
|
||||
Args:
|
||||
chunk: LiteLLM ModelResponse chunk
|
||||
request_id: Original A2A request ID
|
||||
is_final: Whether this is the final chunk
|
||||
|
||||
Returns:
|
||||
A2A streaming chunk dict or None if no content
|
||||
"""
|
||||
# Extract delta content
|
||||
content = ""
|
||||
if chunk is not None and hasattr(chunk, "choices") and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
if hasattr(choice, "delta") and choice.delta:
|
||||
content = choice.delta.content or ""
|
||||
|
||||
if not content and not is_final:
|
||||
return None
|
||||
|
||||
# Build A2A streaming chunk (legacy format)
|
||||
a2a_chunk = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": content}],
|
||||
"messageId": uuid4().hex,
|
||||
},
|
||||
"final": is_final,
|
||||
},
|
||||
}
|
||||
|
||||
return a2a_chunk
|
||||
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Pydantic AI agent provider for A2A protocol.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This provider handles fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.config import (
|
||||
PydanticAIProviderConfig,
|
||||
)
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
__all__ = ["PydanticAIHandler", "PydanticAITransformation", "PydanticAIProviderConfig"]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Pydantic AI provider configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm.a2a_protocol.providers.base import BaseA2AProviderConfig
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.handler import PydanticAIHandler
|
||||
|
||||
|
||||
class PydanticAIProviderConfig(BaseA2AProviderConfig):
|
||||
"""
|
||||
Provider configuration for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This config provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
async def handle_non_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle non-streaming request to Pydantic AI agent."""
|
||||
return await PydanticAIHandler.handle_non_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
)
|
||||
|
||||
async def handle_streaming(
|
||||
self,
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
**kwargs,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""Handle streaming request with fake streaming."""
|
||||
async for chunk in PydanticAIHandler.handle_streaming(
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
api_base=api_base,
|
||||
timeout=kwargs.get("timeout", 60.0),
|
||||
chunk_size=kwargs.get("chunk_size", 50),
|
||||
delay_ms=kwargs.get("delay_ms", 10),
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,102 @@
|
||||
"""
|
||||
Handler for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming natively.
|
||||
This handler provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Dict
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.providers.pydantic_ai_agents.transformation import (
|
||||
PydanticAITransformation,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAIHandler:
|
||||
"""
|
||||
Handler for Pydantic AI agent requests.
|
||||
|
||||
Provides:
|
||||
- Direct non-streaming requests to Pydantic AI agents
|
||||
- Fake streaming by converting non-streaming responses into streaming chunks
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
async def handle_non_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle non-streaming request to Pydantic AI agent.
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
A2A SendMessageResponse dict
|
||||
"""
|
||||
verbose_logger.info(f"Pydantic AI: Routing to Pydantic AI agent at {api_base}")
|
||||
|
||||
# Send request directly to Pydantic AI agent
|
||||
response_data = await PydanticAITransformation.send_non_streaming_request(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def handle_streaming(
|
||||
request_id: str,
|
||||
params: Dict[str, Any],
|
||||
api_base: str,
|
||||
timeout: float = 60.0,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Handle streaming request to Pydantic AI agent with fake streaming.
|
||||
|
||||
Since Pydantic AI agents don't support streaming natively, this method:
|
||||
1. Makes a non-streaming request
|
||||
2. Converts the response into streaming chunks
|
||||
|
||||
Args:
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
timeout: Request timeout in seconds
|
||||
chunk_size: Number of characters per chunk
|
||||
delay_ms: Delay between chunks in milliseconds
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}"
|
||||
)
|
||||
|
||||
# Get raw task response first (not the transformed A2A format)
|
||||
raw_response = await PydanticAITransformation.send_and_get_raw_response(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Convert raw task response to fake streaming chunks
|
||||
async for chunk in PydanticAITransformation.fake_streaming_from_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
chunk_size=chunk_size,
|
||||
delay_ms=delay_ms,
|
||||
):
|
||||
yield chunk
|
||||
@@ -0,0 +1,530 @@
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Pydantic AI agents follow A2A protocol but don't support streaming.
|
||||
This module provides fake streaming by converting non-streaming responses into streaming chunks.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, AsyncIterator, Dict, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
|
||||
|
||||
class PydanticAITransformation:
|
||||
"""
|
||||
Transformation layer for Pydantic AI agents.
|
||||
|
||||
Handles:
|
||||
- Direct A2A requests to Pydantic AI endpoints
|
||||
- Polling for task completion (since Pydantic AI doesn't support streaming)
|
||||
- Fake streaming by chunking non-streaming responses
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _remove_none_values(obj: Any) -> Any:
|
||||
"""
|
||||
Recursively remove None values from a dict/list structure.
|
||||
|
||||
FastA2A/Pydantic AI servers don't accept None values for optional fields -
|
||||
they expect those fields to be omitted entirely.
|
||||
|
||||
Args:
|
||||
obj: Dict, list, or other value to clean
|
||||
|
||||
Returns:
|
||||
Cleaned object with None values removed
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: PydanticAITransformation._remove_none_values(v)
|
||||
for k, v in obj.items()
|
||||
if v is not None
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [
|
||||
PydanticAITransformation._remove_none_values(item)
|
||||
for item in obj
|
||||
if item is not None
|
||||
]
|
||||
else:
|
||||
return obj
|
||||
|
||||
@staticmethod
|
||||
def _params_to_dict(params: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert params to a dict, handling Pydantic models.
|
||||
|
||||
Args:
|
||||
params: Dict or Pydantic model
|
||||
|
||||
Returns:
|
||||
Dict representation of params
|
||||
"""
|
||||
if hasattr(params, "model_dump"):
|
||||
# Pydantic v2 model
|
||||
return params.model_dump(mode="python", exclude_none=True)
|
||||
elif hasattr(params, "dict"):
|
||||
# Pydantic v1 model
|
||||
return params.dict(exclude_none=True)
|
||||
elif isinstance(params, dict):
|
||||
return params
|
||||
else:
|
||||
# Try to convert to dict
|
||||
return dict(params)
|
||||
|
||||
@staticmethod
|
||||
async def _poll_for_completion(
|
||||
client: AsyncHTTPHandler,
|
||||
endpoint: str,
|
||||
task_id: str,
|
||||
request_id: str,
|
||||
max_attempts: int = 30,
|
||||
poll_interval: float = 0.5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll for task completion using tasks/get method.
|
||||
|
||||
Args:
|
||||
client: HTTPX async client
|
||||
endpoint: API endpoint URL
|
||||
task_id: Task ID to poll for
|
||||
request_id: JSON-RPC request ID
|
||||
max_attempts: Maximum polling attempts
|
||||
poll_interval: Seconds between poll attempts
|
||||
|
||||
Returns:
|
||||
Completed task response
|
||||
"""
|
||||
for attempt in range(max_attempts):
|
||||
poll_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": f"{request_id}-poll-{attempt}",
|
||||
"method": "tasks/get",
|
||||
"params": {"id": task_id},
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=poll_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
poll_data = response.json()
|
||||
|
||||
result = poll_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Pydantic AI: Poll attempt {attempt + 1}/{max_attempts}, state={state}"
|
||||
)
|
||||
|
||||
if state == "completed":
|
||||
return poll_data
|
||||
elif state in ("failed", "canceled"):
|
||||
raise Exception(f"Task {task_id} ended with state: {state}")
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
raise TimeoutError(
|
||||
f"Task {task_id} did not complete within {max_attempts * poll_interval} seconds"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_and_poll_raw(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
This is an internal method used by both non-streaming and streaming handlers.
|
||||
Returns the raw Pydantic AI task format with history/artifacts.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
# Convert params to dict if it's a Pydantic model
|
||||
params_dict = PydanticAITransformation._params_to_dict(params)
|
||||
|
||||
# Remove None values - FastA2A doesn't accept null for optional fields
|
||||
params_dict = PydanticAITransformation._remove_none_values(params_dict)
|
||||
|
||||
# Ensure the message has 'kind': 'message' as required by FastA2A/Pydantic AI
|
||||
if "message" in params_dict:
|
||||
params_dict["message"]["kind"] = "message"
|
||||
|
||||
# Build A2A JSON-RPC request using message/send method for FastA2A compatibility
|
||||
a2a_request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"method": "message/send",
|
||||
"params": params_dict,
|
||||
}
|
||||
|
||||
# FastA2A uses root endpoint (/) not /messages
|
||||
endpoint = api_base.rstrip("/")
|
||||
|
||||
verbose_logger.info(f"Pydantic AI: Sending non-streaming request to {endpoint}")
|
||||
|
||||
# Send request to Pydantic AI agent using shared async HTTP client
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=cast(Any, "pydantic_ai_agent"),
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
response = await client.post(
|
||||
endpoint,
|
||||
json=a2a_request,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
|
||||
# Check if task is already completed
|
||||
result = response_data.get("result", {})
|
||||
status = result.get("status", {})
|
||||
state = status.get("state", "")
|
||||
|
||||
if state != "completed":
|
||||
# Need to poll for completion
|
||||
task_id = result.get("id")
|
||||
if task_id:
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Task {task_id} submitted, polling for completion..."
|
||||
)
|
||||
response_data = await PydanticAITransformation._poll_for_completion(
|
||||
client=client,
|
||||
endpoint=endpoint,
|
||||
task_id=task_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Received completed response for request_id={request_id}"
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
async def send_non_streaming_request(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a non-streaming A2A request to Pydantic AI agent and wait for completion.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent (e.g., "http://localhost:9999")
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message (dict or Pydantic model)
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format with message
|
||||
"""
|
||||
# Get raw task response
|
||||
raw_response = await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Transform to standard A2A non-streaming format
|
||||
return PydanticAITransformation._transform_to_a2a_response(
|
||||
response_data=raw_response,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def send_and_get_raw_response(
|
||||
api_base: str,
|
||||
request_id: str,
|
||||
params: Any,
|
||||
timeout: float = 60.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a request to Pydantic AI agent and return the raw task response.
|
||||
|
||||
Used by streaming handler to get raw response for fake streaming.
|
||||
|
||||
Args:
|
||||
api_base: Base URL of the Pydantic AI agent
|
||||
request_id: A2A JSON-RPC request ID
|
||||
params: A2A MessageSendParams containing the message
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Raw Pydantic AI task response (with history/artifacts)
|
||||
"""
|
||||
return await PydanticAITransformation._send_and_poll_raw(
|
||||
api_base=api_base,
|
||||
request_id=request_id,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _transform_to_a2a_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Pydantic AI task response to standard A2A non-streaming format.
|
||||
|
||||
Pydantic AI returns a task with history/artifacts, but the standard A2A
|
||||
non-streaming format expects:
|
||||
{
|
||||
"jsonrpc": "2.0",
|
||||
"id": "...",
|
||||
"result": {
|
||||
"message": {
|
||||
"role": "agent",
|
||||
"parts": [{"kind": "text", "text": "..."}],
|
||||
"messageId": "..."
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Args:
|
||||
response_data: Pydantic AI task response
|
||||
request_id: Original request ID
|
||||
|
||||
Returns:
|
||||
Standard A2A non-streaming response format
|
||||
"""
|
||||
# Extract the agent response text
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Build standard A2A message
|
||||
a2a_message = {
|
||||
"role": "agent",
|
||||
"parts": parts if parts else [{"kind": "text", "text": full_text}],
|
||||
"messageId": message_id,
|
||||
}
|
||||
|
||||
# Return standard A2A non-streaming format
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"message": a2a_message,
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _extract_response_text(response_data: Dict[str, Any]) -> tuple[str, str, list]:
|
||||
"""
|
||||
Extract response text from completed task response.
|
||||
|
||||
Pydantic AI returns completed tasks with:
|
||||
- history: list of messages (user and agent)
|
||||
- artifacts: list of result artifacts
|
||||
|
||||
Args:
|
||||
response_data: Completed task response
|
||||
|
||||
Returns:
|
||||
Tuple of (full_text, message_id, parts)
|
||||
"""
|
||||
result = response_data.get("result", {})
|
||||
|
||||
# Try to extract from artifacts first (preferred for results)
|
||||
artifacts = result.get("artifacts", [])
|
||||
if artifacts:
|
||||
for artifact in artifacts:
|
||||
parts = artifact.get("parts", [])
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
text = part.get("text", "")
|
||||
if text:
|
||||
return text, str(uuid4()), parts
|
||||
|
||||
# Fall back to history - get the last agent message
|
||||
history = result.get("history", [])
|
||||
for msg in reversed(history):
|
||||
if msg.get("role") == "agent":
|
||||
parts = msg.get("parts", [])
|
||||
message_id = msg.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
if full_text:
|
||||
return full_text, message_id, parts
|
||||
|
||||
# Fall back to message field (original format)
|
||||
message = result.get("message", {})
|
||||
if message:
|
||||
parts = message.get("parts", [])
|
||||
message_id = message.get("messageId", str(uuid4()))
|
||||
full_text = ""
|
||||
for part in parts:
|
||||
if part.get("kind") == "text":
|
||||
full_text += part.get("text", "")
|
||||
return full_text, message_id, parts
|
||||
|
||||
return "", str(uuid4()), []
|
||||
|
||||
@staticmethod
|
||||
async def fake_streaming_from_response(
|
||||
response_data: Dict[str, Any],
|
||||
request_id: str,
|
||||
chunk_size: int = 50,
|
||||
delay_ms: int = 10,
|
||||
) -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a non-streaming A2A response into fake streaming chunks.
|
||||
|
||||
Emits proper A2A streaming events:
|
||||
1. Task event (kind: "task") - Initial task with status "submitted"
|
||||
2. Status update (kind: "status-update") - Status "working"
|
||||
3. Artifact update chunks (kind: "artifact-update") - Content delivery in chunks
|
||||
4. Status update (kind: "status-update") - Final "completed" status
|
||||
|
||||
Args:
|
||||
response_data: Non-streaming A2A response dict (completed task)
|
||||
request_id: A2A JSON-RPC request ID
|
||||
chunk_size: Number of characters per chunk (default: 50)
|
||||
delay_ms: Delay between chunks in milliseconds (default: 10)
|
||||
|
||||
Yields:
|
||||
A2A streaming response events
|
||||
"""
|
||||
# Extract the response text from completed task
|
||||
full_text, message_id, parts = PydanticAITransformation._extract_response_text(
|
||||
response_data
|
||||
)
|
||||
|
||||
# Extract input message from raw response for history
|
||||
result = response_data.get("result", {})
|
||||
history = result.get("history", [])
|
||||
input_message = {}
|
||||
for msg in history:
|
||||
if msg.get("role") == "user":
|
||||
input_message = msg
|
||||
break
|
||||
|
||||
# Generate IDs for streaming events
|
||||
task_id = str(uuid4())
|
||||
context_id = str(uuid4())
|
||||
artifact_id = str(uuid4())
|
||||
input_message_id = input_message.get("messageId", str(uuid4()))
|
||||
|
||||
# 1. Emit initial task event (kind: "task", status: "submitted")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_task_event
|
||||
task_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"history": [
|
||||
{
|
||||
"contextId": context_id,
|
||||
"kind": "message",
|
||||
"messageId": input_message_id,
|
||||
"parts": input_message.get(
|
||||
"parts", [{"kind": "text", "text": ""}]
|
||||
),
|
||||
"role": "user",
|
||||
"taskId": task_id,
|
||||
}
|
||||
],
|
||||
"id": task_id,
|
||||
"kind": "task",
|
||||
"status": {
|
||||
"state": "submitted",
|
||||
},
|
||||
},
|
||||
}
|
||||
yield task_event
|
||||
|
||||
# 2. Emit status update (kind: "status-update", status: "working")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_status_update_event
|
||||
working_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": False,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "working",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield working_event
|
||||
|
||||
# Small delay to simulate processing
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 3. Emit artifact update chunks (kind: "artifact-update")
|
||||
# Format matches A2ACompletionBridgeTransformation.create_artifact_update_event
|
||||
if full_text:
|
||||
# Split text into chunks
|
||||
for i in range(0, len(full_text), chunk_size):
|
||||
chunk_text = full_text[i : i + chunk_size]
|
||||
is_last_chunk = (i + chunk_size) >= len(full_text)
|
||||
|
||||
artifact_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"kind": "artifact-update",
|
||||
"taskId": task_id,
|
||||
"artifact": {
|
||||
"artifactId": artifact_id,
|
||||
"parts": [
|
||||
{
|
||||
"kind": "text",
|
||||
"text": chunk_text,
|
||||
}
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
yield artifact_event
|
||||
|
||||
# Add delay between chunks (except for last chunk)
|
||||
if not is_last_chunk:
|
||||
await asyncio.sleep(delay_ms / 1000.0)
|
||||
|
||||
# 4. Emit final status update (kind: "status-update", status: "completed", final: true)
|
||||
completed_event = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
"result": {
|
||||
"contextId": context_id,
|
||||
"final": True,
|
||||
"kind": "status-update",
|
||||
"status": {
|
||||
"state": "completed",
|
||||
},
|
||||
"taskId": task_id,
|
||||
},
|
||||
}
|
||||
yield completed_event
|
||||
|
||||
verbose_logger.info(
|
||||
f"Pydantic AI: Fake streaming completed for request_id={request_id}"
|
||||
)
|
||||
@@ -0,0 +1,184 @@
|
||||
"""
|
||||
A2A Streaming Iterator with token tracking and logging support.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.a2a_protocol.cost_calculator import A2ACostCalculator
|
||||
from litellm.a2a_protocol.utils import A2ARequestUtils
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendStreamingMessageRequest, SendStreamingMessageResponse
|
||||
|
||||
|
||||
class A2AStreamingIterator:
|
||||
"""
|
||||
Async iterator for A2A streaming responses with token tracking.
|
||||
|
||||
Collects chunks, extracts text, and logs usage on completion.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: AsyncIterator["SendStreamingMessageResponse"],
|
||||
request: "SendStreamingMessageRequest",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
agent_name: str = "unknown",
|
||||
):
|
||||
self.stream = stream
|
||||
self.request = request
|
||||
self.logging_obj = logging_obj
|
||||
self.agent_name = agent_name
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Collect chunks for token counting
|
||||
self.chunks: List[Any] = []
|
||||
self.collected_text_parts: List[str] = []
|
||||
self.final_chunk: Optional[Any] = None
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> "SendStreamingMessageResponse":
|
||||
try:
|
||||
chunk = await self.stream.__anext__()
|
||||
|
||||
# Store chunk
|
||||
self.chunks.append(chunk)
|
||||
|
||||
# Extract text from chunk for token counting
|
||||
self._collect_text_from_chunk(chunk)
|
||||
|
||||
# Check if this is the final chunk (completed status)
|
||||
if self._is_completed_chunk(chunk):
|
||||
self.final_chunk = chunk
|
||||
|
||||
return chunk
|
||||
|
||||
except StopAsyncIteration:
|
||||
# Stream ended - handle logging
|
||||
if self.final_chunk is None and self.chunks:
|
||||
self.final_chunk = self.chunks[-1]
|
||||
await self._handle_stream_complete()
|
||||
raise
|
||||
|
||||
def _collect_text_from_chunk(self, chunk: Any) -> None:
|
||||
"""Extract text from a streaming chunk and add to collected parts."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
text = A2ARequestUtils.extract_text_from_response(chunk_dict)
|
||||
if text:
|
||||
self.collected_text_parts.append(text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to extract text from A2A streaming chunk")
|
||||
|
||||
def _is_completed_chunk(self, chunk: Any) -> bool:
|
||||
"""Check if chunk indicates stream completion."""
|
||||
try:
|
||||
chunk_dict = (
|
||||
chunk.model_dump(mode="json", exclude_none=True)
|
||||
if hasattr(chunk, "model_dump")
|
||||
else {}
|
||||
)
|
||||
result = chunk_dict.get("result", {})
|
||||
if isinstance(result, dict):
|
||||
status = result.get("status", {})
|
||||
if isinstance(status, dict):
|
||||
return status.get("state") == "completed"
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
async def _handle_stream_complete(self) -> None:
|
||||
"""Handle logging and token counting when stream completes."""
|
||||
try:
|
||||
end_time = datetime.now()
|
||||
|
||||
# Calculate tokens from collected text
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(self.request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Use the last (most complete) text from chunks
|
||||
output_text = (
|
||||
self.collected_text_parts[-1] if self.collected_text_parts else ""
|
||||
)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
# Create usage object
|
||||
usage = litellm.Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
# Set usage on logging obj
|
||||
self.logging_obj.model_call_details["usage"] = usage
|
||||
# Mark stream flag for downstream callbacks
|
||||
self.logging_obj.model_call_details["stream"] = False
|
||||
|
||||
# Calculate cost using A2ACostCalculator
|
||||
response_cost = A2ACostCalculator.calculate_a2a_cost(self.logging_obj)
|
||||
self.logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
# Build result for logging
|
||||
result = self._build_logging_result(usage)
|
||||
|
||||
# Call success handlers - they will build standard_logging_object
|
||||
asyncio.create_task(
|
||||
self.logging_obj.async_success_handler(
|
||||
result=result,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
|
||||
executor.submit(
|
||||
self.logging_obj.success_handler,
|
||||
result=result,
|
||||
cache_hit=None,
|
||||
start_time=self.start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
f"A2A streaming completed: prompt_tokens={prompt_tokens}, "
|
||||
f"completion_tokens={completion_tokens}, total_tokens={total_tokens}, "
|
||||
f"response_cost={response_cost}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error in A2A streaming completion handler: {e}")
|
||||
|
||||
def _build_logging_result(self, usage: litellm.Usage) -> Dict[str, Any]:
|
||||
"""Build a result dict for logging."""
|
||||
result: Dict[str, Any] = {
|
||||
"id": getattr(self.request, "id", "unknown"),
|
||||
"jsonrpc": "2.0",
|
||||
"usage": usage.model_dump()
|
||||
if hasattr(usage, "model_dump")
|
||||
else dict(usage),
|
||||
}
|
||||
|
||||
# Add final chunk result if available
|
||||
if self.final_chunk:
|
||||
try:
|
||||
chunk_dict = self.final_chunk.model_dump(mode="json", exclude_none=True)
|
||||
result["result"] = chunk_dict.get("result", {})
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return result
|
||||
@@ -0,0 +1,138 @@
|
||||
"""
|
||||
Utility functions for A2A protocol.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import SendMessageRequest, SendStreamingMessageRequest
|
||||
|
||||
|
||||
class A2ARequestUtils:
|
||||
"""Utility class for A2A request/response processing."""
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_message(message: Any) -> str:
|
||||
"""
|
||||
Extract text content from A2A message parts.
|
||||
|
||||
Args:
|
||||
message: A2A message dict or object with 'parts' containing text parts
|
||||
|
||||
Returns:
|
||||
Concatenated text from all text parts
|
||||
"""
|
||||
if message is None:
|
||||
return ""
|
||||
|
||||
# Handle both dict and object access
|
||||
if isinstance(message, dict):
|
||||
parts = message.get("parts", [])
|
||||
else:
|
||||
parts = getattr(message, "parts", []) or []
|
||||
|
||||
text_parts: List[str] = []
|
||||
for part in parts:
|
||||
if isinstance(part, dict):
|
||||
if part.get("kind") == "text":
|
||||
text_parts.append(part.get("text", ""))
|
||||
else:
|
||||
if getattr(part, "kind", None) == "text":
|
||||
text_parts.append(getattr(part, "text", ""))
|
||||
|
||||
return " ".join(text_parts)
|
||||
|
||||
@staticmethod
|
||||
def extract_text_from_response(response_dict: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract text content from A2A response result.
|
||||
|
||||
Args:
|
||||
response_dict: A2A response dict with 'result' containing message
|
||||
|
||||
Returns:
|
||||
Text from response message parts
|
||||
"""
|
||||
result = response_dict.get("result", {})
|
||||
if not isinstance(result, dict):
|
||||
return ""
|
||||
|
||||
message = result.get("message", {})
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
@staticmethod
|
||||
def get_input_message_from_request(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
) -> Any:
|
||||
"""
|
||||
Extract the input message from an A2A request.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
|
||||
Returns:
|
||||
The message object/dict or None
|
||||
"""
|
||||
params = getattr(request, "params", None)
|
||||
if params is None:
|
||||
return None
|
||||
return getattr(params, "message", None)
|
||||
|
||||
@staticmethod
|
||||
def count_tokens(text: str) -> int:
|
||||
"""
|
||||
Count tokens in text using litellm.token_counter.
|
||||
|
||||
Args:
|
||||
text: Text to count tokens for
|
||||
|
||||
Returns:
|
||||
Token count, or 0 if counting fails
|
||||
"""
|
||||
if not text:
|
||||
return 0
|
||||
try:
|
||||
return litellm.token_counter(text=text)
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to count tokens")
|
||||
return 0
|
||||
|
||||
@staticmethod
|
||||
def calculate_usage_from_request_response(
|
||||
request: "Union[SendMessageRequest, SendStreamingMessageRequest]",
|
||||
response_dict: Dict[str, Any],
|
||||
) -> Tuple[int, int, int]:
|
||||
"""
|
||||
Calculate token usage from A2A request and response.
|
||||
|
||||
Args:
|
||||
request: The A2A SendMessageRequest or SendStreamingMessageRequest
|
||||
response_dict: The A2A response as a dict
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_tokens, completion_tokens, total_tokens)
|
||||
"""
|
||||
# Count input tokens
|
||||
input_message = A2ARequestUtils.get_input_message_from_request(request)
|
||||
input_text = A2ARequestUtils.extract_text_from_message(input_message)
|
||||
prompt_tokens = A2ARequestUtils.count_tokens(input_text)
|
||||
|
||||
# Count output tokens
|
||||
output_text = A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
completion_tokens = A2ARequestUtils.count_tokens(output_text)
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
return prompt_tokens, completion_tokens, total_tokens
|
||||
|
||||
|
||||
# Backwards compatibility aliases
|
||||
def extract_text_from_a2a_message(message: Any) -> str:
|
||||
return A2ARequestUtils.extract_text_from_message(message)
|
||||
|
||||
|
||||
def extract_text_from_a2a_response(response_dict: Dict[str, Any]) -> str:
|
||||
return A2ARequestUtils.extract_text_from_response(response_dict)
|
||||
@@ -0,0 +1,182 @@
|
||||
{
|
||||
"description": "Mapping of Anthropic beta headers for each provider. Keys are input header names, values are provider-specific header names (or null if unsupported). Only headers present in mapping keys with non-null values can be forwarded.",
|
||||
"anthropic": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"structured-output-2024-03-01": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"oauth-2025-04-20": "oauth-2025-04-20",
|
||||
"output-128k-2025-02-19": "output-128k-2025-02-19",
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"azure_ai": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"bedrock_converse": {
|
||||
"advanced-tool-use-2025-11-20": null,
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": null,
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": null
|
||||
},
|
||||
"bedrock": {
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": null,
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": null
|
||||
},
|
||||
"vertex_ai": {
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": null,
|
||||
"compact-2026-01-12": null,
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": null,
|
||||
"fast-mode-2026-02-01": null,
|
||||
"files-api-2025-04-14": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": null,
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": null,
|
||||
"mcp-client-2025-04-04": null,
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"output-128k-2025-02-19": null,
|
||||
"structured-output-2024-03-01": null,
|
||||
"prompt-caching-scope-2026-01-05": null,
|
||||
"skills-2025-10-02": null,
|
||||
"structured-outputs-2025-11-13": null,
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": null,
|
||||
"tool-search-tool-2025-10-19": "tool-search-tool-2025-10-19",
|
||||
"web-fetch-2025-09-10": null,
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
},
|
||||
"databricks": {
|
||||
"advanced-tool-use-2025-11-20": "advanced-tool-use-2025-11-20",
|
||||
"bash_20241022": null,
|
||||
"bash_20250124": null,
|
||||
"code-execution-2025-08-25": "code-execution-2025-08-25",
|
||||
"compact-2026-01-12": "compact-2026-01-12",
|
||||
"computer-use-2025-01-24": "computer-use-2025-01-24",
|
||||
"computer-use-2025-11-24": "computer-use-2025-11-24",
|
||||
"context-1m-2025-08-07": "context-1m-2025-08-07",
|
||||
"context-management-2025-06-27": "context-management-2025-06-27",
|
||||
"effort-2025-11-24": "effort-2025-11-24",
|
||||
"fast-mode-2026-02-01": "fast-mode-2026-02-01",
|
||||
"files-api-2025-04-14": "files-api-2025-04-14",
|
||||
"structured-output-2024-03-01": null,
|
||||
"fine-grained-tool-streaming-2025-05-14": "fine-grained-tool-streaming-2025-05-14",
|
||||
"interleaved-thinking-2025-05-14": "interleaved-thinking-2025-05-14",
|
||||
"mcp-client-2025-11-20": "mcp-client-2025-11-20",
|
||||
"mcp-client-2025-04-04": "mcp-client-2025-04-04",
|
||||
"mcp-servers-2025-12-04": null,
|
||||
"oauth-2025-04-20": "oauth-2025-04-20",
|
||||
"output-128k-2025-02-19": "output-128k-2025-02-19",
|
||||
"prompt-caching-scope-2026-01-05": "prompt-caching-scope-2026-01-05",
|
||||
"skills-2025-10-02": "skills-2025-10-02",
|
||||
"structured-outputs-2025-11-13": "structured-outputs-2025-11-13",
|
||||
"text_editor_20241022": null,
|
||||
"text_editor_20250124": null,
|
||||
"token-efficient-tools-2025-02-19": "token-efficient-tools-2025-02-19",
|
||||
"web-fetch-2025-09-10": "web-fetch-2025-09-10",
|
||||
"web-search-2025-03-05": "web-search-2025-03-05"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,385 @@
|
||||
"""
|
||||
Centralized manager for Anthropic beta headers across different providers.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Load beta header configuration from JSON (mapping of supported headers per provider)
|
||||
2. Filter and map beta headers based on provider support
|
||||
3. Handle provider-specific header name mappings (e.g., advanced-tool-use -> tool-search-tool)
|
||||
4. Support remote fetching and caching similar to model cost map
|
||||
|
||||
Design:
|
||||
- JSON config contains mapping of beta headers for each provider
|
||||
- Keys are input header names, values are provider-specific header names (or null if unsupported)
|
||||
- Only headers present in mapping keys with non-null values can be forwarded
|
||||
- This enforces stricter validation than the previous unsupported list approach
|
||||
|
||||
Configuration can be loaded from:
|
||||
- Remote URL (default): Fetches from GitHub repository
|
||||
- Local file: Set LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True to use bundled config only
|
||||
|
||||
Environment Variables:
|
||||
- LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS: Set to "True" to disable remote fetching
|
||||
- LITELLM_ANTHROPIC_BETA_HEADERS_URL: Custom URL for remote config (optional)
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib.resources import files
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import verbose_logger
|
||||
|
||||
# Cache for the loaded configuration
|
||||
_BETA_HEADERS_CONFIG: Optional[Dict] = None
|
||||
|
||||
|
||||
class GetAnthropicBetaHeadersConfig:
|
||||
"""
|
||||
Handles fetching, validating, and loading the Anthropic beta headers configuration.
|
||||
|
||||
Similar to GetModelCostMap, this class manages the lifecycle of the beta headers
|
||||
configuration with support for remote fetching and local fallback.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def load_local_beta_headers_config() -> Dict:
|
||||
"""Load the local backup beta headers config bundled with the package."""
|
||||
try:
|
||||
content = json.loads(
|
||||
files("litellm")
|
||||
.joinpath("anthropic_beta_headers_config.json")
|
||||
.read_text(encoding="utf-8")
|
||||
)
|
||||
return content
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to load local beta headers config: {e}")
|
||||
# Return empty config as fallback
|
||||
return {
|
||||
"anthropic": {},
|
||||
"azure_ai": {},
|
||||
"bedrock": {},
|
||||
"bedrock_converse": {},
|
||||
"vertex_ai": {},
|
||||
"provider_aliases": {},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _check_is_valid_dict(fetched_config: dict) -> bool:
|
||||
"""Check if fetched config is a non-empty dict with expected structure."""
|
||||
if not isinstance(fetched_config, dict):
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config is not a dict (type=%s). "
|
||||
"Falling back to local backup.",
|
||||
type(fetched_config).__name__,
|
||||
)
|
||||
return False
|
||||
|
||||
if len(fetched_config) == 0:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config is empty. "
|
||||
"Falling back to local backup.",
|
||||
)
|
||||
return False
|
||||
|
||||
# Check for at least one provider key
|
||||
provider_keys = [
|
||||
"anthropic",
|
||||
"azure_ai",
|
||||
"bedrock",
|
||||
"bedrock_converse",
|
||||
"vertex_ai",
|
||||
]
|
||||
has_provider = any(key in fetched_config for key in provider_keys)
|
||||
|
||||
if not has_provider:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config missing provider keys. "
|
||||
"Falling back to local backup.",
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_beta_headers_config(cls, fetched_config: dict) -> bool:
|
||||
"""
|
||||
Validate the integrity of a fetched beta headers config.
|
||||
|
||||
Returns True if all checks pass, False otherwise.
|
||||
"""
|
||||
return cls._check_is_valid_dict(fetched_config)
|
||||
|
||||
@staticmethod
|
||||
def fetch_remote_beta_headers_config(url: str, timeout: int = 5) -> dict:
|
||||
"""
|
||||
Fetch the beta headers config from a remote URL.
|
||||
|
||||
Returns the parsed JSON dict. Raises on network/parse errors
|
||||
(caller is expected to handle).
|
||||
"""
|
||||
response = httpx.get(url, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_beta_headers_config(url: str) -> dict:
|
||||
"""
|
||||
Public entry point — returns the beta headers config dict.
|
||||
|
||||
1. If ``LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS`` is set, uses the local backup only.
|
||||
2. Otherwise fetches from ``url``, validates integrity, and falls back
|
||||
to the local backup on any failure.
|
||||
|
||||
Args:
|
||||
url: URL to fetch the remote beta headers configuration from
|
||||
|
||||
Returns:
|
||||
Dict containing the beta headers configuration
|
||||
"""
|
||||
# Check if local-only mode is enabled
|
||||
if os.getenv("LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS", "").lower() == "true":
|
||||
# verbose_logger.debug("Using local Anthropic beta headers config (LITELLM_LOCAL_ANTHROPIC_BETA_HEADERS=True)")
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
try:
|
||||
content = GetAnthropicBetaHeadersConfig.fetch_remote_beta_headers_config(url)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Failed to fetch remote beta headers config from %s: %s. "
|
||||
"Falling back to local backup.",
|
||||
url,
|
||||
str(e),
|
||||
)
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
# Validate the fetched config
|
||||
if not GetAnthropicBetaHeadersConfig.validate_beta_headers_config(
|
||||
fetched_config=content
|
||||
):
|
||||
verbose_logger.warning(
|
||||
"LiteLLM: Fetched beta headers config failed integrity check. "
|
||||
"Using local backup instead. url=%s",
|
||||
url,
|
||||
)
|
||||
return GetAnthropicBetaHeadersConfig.load_local_beta_headers_config()
|
||||
|
||||
return content
|
||||
|
||||
|
||||
def _load_beta_headers_config() -> Dict:
|
||||
"""
|
||||
Load the beta headers configuration.
|
||||
Uses caching to avoid repeated fetches/file reads.
|
||||
|
||||
This function is called by all public API functions and manages the global cache.
|
||||
|
||||
Returns:
|
||||
Dict containing the beta headers configuration
|
||||
"""
|
||||
global _BETA_HEADERS_CONFIG
|
||||
|
||||
if _BETA_HEADERS_CONFIG is not None:
|
||||
return _BETA_HEADERS_CONFIG
|
||||
|
||||
# Get the URL from environment or use default
|
||||
from litellm import anthropic_beta_headers_url
|
||||
|
||||
_BETA_HEADERS_CONFIG = get_beta_headers_config(url=anthropic_beta_headers_url)
|
||||
verbose_logger.debug("Loaded and cached beta headers config")
|
||||
|
||||
return _BETA_HEADERS_CONFIG
|
||||
|
||||
|
||||
def reload_beta_headers_config() -> Dict:
|
||||
"""
|
||||
Force reload the beta headers configuration from source (remote or local).
|
||||
Clears the cache and fetches fresh configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing the newly loaded beta headers configuration
|
||||
"""
|
||||
global _BETA_HEADERS_CONFIG
|
||||
_BETA_HEADERS_CONFIG = None
|
||||
verbose_logger.info("Reloading beta headers config (cache cleared)")
|
||||
return _load_beta_headers_config()
|
||||
|
||||
|
||||
def get_provider_name(provider: str) -> str:
|
||||
"""
|
||||
Resolve provider aliases to canonical provider names.
|
||||
|
||||
Args:
|
||||
provider: Provider name (may be an alias)
|
||||
|
||||
Returns:
|
||||
Canonical provider name
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
aliases = config.get("provider_aliases", {})
|
||||
return aliases.get(provider, provider)
|
||||
|
||||
|
||||
def filter_and_transform_beta_headers(
|
||||
beta_headers: List[str],
|
||||
provider: str,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Filter and transform beta headers based on provider's mapping configuration.
|
||||
|
||||
This function:
|
||||
1. Only allows headers that are present in the provider's mapping keys
|
||||
2. Filters out headers with null values (unsupported)
|
||||
3. Maps headers to provider-specific names (e.g., advanced-tool-use -> tool-search-tool)
|
||||
|
||||
Args:
|
||||
beta_headers: List of Anthropic beta header values
|
||||
provider: Provider name (e.g., "anthropic", "bedrock", "vertex_ai")
|
||||
|
||||
Returns:
|
||||
List of filtered and transformed beta headers for the provider
|
||||
"""
|
||||
if not beta_headers:
|
||||
return []
|
||||
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
|
||||
# Get the header mapping for this provider
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
filtered_headers: Set[str] = set()
|
||||
|
||||
for header in beta_headers:
|
||||
header = header.strip()
|
||||
|
||||
# Check if header is in the mapping
|
||||
if header not in provider_mapping:
|
||||
verbose_logger.debug(
|
||||
f"Dropping unknown beta header '{header}' for provider '{provider}' (not in mapping)"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the mapped header value
|
||||
mapped_header = provider_mapping[header]
|
||||
|
||||
# Skip if header is unsupported (null value)
|
||||
if mapped_header is None:
|
||||
verbose_logger.debug(
|
||||
f"Dropping unsupported beta header '{header}' for provider '{provider}'"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add the mapped header
|
||||
filtered_headers.add(mapped_header)
|
||||
|
||||
return sorted(list(filtered_headers))
|
||||
|
||||
|
||||
def is_beta_header_supported(
|
||||
beta_header: str,
|
||||
provider: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a specific beta header is supported by a provider.
|
||||
|
||||
Args:
|
||||
beta_header: The Anthropic beta header value
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
True if the header is in the mapping with a non-null value, False otherwise
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Header is supported if it's in the mapping and has a non-null value
|
||||
return beta_header in provider_mapping and provider_mapping[beta_header] is not None
|
||||
|
||||
|
||||
def get_provider_beta_header(
|
||||
anthropic_beta_header: str,
|
||||
provider: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the provider-specific beta header name for a given Anthropic beta header.
|
||||
|
||||
This function handles header transformations/mappings (e.g., advanced-tool-use -> tool-search-tool).
|
||||
|
||||
Args:
|
||||
anthropic_beta_header: The Anthropic beta header value
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
The provider-specific header name if supported, or None if unsupported/unknown
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
|
||||
# Get the header mapping for this provider
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Check if header is in the mapping
|
||||
if anthropic_beta_header not in provider_mapping:
|
||||
return None
|
||||
|
||||
# Return the mapped value (could be None if unsupported)
|
||||
return provider_mapping[anthropic_beta_header]
|
||||
|
||||
|
||||
def update_headers_with_filtered_beta(
|
||||
headers: dict,
|
||||
provider: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Update headers dict by filtering and transforming anthropic-beta header values.
|
||||
Modifies the headers dict in place and returns it.
|
||||
|
||||
Args:
|
||||
headers: Request headers dict (will be modified in place)
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
Updated headers dict
|
||||
"""
|
||||
existing_beta = headers.get("anthropic-beta")
|
||||
if not existing_beta:
|
||||
return headers
|
||||
|
||||
# Parse existing beta headers
|
||||
beta_values = [b.strip() for b in existing_beta.split(",") if b.strip()]
|
||||
|
||||
# Filter and transform based on provider
|
||||
filtered_beta_values = filter_and_transform_beta_headers(
|
||||
beta_headers=beta_values,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Update or remove the header
|
||||
if filtered_beta_values:
|
||||
headers["anthropic-beta"] = ",".join(filtered_beta_values)
|
||||
else:
|
||||
# Remove the header if no values remain
|
||||
headers.pop("anthropic-beta", None)
|
||||
|
||||
return headers
|
||||
|
||||
|
||||
def get_unsupported_headers(provider: str) -> List[str]:
|
||||
"""
|
||||
Get all beta headers that are unsupported by a provider (have null values in mapping).
|
||||
|
||||
Args:
|
||||
provider: Provider name
|
||||
|
||||
Returns:
|
||||
List of unsupported Anthropic beta header names
|
||||
"""
|
||||
config = _load_beta_headers_config()
|
||||
provider = get_provider_name(provider)
|
||||
provider_mapping = config.get(provider, {})
|
||||
|
||||
# Return headers with null values
|
||||
return [header for header, value in provider_mapping.items() if value is None]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""
|
||||
Anthropic module for LiteLLM
|
||||
"""
|
||||
from .messages import acreate, create
|
||||
|
||||
__all__ = ["acreate", "create"]
|
||||
@@ -0,0 +1,19 @@
|
||||
"""Anthropic error format utilities."""
|
||||
|
||||
from .exception_mapping_utils import (
|
||||
ANTHROPIC_ERROR_TYPE_MAP,
|
||||
AnthropicExceptionMapping,
|
||||
)
|
||||
from .exceptions import (
|
||||
AnthropicErrorDetail,
|
||||
AnthropicErrorResponse,
|
||||
AnthropicErrorType,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnthropicErrorType",
|
||||
"AnthropicErrorDetail",
|
||||
"AnthropicErrorResponse",
|
||||
"ANTHROPIC_ERROR_TYPE_MAP",
|
||||
"AnthropicExceptionMapping",
|
||||
]
|
||||
@@ -0,0 +1,172 @@
|
||||
"""
|
||||
Utilities for mapping exceptions to Anthropic error format.
|
||||
|
||||
Similar to litellm/litellm_core_utils/exception_mapping_utils.py but for Anthropic response format.
|
||||
"""
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from typing import Dict, Optional
|
||||
|
||||
from .exceptions import AnthropicErrorResponse, AnthropicErrorType
|
||||
|
||||
|
||||
# HTTP status code -> Anthropic error type
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
ANTHROPIC_ERROR_TYPE_MAP: Dict[int, AnthropicErrorType] = {
|
||||
400: "invalid_request_error",
|
||||
401: "authentication_error",
|
||||
403: "permission_error",
|
||||
404: "not_found_error",
|
||||
413: "request_too_large",
|
||||
429: "rate_limit_error",
|
||||
500: "api_error",
|
||||
529: "overloaded_error",
|
||||
}
|
||||
|
||||
|
||||
class AnthropicExceptionMapping:
|
||||
"""
|
||||
Helper class for mapping exceptions to Anthropic error format.
|
||||
|
||||
Similar pattern to ExceptionCheckers in litellm_core_utils/exception_mapping_utils.py
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_error_type(status_code: int) -> AnthropicErrorType:
|
||||
"""Map HTTP status code to Anthropic error type."""
|
||||
return ANTHROPIC_ERROR_TYPE_MAP.get(status_code, "api_error")
|
||||
|
||||
@staticmethod
|
||||
def create_error_response(
|
||||
status_code: int,
|
||||
message: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> AnthropicErrorResponse:
|
||||
"""
|
||||
Create an Anthropic-formatted error response dict.
|
||||
|
||||
Anthropic error format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."},
|
||||
"request_id": "req_..."
|
||||
}
|
||||
"""
|
||||
error_type = AnthropicExceptionMapping.get_error_type(status_code)
|
||||
|
||||
response: AnthropicErrorResponse = {
|
||||
"type": "error",
|
||||
"error": {
|
||||
"type": error_type,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
|
||||
if request_id:
|
||||
response["request_id"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def extract_error_message(raw_message: str) -> str:
|
||||
"""
|
||||
Extract error message from various provider response formats.
|
||||
|
||||
Handles:
|
||||
- Bedrock: {"detail": {"message": "..."}}
|
||||
- AWS: {"Message": "..."}
|
||||
- Generic: {"message": "..."}
|
||||
- Plain strings
|
||||
"""
|
||||
parsed = safe_json_loads(raw_message)
|
||||
if isinstance(parsed, dict):
|
||||
# Bedrock format
|
||||
if "detail" in parsed and isinstance(parsed["detail"], dict):
|
||||
return parsed["detail"].get("message", raw_message)
|
||||
# AWS/generic format
|
||||
return parsed.get("Message") or parsed.get("message") or raw_message
|
||||
return raw_message
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_error_dict(parsed: dict) -> bool:
|
||||
"""
|
||||
Check if a parsed dict is in Anthropic error format.
|
||||
|
||||
Anthropic error format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."}
|
||||
}
|
||||
"""
|
||||
return (
|
||||
parsed.get("type") == "error"
|
||||
and isinstance(parsed.get("error"), dict)
|
||||
and "type" in parsed["error"]
|
||||
and "message" in parsed["error"]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_message_from_dict(parsed: dict, raw_message: str) -> str:
|
||||
"""
|
||||
Extract error message from a parsed provider-specific dict.
|
||||
|
||||
Handles:
|
||||
- Bedrock: {"detail": {"message": "..."}}
|
||||
- AWS: {"Message": "..."}
|
||||
- Generic: {"message": "..."}
|
||||
"""
|
||||
# Bedrock format
|
||||
if "detail" in parsed and isinstance(parsed["detail"], dict):
|
||||
return parsed["detail"].get("message", raw_message)
|
||||
# AWS/generic format
|
||||
return parsed.get("Message") or parsed.get("message") or raw_message
|
||||
|
||||
@staticmethod
|
||||
def transform_to_anthropic_error(
|
||||
status_code: int,
|
||||
raw_message: str,
|
||||
request_id: Optional[str] = None,
|
||||
) -> AnthropicErrorResponse:
|
||||
"""
|
||||
Transform an error message to Anthropic format.
|
||||
|
||||
- If already in Anthropic format: passthrough unchanged
|
||||
- Otherwise: extract message and create Anthropic error
|
||||
|
||||
Parses JSON only once for efficiency.
|
||||
|
||||
Args:
|
||||
status_code: HTTP status code
|
||||
raw_message: Raw error message (may be JSON string or plain text)
|
||||
request_id: Optional request ID to include
|
||||
|
||||
Returns:
|
||||
AnthropicErrorResponse dict
|
||||
"""
|
||||
# Try to parse as JSON once
|
||||
parsed: Optional[dict] = safe_json_loads(raw_message)
|
||||
if not isinstance(parsed, dict):
|
||||
parsed = None
|
||||
|
||||
# If parsed and already in Anthropic format - passthrough
|
||||
if parsed is not None and AnthropicExceptionMapping._is_anthropic_error_dict(
|
||||
parsed
|
||||
):
|
||||
# Optionally add request_id if provided and not present
|
||||
if request_id and "request_id" not in parsed:
|
||||
parsed["request_id"] = request_id
|
||||
return parsed # type: ignore
|
||||
|
||||
# Extract message - use parsed dict if available, otherwise raw string
|
||||
if parsed is not None:
|
||||
message = AnthropicExceptionMapping._extract_message_from_dict(
|
||||
parsed, raw_message
|
||||
)
|
||||
else:
|
||||
message = raw_message
|
||||
|
||||
return AnthropicExceptionMapping.create_error_response(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
request_id=request_id,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""Anthropic error format type definitions."""
|
||||
|
||||
from typing_extensions import Literal, Required, TypedDict
|
||||
|
||||
|
||||
# Known Anthropic error types
|
||||
# Source: https://docs.anthropic.com/en/api/errors
|
||||
AnthropicErrorType = Literal[
|
||||
"invalid_request_error",
|
||||
"authentication_error",
|
||||
"permission_error",
|
||||
"not_found_error",
|
||||
"request_too_large",
|
||||
"rate_limit_error",
|
||||
"api_error",
|
||||
"overloaded_error",
|
||||
]
|
||||
|
||||
|
||||
class AnthropicErrorDetail(TypedDict):
|
||||
"""Inner error detail in Anthropic format."""
|
||||
|
||||
type: AnthropicErrorType
|
||||
message: str
|
||||
|
||||
|
||||
class AnthropicErrorResponse(TypedDict, total=False):
|
||||
"""
|
||||
Anthropic-formatted error response.
|
||||
|
||||
Format:
|
||||
{
|
||||
"type": "error",
|
||||
"error": {"type": "...", "message": "..."},
|
||||
"request_id": "req_..." # optional
|
||||
}
|
||||
"""
|
||||
|
||||
type: Required[Literal["error"]]
|
||||
error: Required[AnthropicErrorDetail]
|
||||
request_id: str
|
||||
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
Interface for Anthropic's messages API
|
||||
|
||||
Use this to call LLMs in Anthropic /messages Request/Response format
|
||||
|
||||
This is an __init__.py file to allow the following interface
|
||||
|
||||
- litellm.messages.acreate
|
||||
- litellm.messages.create
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any, AsyncIterator, Coroutine, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages as _async_anthropic_messages,
|
||||
)
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.handler import (
|
||||
anthropic_messages_handler as _sync_anthropic_messages,
|
||||
)
|
||||
from litellm.types.llms.anthropic_messages.anthropic_response import (
|
||||
AnthropicMessagesResponse,
|
||||
)
|
||||
|
||||
|
||||
async def acreate(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
container: Optional[Dict] = None,
|
||||
**kwargs
|
||||
) -> Union[AnthropicMessagesResponse, AsyncIterator]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
container (Dict, optional): Container config with skills for code execution
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
return await _async_anthropic_messages(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
container=container,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def create(
|
||||
max_tokens: int,
|
||||
messages: List[Dict],
|
||||
model: str,
|
||||
metadata: Optional[Dict] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
stream: Optional[bool] = False,
|
||||
system: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
thinking: Optional[Dict] = None,
|
||||
tool_choice: Optional[Dict] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
container: Optional[Dict] = None,
|
||||
**kwargs
|
||||
) -> Union[
|
||||
AnthropicMessagesResponse,
|
||||
AsyncIterator[Any],
|
||||
Coroutine[Any, Any, Union[AnthropicMessagesResponse, AsyncIterator[Any]]],
|
||||
]:
|
||||
"""
|
||||
Async wrapper for Anthropic's messages API
|
||||
|
||||
Args:
|
||||
max_tokens (int): Maximum tokens to generate (required)
|
||||
messages (List[Dict]): List of message objects with role and content (required)
|
||||
model (str): Model name to use (required)
|
||||
metadata (Dict, optional): Request metadata
|
||||
stop_sequences (List[str], optional): Custom stop sequences
|
||||
stream (bool, optional): Whether to stream the response
|
||||
system (str, optional): System prompt
|
||||
temperature (float, optional): Sampling temperature (0.0 to 1.0)
|
||||
thinking (Dict, optional): Extended thinking configuration
|
||||
tool_choice (Dict, optional): Tool choice configuration
|
||||
tools (List[Dict], optional): List of tool definitions
|
||||
top_k (int, optional): Top K sampling parameter
|
||||
top_p (float, optional): Nucleus sampling parameter
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dict: Response from the API
|
||||
"""
|
||||
return _sync_anthropic_messages(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=model,
|
||||
metadata=metadata,
|
||||
stop_sequences=stop_sequences,
|
||||
stream=stream,
|
||||
system=system,
|
||||
temperature=temperature,
|
||||
thinking=thinking,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
container=container,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,116 @@
|
||||
## Use LLM API endpoints in Anthropic Interface
|
||||
|
||||
Note: This is called `anthropic_interface` because `anthropic` is a known python package and was failing mypy type checking.
|
||||
|
||||
|
||||
## Usage
|
||||
---
|
||||
|
||||
### LiteLLM Python SDK
|
||||
|
||||
#### Non-streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
|
||||
Example response:
|
||||
```json
|
||||
{
|
||||
"content": [
|
||||
{
|
||||
"text": "Hi! this is a very short joke",
|
||||
"type": "text"
|
||||
}
|
||||
],
|
||||
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
|
||||
"model": "claude-3-7-sonnet-20250219",
|
||||
"role": "assistant",
|
||||
"stop_reason": "end_turn",
|
||||
"stop_sequence": null,
|
||||
"type": "message",
|
||||
"usage": {
|
||||
"input_tokens": 2095,
|
||||
"output_tokens": 503,
|
||||
"cache_creation_input_tokens": 2095,
|
||||
"cache_read_input_tokens": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Streaming example
|
||||
```python showLineNumbers title="Example using LiteLLM Python SDK"
|
||||
import litellm
|
||||
response = await litellm.anthropic.messages.acreate(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
api_key=api_key,
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
stream=True,
|
||||
)
|
||||
async for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
### LiteLLM Proxy Server
|
||||
|
||||
|
||||
1. Setup config.yaml
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: anthropic-claude
|
||||
litellm_params:
|
||||
model: claude-3-7-sonnet-latest
|
||||
```
|
||||
|
||||
2. Start proxy
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test it!
|
||||
|
||||
<Tabs>
|
||||
<TabItem label="Anthropic Python SDK" value="python">
|
||||
|
||||
```python showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
import anthropic
|
||||
|
||||
# point anthropic sdk to litellm proxy
|
||||
client = anthropic.Anthropic(
|
||||
base_url="http://0.0.0.0:4000",
|
||||
api_key="sk-1234",
|
||||
)
|
||||
|
||||
response = client.messages.create(
|
||||
messages=[{"role": "user", "content": "Hello, can you tell me a short joke?"}],
|
||||
model="anthropic/claude-3-haiku-20240307",
|
||||
max_tokens=100,
|
||||
)
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem label="curl" value="curl">
|
||||
|
||||
```bash showLineNumbers title="Example using LiteLLM Proxy Server"
|
||||
curl -L -X POST 'http://0.0.0.0:4000/v1/messages' \
|
||||
-H 'content-type: application/json' \
|
||||
-H 'x-api-key: $LITELLM_API_KEY' \
|
||||
-H 'anthropic-version: 2023-06-01' \
|
||||
-d '{
|
||||
"model": "anthropic-claude",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello, can you tell me a short joke?"
|
||||
}
|
||||
],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
1484
llm-gateway-competitors/litellm-wheel-src/litellm/assistants/main.py
Normal file
1484
llm-gateway-competitors/litellm-wheel-src/litellm/assistants/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,161 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..exceptions import UnsupportedParamsError
|
||||
from ..types.llms.openai import *
|
||||
|
||||
|
||||
def get_optional_params_add_message(
|
||||
role: Optional[str],
|
||||
content: Optional[
|
||||
Union[
|
||||
str,
|
||||
List[
|
||||
Union[
|
||||
MessageContentTextObject,
|
||||
MessageContentImageFileObject,
|
||||
MessageContentImageURLObject,
|
||||
]
|
||||
],
|
||||
]
|
||||
],
|
||||
attachments: Optional[List[Attachment]],
|
||||
metadata: Optional[dict],
|
||||
custom_llm_provider: str,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Azure doesn't support 'attachments' for creating a message
|
||||
|
||||
Reference - https://learn.microsoft.com/en-us/azure/ai-services/openai/assistants-reference-messages?tabs=python#create-message
|
||||
"""
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"role": None,
|
||||
"content": None,
|
||||
"attachments": None,
|
||||
"metadata": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message="k={}, not supported by {}. Supported params={}. To drop it from the call, set `litellm.drop_params = True`.".format(
|
||||
k, custom_llm_provider, supported_params
|
||||
),
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if custom_llm_provider == "openai":
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "azure":
|
||||
supported_params = (
|
||||
litellm.AzureOpenAIAssistantsAPIConfig().get_supported_openai_create_message_params()
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.AzureOpenAIAssistantsAPIConfig().map_openai_params_create_message_params(
|
||||
non_default_params=non_default_params, optional_params=optional_params
|
||||
)
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
|
||||
|
||||
def get_optional_params_image_gen(
|
||||
n: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
response_format: Optional[str] = None,
|
||||
size: Optional[str] = None,
|
||||
style: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
# retrieve all parameters passed to the function
|
||||
passed_params = locals()
|
||||
custom_llm_provider = passed_params.pop("custom_llm_provider")
|
||||
special_params = passed_params.pop("kwargs")
|
||||
for k, v in special_params.items():
|
||||
passed_params[k] = v
|
||||
|
||||
default_params = {
|
||||
"n": None,
|
||||
"quality": None,
|
||||
"response_format": None,
|
||||
"size": None,
|
||||
"style": None,
|
||||
"user": None,
|
||||
}
|
||||
|
||||
non_default_params = {
|
||||
k: v
|
||||
for k, v in passed_params.items()
|
||||
if (k in default_params and v != default_params[k])
|
||||
}
|
||||
optional_params = {}
|
||||
|
||||
## raise exception if non-default value passed for non-openai/azure embedding calls
|
||||
def _check_valid_arg(supported_params):
|
||||
if len(non_default_params.keys()) > 0:
|
||||
keys = list(non_default_params.keys())
|
||||
for k in keys:
|
||||
if (
|
||||
litellm.drop_params is True and k not in supported_params
|
||||
): # drop the unsupported non-default values
|
||||
non_default_params.pop(k, None)
|
||||
elif k not in supported_params:
|
||||
raise UnsupportedParamsError(
|
||||
status_code=500,
|
||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||
)
|
||||
return non_default_params
|
||||
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
):
|
||||
optional_params = non_default_params
|
||||
elif custom_llm_provider == "bedrock":
|
||||
supported_params = ["size"]
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if size is not None:
|
||||
width, height = size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
supported_params = ["n"]
|
||||
"""
|
||||
All params here: https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/imagegeneration?project=adroit-crow-413218
|
||||
"""
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if n is not None:
|
||||
optional_params["sampleCount"] = int(n)
|
||||
|
||||
for k in passed_params.keys():
|
||||
if k not in default_params.keys():
|
||||
optional_params[k] = passed_params[k]
|
||||
return optional_params
|
||||
@@ -0,0 +1,11 @@
|
||||
# Implementation of `litellm.batch_completion`, `litellm.batch_completion_models`, `litellm.batch_completion_models_all_responses`
|
||||
|
||||
Doc: https://docs.litellm.ai/docs/completion/batching
|
||||
|
||||
|
||||
LiteLLM Python SDK allows you to:
|
||||
1. `litellm.batch_completion` Batch litellm.completion function for a given model.
|
||||
2. `litellm.batch_completion_models` Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
3. `litellm.batch_completion_models_all_responses` Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
@@ -0,0 +1,273 @@
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.utils import get_optional_params
|
||||
|
||||
from ..llms.vllm.completion import handler as vllm_handler
|
||||
|
||||
|
||||
def batch_completion(
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
functions: Optional[List] = None,
|
||||
function_call: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stop=None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
user: Optional[str] = None,
|
||||
deployment_id=None,
|
||||
request_timeout: Optional[int] = None,
|
||||
timeout: Optional[int] = 600,
|
||||
max_workers: Optional[int] = 100,
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Batch litellm.completion function for a given model.
|
||||
|
||||
Args:
|
||||
model (str): The model to use for generating completions.
|
||||
messages (List, optional): List of messages to use as input for generating completions. Defaults to [].
|
||||
functions (List, optional): List of functions to use as input for generating completions. Defaults to [].
|
||||
function_call (str, optional): The function call to use as input for generating completions. Defaults to "".
|
||||
temperature (float, optional): The temperature parameter for generating completions. Defaults to None.
|
||||
top_p (float, optional): The top-p parameter for generating completions. Defaults to None.
|
||||
n (int, optional): The number of completions to generate. Defaults to None.
|
||||
stream (bool, optional): Whether to stream completions or not. Defaults to None.
|
||||
stop (optional): The stop parameter for generating completions. Defaults to None.
|
||||
max_tokens (float, optional): The maximum number of tokens to generate. Defaults to None.
|
||||
presence_penalty (float, optional): The presence penalty for generating completions. Defaults to None.
|
||||
frequency_penalty (float, optional): The frequency penalty for generating completions. Defaults to None.
|
||||
logit_bias (dict, optional): The logit bias for generating completions. Defaults to {}.
|
||||
user (str, optional): The user string for generating completions. Defaults to "".
|
||||
deployment_id (optional): The deployment ID for generating completions. Defaults to None.
|
||||
request_timeout (int, optional): The request timeout for generating completions. Defaults to None.
|
||||
max_workers (int,optional): The maximum number of threads to use for parallel processing.
|
||||
|
||||
Returns:
|
||||
list: A list of completion results.
|
||||
"""
|
||||
args = locals()
|
||||
|
||||
batch_messages = messages
|
||||
completions = []
|
||||
model = model
|
||||
custom_llm_provider = None
|
||||
if model.split("/", 1)[0] in litellm.provider_list:
|
||||
custom_llm_provider = model.split("/", 1)[0]
|
||||
model = model.split("/", 1)[1]
|
||||
if custom_llm_provider == "vllm":
|
||||
optional_params = get_optional_params(
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream or False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
user=user,
|
||||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
results = vllm_handler.batch_completions(
|
||||
model=model,
|
||||
messages=batch_messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
# all non VLLM models for batch completion models
|
||||
else:
|
||||
|
||||
def chunks(lst, n):
|
||||
"""Yield successive n-sized chunks from lst."""
|
||||
for i in range(0, len(lst), n):
|
||||
yield lst[i : i + n]
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
for sub_batch in chunks(batch_messages, 100):
|
||||
for message_list in sub_batch:
|
||||
kwargs_modified = args.copy()
|
||||
kwargs_modified.pop("max_workers")
|
||||
kwargs_modified["messages"] = message_list
|
||||
original_kwargs = {}
|
||||
if "kwargs" in kwargs_modified:
|
||||
original_kwargs = kwargs_modified.pop("kwargs")
|
||||
future = executor.submit(
|
||||
litellm.completion, **kwargs_modified, **original_kwargs
|
||||
)
|
||||
completions.append(future)
|
||||
|
||||
# Retrieve the results from the futures
|
||||
# results = [future.result() for future in completions]
|
||||
# return exceptions if any
|
||||
results = []
|
||||
for future in completions:
|
||||
try:
|
||||
results.append(future.result())
|
||||
except Exception as exc:
|
||||
results.append(exc)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# send one request to multiple models
|
||||
# return as soon as one of the llms responds
|
||||
def batch_completion_models(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return the response
|
||||
as soon as one of the models responds.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
str or None: The response from one of the language models, or None if no response is received.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and returns the response from the first model that responds.
|
||||
"""
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for model in models:
|
||||
futures[model] = executor.submit(
|
||||
litellm.completion, *args, model=model, **kwargs
|
||||
)
|
||||
|
||||
for model, future in sorted(
|
||||
futures.items(), key=lambda x: models.index(x[0])
|
||||
):
|
||||
if future.result() is not None:
|
||||
return future.result()
|
||||
elif "deployments" in kwargs:
|
||||
deployments = kwargs["deployments"]
|
||||
kwargs.pop("deployments")
|
||||
kwargs.pop("model_list")
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
futures = {}
|
||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||
for deployment in deployments:
|
||||
for key in kwargs.keys():
|
||||
if (
|
||||
key not in deployment
|
||||
): # don't override deployment values e.g. model name, api base, etc.
|
||||
deployment[key] = kwargs[key]
|
||||
kwargs = {**deployment, **nested_kwargs}
|
||||
futures[deployment["model"]] = executor.submit(
|
||||
litellm.completion, **kwargs
|
||||
)
|
||||
|
||||
while futures:
|
||||
# wait for the first returned future
|
||||
print_verbose("\n\n waiting for next result\n\n")
|
||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||
print_verbose(f"done list\n{done}")
|
||||
for future in done:
|
||||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception:
|
||||
# if model 1 fails, continue with response from model 2, model3
|
||||
print_verbose(
|
||||
"\n\ngot an exception, ignoring, removing from futures"
|
||||
)
|
||||
print_verbose(futures)
|
||||
new_futures = {}
|
||||
for key, value in futures.items():
|
||||
if future == value:
|
||||
print_verbose(f"removing key{key}")
|
||||
continue
|
||||
else:
|
||||
new_futures[key] = value
|
||||
futures = new_futures
|
||||
print_verbose(f"new futures{futures}")
|
||||
continue
|
||||
|
||||
print_verbose("\n\ndone looping through futures\n\n")
|
||||
print_verbose(futures)
|
||||
|
||||
return None # If no response is received from any model
|
||||
|
||||
|
||||
def batch_completion_models_all_responses(*args, **kwargs):
|
||||
"""
|
||||
Send a request to multiple language models concurrently and return a list of responses
|
||||
from all models that respond.
|
||||
|
||||
Args:
|
||||
*args: Variable-length positional arguments passed to the completion function.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- models (str or list of str): The language models to send requests to.
|
||||
- Other keyword arguments to be passed to the completion function.
|
||||
|
||||
Returns:
|
||||
list: A list of responses from the language models that responded.
|
||||
|
||||
Note:
|
||||
This function utilizes a ThreadPoolExecutor to parallelize requests to multiple models.
|
||||
It sends requests concurrently and collects responses from all models that respond.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
# ANSI escape codes for colored output
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
if "models" in kwargs:
|
||||
models = kwargs.pop("models")
|
||||
else:
|
||||
raise Exception("'models' param not in kwargs")
|
||||
|
||||
if isinstance(models, str):
|
||||
models = [models]
|
||||
elif isinstance(models, (list, tuple)):
|
||||
models = list(models)
|
||||
else:
|
||||
raise TypeError("'models' must be a string or list of strings")
|
||||
|
||||
if len(models) == 0:
|
||||
return []
|
||||
|
||||
responses = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
futures = [
|
||||
executor.submit(litellm.completion, *args, model=model, **kwargs)
|
||||
for model in models
|
||||
]
|
||||
|
||||
for future in futures:
|
||||
try:
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
responses.append(result)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"batch_completion_models_all_responses: model request failed: {str(e)}"
|
||||
)
|
||||
continue
|
||||
|
||||
return responses
|
||||
@@ -0,0 +1,442 @@
|
||||
import json
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import CallTypes, ModelInfo, Usage
|
||||
from litellm.utils import token_counter
|
||||
|
||||
|
||||
async def calculate_batch_cost_and_usage(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
],
|
||||
model_name: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""
|
||||
Calculate the cost and usage of a batch.
|
||||
|
||||
Args:
|
||||
model_info: Optional deployment-level model info with custom batch
|
||||
pricing. Threaded through to batch_cost_calculator so that
|
||||
deployment-specific pricing (e.g. input_cost_per_token_batches)
|
||||
is used instead of the global cost map.
|
||||
"""
|
||||
batch_cost = _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
model_name=model_name,
|
||||
model_info=model_info,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
batch_models = _get_batch_models_from_file_content(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
],
|
||||
model_name: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Tuple[float, Usage, List[str]]:
|
||||
"""Helper function to process a completed batch and handle logging
|
||||
|
||||
Args:
|
||||
batch: The batch object
|
||||
custom_llm_provider: The LLM provider
|
||||
model_name: Optional model name
|
||||
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
||||
"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
batch, custom_llm_provider, litellm_params=litellm_params
|
||||
)
|
||||
|
||||
# Calculate costs and usage
|
||||
batch_cost = _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
model_name=model_name,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_name=model_name,
|
||||
)
|
||||
|
||||
batch_models = _get_batch_models_from_file_content(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
|
||||
return batch_cost, batch_usage, batch_models
|
||||
|
||||
|
||||
def _get_batch_models_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
model_name: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the models from the file content
|
||||
"""
|
||||
if model_name:
|
||||
return [model_name]
|
||||
batch_models = []
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
_model = _response_body.get("model")
|
||||
if _model:
|
||||
batch_models.append(_model)
|
||||
return batch_models
|
||||
|
||||
|
||||
def _batch_cost_calculator(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a batch based on the output file id
|
||||
"""
|
||||
# Handle Vertex AI with specialized method
|
||||
if custom_llm_provider == "vertex_ai" and model_name:
|
||||
batch_cost, _ = calculate_vertex_ai_batch_cost_and_usage(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
verbose_logger.debug("vertex_ai_total_cost=%s", batch_cost)
|
||||
return batch_cost
|
||||
|
||||
# For other providers, use the existing logic
|
||||
total_cost = _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_info=model_info,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
|
||||
|
||||
def calculate_vertex_ai_batch_cost_and_usage(
|
||||
vertex_ai_batch_responses: List[dict],
|
||||
model_name: Optional[str] = None,
|
||||
) -> Tuple[float, Usage]:
|
||||
"""
|
||||
Calculate both cost and usage from Vertex AI batch responses.
|
||||
|
||||
Vertex AI batch output lines have format:
|
||||
{"request": ..., "status": "", "response": {"candidates": [...], "usageMetadata": {...}}}
|
||||
|
||||
usageMetadata contains promptTokenCount, candidatesTokenCount, totalTokenCount.
|
||||
"""
|
||||
from litellm.cost_calculator import batch_cost_calculator
|
||||
|
||||
total_cost = 0.0
|
||||
total_tokens = 0
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
actual_model_name = model_name or "gemini-2.0-flash-001"
|
||||
|
||||
for response in vertex_ai_batch_responses:
|
||||
response_body = response.get("response")
|
||||
if response_body is None:
|
||||
continue
|
||||
|
||||
usage_metadata = response_body.get("usageMetadata", {})
|
||||
_prompt = usage_metadata.get("promptTokenCount", 0) or 0
|
||||
_completion = usage_metadata.get("candidatesTokenCount", 0) or 0
|
||||
_total = usage_metadata.get("totalTokenCount", 0) or (_prompt + _completion)
|
||||
|
||||
line_usage = Usage(
|
||||
prompt_tokens=_prompt,
|
||||
completion_tokens=_completion,
|
||||
total_tokens=_total,
|
||||
)
|
||||
|
||||
try:
|
||||
p_cost, c_cost = batch_cost_calculator(
|
||||
usage=line_usage,
|
||||
model=actual_model_name,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
total_cost += p_cost + c_cost
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
"vertex_ai batch cost calculation error for line: %s", str(e)
|
||||
)
|
||||
|
||||
prompt_tokens += _prompt
|
||||
completion_tokens += _completion
|
||||
total_tokens += _total
|
||||
|
||||
verbose_logger.info(
|
||||
"vertex_ai batch cost: cost=%s, prompt=%d, completion=%d, total=%d",
|
||||
total_cost,
|
||||
prompt_tokens,
|
||||
completion_tokens,
|
||||
total_tokens,
|
||||
)
|
||||
|
||||
return total_cost, Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
async def _get_batch_output_file_content_as_dictionary(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the batch output file content as a list of dictionaries
|
||||
|
||||
Args:
|
||||
batch: The batch object
|
||||
custom_llm_provider: The LLM provider
|
||||
litellm_params: Optional litellm parameters containing credentials (api_key, api_base, etc.)
|
||||
Required for Azure and other providers that need authentication
|
||||
"""
|
||||
from litellm.files.main import afile_content
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
_is_base64_encoded_unified_file_id,
|
||||
)
|
||||
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
|
||||
if batch.output_file_id is None:
|
||||
raise ValueError("Output file id is None cannot retrieve file content")
|
||||
|
||||
file_id = batch.output_file_id
|
||||
is_base64_unified_file_id = _is_base64_encoded_unified_file_id(file_id)
|
||||
if is_base64_unified_file_id:
|
||||
try:
|
||||
file_id = is_base64_unified_file_id.split("llm_output_file_id,")[1].split(
|
||||
";"
|
||||
)[0]
|
||||
verbose_logger.debug(
|
||||
f"Extracted LLM output file ID from unified file ID: {file_id}"
|
||||
)
|
||||
except (IndexError, AttributeError) as e:
|
||||
verbose_logger.error(
|
||||
f"Failed to extract LLM output file ID from unified file ID: {batch.output_file_id}, error: {e}"
|
||||
)
|
||||
|
||||
# Build kwargs for afile_content with credentials from litellm_params
|
||||
file_content_kwargs = {
|
||||
"file_id": file_id,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
|
||||
# Extract and add credentials for file access
|
||||
credentials = _extract_file_access_credentials(litellm_params)
|
||||
file_content_kwargs.update(credentials)
|
||||
|
||||
_file_content = await afile_content(**file_content_kwargs) # type: ignore[reportArgumentType]
|
||||
return _get_file_content_as_dictionary(_file_content.content)
|
||||
|
||||
|
||||
def _extract_file_access_credentials(litellm_params: Optional[dict]) -> dict:
|
||||
"""
|
||||
Extract credentials from litellm_params for file access operations.
|
||||
|
||||
This method extracts relevant authentication and configuration parameters
|
||||
needed for accessing files across different providers (Azure, Vertex AI, etc.).
|
||||
|
||||
Args:
|
||||
litellm_params: Dictionary containing litellm parameters with credentials
|
||||
|
||||
Returns:
|
||||
Dictionary containing only the credentials needed for file access
|
||||
"""
|
||||
credentials = {}
|
||||
|
||||
if litellm_params:
|
||||
# List of credential keys that should be passed to file operations
|
||||
credential_keys = [
|
||||
"api_key",
|
||||
"api_base",
|
||||
"api_version",
|
||||
"organization",
|
||||
"azure_ad_token",
|
||||
"azure_ad_token_provider",
|
||||
"vertex_project",
|
||||
"vertex_location",
|
||||
"vertex_credentials",
|
||||
"timeout",
|
||||
"max_retries",
|
||||
]
|
||||
for key in credential_keys:
|
||||
if key in litellm_params:
|
||||
credentials[key] = litellm_params[key]
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
||||
"""
|
||||
Get the file content as a list of dictionaries from JSON Lines format
|
||||
"""
|
||||
try:
|
||||
_file_content_str = file_content.decode("utf-8")
|
||||
# Split by newlines and parse each line as a separate JSON object
|
||||
json_objects = []
|
||||
for line in _file_content_str.strip().split("\n"):
|
||||
if line: # Skip empty lines
|
||||
json_objects.append(json.loads(line))
|
||||
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
||||
return json_objects
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost of a batch job from the file content
|
||||
"""
|
||||
from litellm.cost_calculator import batch_cost_calculator
|
||||
|
||||
try:
|
||||
total_cost: float = 0.0
|
||||
# parse the file content as json
|
||||
verbose_logger.debug(
|
||||
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
||||
)
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
if model_info is not None:
|
||||
usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
model = _response_body.get("model", "")
|
||||
prompt_cost, completion_cost = batch_cost_calculator(
|
||||
usage=usage,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
model_info=model_info,
|
||||
)
|
||||
total_cost += prompt_cost + completion_cost
|
||||
else:
|
||||
total_cost += litellm.completion_cost(
|
||||
completion_response=_response_body,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
call_type=CallTypes.aretrieve_batch.value,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal[
|
||||
"openai", "azure", "vertex_ai", "hosted_vllm", "anthropic"
|
||||
] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the file content
|
||||
"""
|
||||
# Handle Vertex AI with specialized method
|
||||
if custom_llm_provider == "vertex_ai" and model_name:
|
||||
_, batch_usage = calculate_vertex_ai_batch_cost_and_usage(
|
||||
file_content_dictionary, model_name
|
||||
)
|
||||
return batch_usage
|
||||
|
||||
# For other providers, use the existing logic
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
total_tokens += usage.total_tokens
|
||||
prompt_tokens += usage.prompt_tokens
|
||||
completion_tokens += usage.completion_tokens
|
||||
return Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_input_file_usage(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
model_name: Optional[str] = None,
|
||||
) -> Usage:
|
||||
"""
|
||||
Count the number of tokens in the input file
|
||||
|
||||
Used for batch rate limiting to count the number of tokens in the input file
|
||||
"""
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
|
||||
for _item in file_content_dictionary:
|
||||
body = _item.get("body", {})
|
||||
model = body.get("model", model_name or "")
|
||||
messages = body.get("messages", [])
|
||||
|
||||
if messages:
|
||||
item_tokens = token_counter(model=model, messages=messages)
|
||||
prompt_tokens += item_tokens
|
||||
|
||||
return Usage(
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the response body
|
||||
"""
|
||||
_usage_dict = response_body.get("usage", None) or {}
|
||||
usage: Usage = Usage(**_usage_dict)
|
||||
return usage
|
||||
|
||||
|
||||
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
||||
"""
|
||||
Get the response from the batch job output file
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
_response_body = _response.get("body", None) or {}
|
||||
return _response_body
|
||||
|
||||
|
||||
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
||||
"""
|
||||
Check if the batch job response status == 200
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
return _response.get("status_code", None) == 200
|
||||
1181
llm-gateway-competitors/litellm-wheel-src/litellm/batches/main.py
Normal file
1181
llm-gateway-competitors/litellm-wheel-src/litellm/batches/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,10 @@
|
||||
{
|
||||
"posts": [
|
||||
{
|
||||
"title": "Incident Report: SERVER_ROOT_PATH regression broke UI routing",
|
||||
"description": "How a single line removal caused UI 404s for all deployments using SERVER_ROOT_PATH, and the tests we added to prevent it from happening again.",
|
||||
"date": "2026-02-21",
|
||||
"url": "https://docs.litellm.ai/blog/server-root-path-incident"
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -0,0 +1,230 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | NOT PROXY BUDGET MANAGER |
|
||||
# | proxy budget manager is in proxy_server.py |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.constants import (
|
||||
DAYS_IN_A_MONTH,
|
||||
DAYS_IN_A_WEEK,
|
||||
DAYS_IN_A_YEAR,
|
||||
HOURS_IN_A_DAY,
|
||||
)
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
|
||||
class BudgetManager:
|
||||
def __init__(
|
||||
self,
|
||||
project_name: str,
|
||||
client_type: str = "local",
|
||||
api_base: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
self.client_type = client_type
|
||||
self.project_name = project_name
|
||||
self.api_base = api_base or "https://api.litellm.ai"
|
||||
self.headers = headers or {"Content-Type": "application/json"}
|
||||
## load the data or init the initial dictionaries
|
||||
self.load_data()
|
||||
|
||||
def print_verbose(self, print_statement):
|
||||
try:
|
||||
if litellm.set_verbose:
|
||||
import logging
|
||||
|
||||
logging.info(print_statement)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
if self.client_type == "local":
|
||||
# Check if user dict file exists
|
||||
if os.path.isfile("user_cost.json"):
|
||||
# Load the user dict
|
||||
with open("user_cost.json", "r") as json_file:
|
||||
self.user_dict = json.load(json_file)
|
||||
else:
|
||||
self.print_verbose("User Dictionary not found!")
|
||||
self.user_dict = {}
|
||||
self.print_verbose(f"user dict from local: {self.user_dict}")
|
||||
elif self.client_type == "hosted":
|
||||
# Load the user_dict from hosted db
|
||||
url = self.api_base + "/get_budget"
|
||||
data = {"project_name": self.project_name}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
if response["status"] == "error":
|
||||
self.user_dict = (
|
||||
{}
|
||||
) # assume this means the user dict hasn't been stored yet
|
||||
else:
|
||||
self.user_dict = response["data"]
|
||||
|
||||
def create_budget(
|
||||
self,
|
||||
total_budget: float,
|
||||
user: str,
|
||||
duration: Optional[Literal["daily", "weekly", "monthly", "yearly"]] = None,
|
||||
created_at: float = time.time(),
|
||||
):
|
||||
self.user_dict[user] = {"total_budget": total_budget}
|
||||
if duration is None:
|
||||
return self.user_dict[user]
|
||||
|
||||
if duration == "daily":
|
||||
duration_in_days = 1
|
||||
elif duration == "weekly":
|
||||
duration_in_days = DAYS_IN_A_WEEK
|
||||
elif duration == "monthly":
|
||||
duration_in_days = DAYS_IN_A_MONTH
|
||||
elif duration == "yearly":
|
||||
duration_in_days = DAYS_IN_A_YEAR
|
||||
else:
|
||||
raise ValueError(
|
||||
"""duration needs to be one of ["daily", "weekly", "monthly", "yearly"]"""
|
||||
)
|
||||
self.user_dict[user] = {
|
||||
"total_budget": total_budget,
|
||||
"duration": duration_in_days,
|
||||
"created_at": created_at,
|
||||
"last_updated_at": created_at,
|
||||
}
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return self.user_dict[user]
|
||||
|
||||
def projected_cost(self, model: str, messages: list, user: str):
|
||||
text = "".join(message["content"] for message in messages)
|
||||
prompt_tokens = litellm.token_counter(model=model, text=text)
|
||||
prompt_cost, _ = litellm.cost_per_token(
|
||||
model=model, prompt_tokens=prompt_tokens, completion_tokens=0
|
||||
)
|
||||
current_cost = self.user_dict[user].get("current_cost", 0)
|
||||
projected_cost = prompt_cost + current_cost
|
||||
return projected_cost
|
||||
|
||||
def get_total_budget(self, user: str):
|
||||
return self.user_dict[user]["total_budget"]
|
||||
|
||||
def update_cost(
|
||||
self,
|
||||
user: str,
|
||||
completion_obj: Optional[ModelResponse] = None,
|
||||
model: Optional[str] = None,
|
||||
input_text: Optional[str] = None,
|
||||
output_text: Optional[str] = None,
|
||||
):
|
||||
if model and input_text and output_text:
|
||||
prompt_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": input_text}]
|
||||
)
|
||||
completion_tokens = litellm.token_counter(
|
||||
model=model, messages=[{"role": "user", "content": output_text}]
|
||||
)
|
||||
(
|
||||
prompt_tokens_cost_usd_dollar,
|
||||
completion_tokens_cost_usd_dollar,
|
||||
) = litellm.cost_per_token(
|
||||
model=model,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
cost = prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
elif completion_obj:
|
||||
cost = litellm.completion_cost(completion_response=completion_obj)
|
||||
model = completion_obj[
|
||||
"model"
|
||||
] # if this throws an error try, model = completion_obj['model']
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either a chat completion object or the text response needs to be passed in. Learn more - https://docs.litellm.ai/docs/budget_manager"
|
||||
)
|
||||
|
||||
self.user_dict[user]["current_cost"] = cost + self.user_dict[user].get(
|
||||
"current_cost", 0
|
||||
)
|
||||
if "model_cost" in self.user_dict[user]:
|
||||
self.user_dict[user]["model_cost"][model] = cost + self.user_dict[user][
|
||||
"model_cost"
|
||||
].get(model, 0)
|
||||
else:
|
||||
self.user_dict[user]["model_cost"] = {model: cost}
|
||||
|
||||
self._save_data_thread() # [Non-Blocking] Update persistent storage without blocking execution
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def get_current_cost(self, user):
|
||||
return self.user_dict[user].get("current_cost", 0)
|
||||
|
||||
def get_model_cost(self, user):
|
||||
return self.user_dict[user].get("model_cost", 0)
|
||||
|
||||
def is_valid_user(self, user: str) -> bool:
|
||||
return user in self.user_dict
|
||||
|
||||
def get_users(self):
|
||||
return list(self.user_dict.keys())
|
||||
|
||||
def reset_cost(self, user):
|
||||
self.user_dict[user]["current_cost"] = 0
|
||||
self.user_dict[user]["model_cost"] = {}
|
||||
return {"user": self.user_dict[user]}
|
||||
|
||||
def reset_on_duration(self, user: str):
|
||||
# Get current and creation time
|
||||
last_updated_at = self.user_dict[user]["last_updated_at"]
|
||||
current_time = time.time()
|
||||
|
||||
# Convert duration from days to seconds
|
||||
duration_in_seconds = (
|
||||
self.user_dict[user]["duration"] * HOURS_IN_A_DAY * 60 * 60
|
||||
)
|
||||
|
||||
# Check if duration has elapsed
|
||||
if current_time - last_updated_at >= duration_in_seconds:
|
||||
# Reset cost if duration has elapsed and update the creation time
|
||||
self.reset_cost(user)
|
||||
self.user_dict[user]["last_updated_at"] = current_time
|
||||
self._save_data_thread() # Save the data
|
||||
|
||||
def update_budget_all_users(self):
|
||||
for user in self.get_users():
|
||||
if "duration" in self.user_dict[user]:
|
||||
self.reset_on_duration(user)
|
||||
|
||||
def _save_data_thread(self):
|
||||
thread = threading.Thread(
|
||||
target=self.save_data
|
||||
) # [Non-Blocking]: saves data without blocking execution
|
||||
thread.start()
|
||||
|
||||
def save_data(self):
|
||||
if self.client_type == "local":
|
||||
import json
|
||||
|
||||
# save the user dict
|
||||
with open("user_cost.json", "w") as json_file:
|
||||
json.dump(
|
||||
self.user_dict, json_file, indent=4
|
||||
) # Indent for pretty formatting
|
||||
return {"status": "success"}
|
||||
elif self.client_type == "hosted":
|
||||
url = self.api_base + "/set_budget"
|
||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||
response = litellm.module_level_client.post(
|
||||
url, headers=self.headers, json=data
|
||||
)
|
||||
response = response.json()
|
||||
return response
|
||||
@@ -0,0 +1,41 @@
|
||||
# Caching on LiteLLM
|
||||
|
||||
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
|
||||
|
||||
The following caching mechanisms are supported:
|
||||
|
||||
1. **RedisCache**
|
||||
2. **RedisSemanticCache**
|
||||
3. **QdrantSemanticCache**
|
||||
4. **InMemoryCache**
|
||||
5. **DiskCache**
|
||||
6. **S3Cache**
|
||||
7. **AzureBlobCache**
|
||||
8. **DualCache** (updates both Redis and an in-memory cache simultaneously)
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
litellm/caching/
|
||||
├── base_cache.py
|
||||
├── caching.py
|
||||
├── caching_handler.py
|
||||
├── disk_cache.py
|
||||
├── dual_cache.py
|
||||
├── in_memory_cache.py
|
||||
├── qdrant_semantic_cache.py
|
||||
├── redis_cache.py
|
||||
├── redis_semantic_cache.py
|
||||
├── s3_cache.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
|
||||
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .caching import Cache, LiteLLMCacheType
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
from .gcs_cache import GCSCache
|
||||
@@ -0,0 +1,30 @@
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def lru_cache_wrapper(
|
||||
maxsize: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""
|
||||
Wrapper for lru_cache that caches success and exceptions
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@lru_cache(maxsize=maxsize)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return ("success", f(*args, **kwargs))
|
||||
except Exception as e:
|
||||
return ("error", e)
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
result = wrapper(*args, **kwargs)
|
||||
if result[0] == "error":
|
||||
raise result[1]
|
||||
return result[1]
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Azure Blob Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import suppress
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class AzureBlobCache(BaseCache):
|
||||
def __init__(self, account_url, container) -> None:
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.identity.aio import (
|
||||
DefaultAzureCredential as AsyncDefaultAzureCredential,
|
||||
)
|
||||
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
|
||||
|
||||
self.container_client = BlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
self.async_container_client = AsyncBlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=AsyncDefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
|
||||
with suppress(ResourceExistsError):
|
||||
self.container_client.create_container()
|
||||
|
||||
def set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
self.container_client.upload_blob(key, serialized_value)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
await self.async_container_client.upload_blob(
|
||||
key, serialized_value, overwrite=True
|
||||
)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
as_bytes = self.container_client.download_blob(key).readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
blob = await self.async_container_client.download_blob(key)
|
||||
as_bytes = await blob.readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
def flush_cache(self) -> None:
|
||||
for blob in self.container_client.walk_blobs():
|
||||
self.container_client.delete_blob(blob.name)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.container_client.close()
|
||||
await self.async_container_client.close()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs) -> None:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Base Cache implementation. All cache implementations should inherit from this class.
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||
kwargs_ttl: Optional[int] = kwargs.get("ttl")
|
||||
if kwargs_ttl is not None:
|
||||
try:
|
||||
return int(kwargs_ttl)
|
||||
except ValueError:
|
||||
return self.default_ttl
|
||||
return self.default_ttl
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the cache connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,926 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import EmbeddingResponse, all_litellm_params
|
||||
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache # noqa
|
||||
from .gcs_cache import GCSCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheMode(str, Enum):
|
||||
default_on = "default_on"
|
||||
default_off = "default_off"
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
# s3 Bucket, boto3 configuration
|
||||
azure_account_url: Optional[str] = None,
|
||||
azure_blob_container: Optional[str] = None,
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
gcs_bucket_name: Optional[str] = None,
|
||||
gcs_path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
redis_semantic_cache_index_name: Optional[str] = None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
qdrant_semantic_cache_vector_size: Optional[int] = None,
|
||||
# GCP IAM authentication parameters
|
||||
gcp_service_account: Optional[str] = None,
|
||||
gcp_ssl_ca_certs: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# GCS Cache Args
|
||||
gcs_bucket_name (str, optional): The bucket name for the gcs cache. Defaults to None.
|
||||
gcs_path_service_account (str, optional): Path to the service account json.
|
||||
gcs_path (str, optional): Folder path inside the bucket to store cache files.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
# Check REDIS_CLUSTER_NODES env var if no explicit startup nodes
|
||||
if not redis_startup_nodes:
|
||||
_env_cluster_nodes = litellm.get_secret("REDIS_CLUSTER_NODES")
|
||||
if _env_cluster_nodes is not None and isinstance(
|
||||
_env_cluster_nodes, str
|
||||
):
|
||||
redis_startup_nodes = json.loads(_env_cluster_nodes)
|
||||
|
||||
if redis_startup_nodes:
|
||||
# Only pass GCP parameters if they are provided
|
||||
cluster_kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"password": password,
|
||||
"redis_flush_size": redis_flush_size,
|
||||
"startup_nodes": redis_startup_nodes,
|
||||
**kwargs,
|
||||
}
|
||||
if gcp_service_account is not None:
|
||||
cluster_kwargs["gcp_service_account"] = gcp_service_account
|
||||
if gcp_ssl_ca_certs is not None:
|
||||
cluster_kwargs["gcp_ssl_ca_certs"] = gcp_ssl_ca_certs
|
||||
|
||||
self.cache: BaseCache = RedisClusterCache(**cluster_kwargs)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
index_name=redis_semantic_cache_index_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
vector_size=qdrant_semantic_cache_vector_size,
|
||||
)
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
s3_api_version=s3_api_version,
|
||||
s3_use_ssl=s3_use_ssl,
|
||||
s3_verify=s3_verify,
|
||||
s3_endpoint_url=s3_endpoint_url,
|
||||
s3_aws_access_key_id=s3_aws_access_key_id,
|
||||
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
||||
s3_aws_session_token=s3_aws_session_token,
|
||||
s3_config=s3_config,
|
||||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.GCS:
|
||||
self.cache = GCSCache(
|
||||
bucket_name=gcs_bucket_name,
|
||||
path_service_account=gcs_path_service_account,
|
||||
gcs_path=gcs_path,
|
||||
)
|
||||
elif type == LiteLLMCacheType.AZURE_BLOB:
|
||||
self.cache = AzureBlobCache(
|
||||
account_url=azure_account_url,
|
||||
container=azure_blob_container,
|
||||
)
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_provider_specific_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
@staticmethod
|
||||
def _get_hashed_cache_key(cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
||||
namespace = (
|
||||
dynamic_cache_control.get("namespace")
|
||||
or kwargs.get("metadata", {}).get("redis_namespace")
|
||||
or self.namespace
|
||||
)
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
||||
max_age = (
|
||||
cache_control_args.get("s-maxage")
|
||||
or cache_control_args.get("s-max-age")
|
||||
or float("inf")
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = dynamic_cache_object.get_cache(
|
||||
cache_key, messages=messages
|
||||
)
|
||||
else:
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(
|
||||
self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = await dynamic_cache_object.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
else:
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## DEFAULT TTL ##
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache(
|
||||
cache_key, cached_data, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def _convert_to_cached_embedding(
|
||||
self, embedding_response: Any, model: Optional[str]
|
||||
) -> CachedEmbedding:
|
||||
"""
|
||||
Convert any embedding response into the standardized CachedEmbedding TypedDict format.
|
||||
"""
|
||||
try:
|
||||
if isinstance(embedding_response, dict):
|
||||
return {
|
||||
"embedding": embedding_response.get("embedding"),
|
||||
"index": embedding_response.get("index"),
|
||||
"object": embedding_response.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
elif hasattr(embedding_response, "model_dump"):
|
||||
data = embedding_response.model_dump()
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
else:
|
||||
data = vars(embedding_response)
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing expected key in embedding response: {e}")
|
||||
|
||||
def add_embedding_response_to_cache(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
input: str,
|
||||
kwargs: dict,
|
||||
idx_in_result_data: int = 0,
|
||||
) -> Tuple[str, dict, dict]:
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx_in_result_data]
|
||||
|
||||
# Always convert to properly typed CachedEmbedding
|
||||
model_name = result.model
|
||||
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
||||
embedding_response, model_name
|
||||
)
|
||||
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return cache_key, cached_data, kwargs
|
||||
|
||||
async def async_add_cache_pipeline(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
if isinstance(kwargs["input"], list):
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
(
|
||||
cache_key,
|
||||
cached_data,
|
||||
kwargs,
|
||||
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
elif isinstance(kwargs["input"], str):
|
||||
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
|
||||
result, kwargs["input"], kwargs
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
If cache is default_on then this is True
|
||||
If cache is default_off then this is only true when user has opted in to use cache
|
||||
"""
|
||||
if self.mode == CacheMode.default_on:
|
||||
return True
|
||||
|
||||
# when mode == default_off -> Cache is opt in only
|
||||
_cache = kwargs.get("cache", None)
|
||||
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
||||
if _cache and isinstance(_cache, dict):
|
||||
if _cache.get("use-cache", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
All cache types now support async operations
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
print_verbose("LiteLLM: Enabling Cache")
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments for the cache.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
print_verbose("LiteLLM: Updating Cache")
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def disable_cache():
|
||||
"""
|
||||
Disable the cache used by LiteLLM.
|
||||
|
||||
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from contextlib import suppress
|
||||
|
||||
print_verbose("LiteLLM: Disabling Cache")
|
||||
with suppress(ValueError):
|
||||
litellm.input_callback.remove("cache")
|
||||
litellm.success_callback.remove("cache")
|
||||
litellm._async_success_callback.remove("cache")
|
||||
|
||||
litellm.cache = None
|
||||
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
try:
|
||||
import diskcache as dc
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install litellm with `litellm[caching]` to use disk caching."
|
||||
) from e
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response) # type: ignore
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
|
||||
|
||||
Has 4 primary methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedSizeOrderedDict(OrderedDict):
|
||||
def __init__(self, *args, max_size=100, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# If inserting a new key exceeds max size, remove the oldest item
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_redis_ttl: Optional[float] = None,
|
||||
default_redis_batch_cache_expiry: Optional[float] = None,
|
||||
default_max_redis_batch_cache_size: int = DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self._last_redis_batch_access_time_lock = Lock()
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
||||
|
||||
def update_cache_ttl(
|
||||
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
|
||||
):
|
||||
if default_in_memory_ttl is not None:
|
||||
self.default_in_memory_ttl = default_in_memory_ttl
|
||||
|
||||
if default_redis_ttl is not None:
|
||||
self.default_redis_ttl = default_redis_ttl
|
||||
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
received_args = locals()
|
||||
received_args.pop("self")
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_batch_get_cache(**received_args)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(
|
||||
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||
)
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||
key, **kwargs
|
||||
)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result, **kwargs
|
||||
)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def _reserve_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> Tuple[List[str], Dict[str, Optional[float]]]:
|
||||
"""
|
||||
Atomically choose keys to fetch from Redis and reserve their access time.
|
||||
This prevents check-then-act races under concurrent async callers.
|
||||
"""
|
||||
sublist_keys: List[str] = []
|
||||
previous_access_times: Dict[str, Optional[float]] = {}
|
||||
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, value in zip(keys, result):
|
||||
if value is not None:
|
||||
continue
|
||||
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
previous_access_times[key] = self.last_redis_batch_access_time.get(
|
||||
key
|
||||
)
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
return sublist_keys, previous_access_times
|
||||
|
||||
def _rollback_redis_batch_key_reservations(
|
||||
self, previous_access_times: Dict[str, Optional[float]]
|
||||
) -> None:
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, previous_time in previous_access_times.items():
|
||||
if previous_time is None:
|
||||
self.last_redis_batch_access_time.pop(key, None)
|
||||
else:
|
||||
self.last_redis_batch_access_time[key] = previous_time
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
result = [None] * len(keys)
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||
keys, **kwargs
|
||||
)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
current_time = time.time()
|
||||
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
|
||||
current_time, keys, result
|
||||
)
|
||||
|
||||
# Only hit Redis if enough time has passed since last access.
|
||||
if len(sublist_keys) > 0:
|
||||
try:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
except Exception:
|
||||
# Do not throttle subsequent callers if the Redis read fails.
|
||||
self._rollback_redis_batch_key_reservations(
|
||||
previous_access_times
|
||||
)
|
||||
raise
|
||||
|
||||
# Short-circuit if redis_result is None or contains only None values
|
||||
if redis_result is None or all(
|
||||
v is None for v in redis_result.values()
|
||||
):
|
||||
return result
|
||||
|
||||
# Pre-compute key-to-index mapping for O(1) lookup
|
||||
key_to_index = {key: i for i, key in enumerate(keys)}
|
||||
|
||||
# Update both result and in-memory cache in a single loop
|
||||
for key, value in redis_result.items():
|
||||
result[key_to_index[key]] = value
|
||||
|
||||
if value is not None and self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Batch write values to the cache
|
||||
"""
|
||||
print_verbose(
|
||||
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_increment_cache(
|
||||
self,
|
||||
key,
|
||||
value: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - float - the value you want to increment by
|
||||
|
||||
Returns - float - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: float = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment(
|
||||
key,
|
||||
value,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=kwargs.get("ttl", None),
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_increment_cache_pipeline(
|
||||
self,
|
||||
increment_list: List["RedisPipelineIncrementOperation"],
|
||||
local_only: bool = False,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
**kwargs,
|
||||
) -> Optional[List[float]]:
|
||||
try:
|
||||
result: Optional[List[float]] = None
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, local_only: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Add value to a set
|
||||
|
||||
Key - the key in cache
|
||||
|
||||
Value - str - the value you want to add to the set
|
||||
|
||||
Returns - None
|
||||
"""
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
_ = await self.redis_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e # don't log, if exception is raised
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
def delete_cache(self, key):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.delete_cache(key)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
await self.redis_cache.async_delete_cache(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache or redis
|
||||
"""
|
||||
ttl = await self.in_memory_cache.async_get_ttl(key)
|
||||
if ttl is None and self.redis_cache is not None:
|
||||
ttl = await self.redis_cache.async_get_ttl(key)
|
||||
return ttl
|
||||
@@ -0,0 +1,113 @@
|
||||
"""GCS Cache implementation
|
||||
Supports syncing responses to Google Cloud Storage Buckets using HTTP requests.
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class GCSCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: Optional[str] = None,
|
||||
path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.bucket_name = bucket_name or GCSBucketBase(bucket_name=None).BUCKET_NAME
|
||||
self.path_service_account = (
|
||||
path_service_account
|
||||
or GCSBucketBase(bucket_name=None).path_service_account_json
|
||||
)
|
||||
self.key_prefix = gcs_path.rstrip("/") + "/" if gcs_path else ""
|
||||
# create httpx clients
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_client = _get_httpx_client()
|
||||
|
||||
def _construct_headers(self) -> dict:
|
||||
base = GCSBucketBase(bucket_name=self.bucket_name)
|
||||
base.path_service_account_json = self.path_service_account
|
||||
base.BUCKET_NAME = self.bucket_name
|
||||
return base.sync_construct_request_headers()
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - GCS. Key={key}. Value={value}")
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
self.sync_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(f"GCS Caching: set_cache() - Got exception from GCS: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
await self.async_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"GCS Caching: async_set_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = self.sync_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
cached_response = json.loads(response.text)
|
||||
verbose_logger.debug(
|
||||
f"Got GCS Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = await self.async_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return json.loads(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: async_get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
In-Memory Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import heapq
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
max_size_in_memory: Optional[int] = 200,
|
||||
default_ttl: Optional[
|
||||
int
|
||||
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
|
||||
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
|
||||
):
|
||||
"""
|
||||
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
|
||||
"""
|
||||
self.max_size_in_memory = (
|
||||
max_size_in_memory if max_size_in_memory is not None else 200
|
||||
) # set an upper bound of 200 items in-memory
|
||||
self.default_ttl = default_ttl or 600
|
||||
self.max_size_per_item = (
|
||||
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
) # 1MB = 1024KB
|
||||
|
||||
# in-memory cache
|
||||
self.cache_dict: dict = {}
|
||||
self.ttl_dict: dict = {}
|
||||
self.expiration_heap: list[tuple[float, str]] = []
|
||||
|
||||
def check_value_size(self, value: Any):
|
||||
"""
|
||||
Check if value size exceeds max_size_per_item (1MB)
|
||||
Returns True if value size is acceptable, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Fast path for common primitive types that are typically small
|
||||
if (
|
||||
isinstance(value, (bool, int, float, str))
|
||||
and len(str(value))
|
||||
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
): # Conservative estimate
|
||||
return True
|
||||
|
||||
# Direct size check for bytes objects
|
||||
if isinstance(value, bytes):
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
# Handle special types without full conversion when possible
|
||||
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
|
||||
size = value.__sizeof__() / 1024
|
||||
return size <= self.max_size_per_item
|
||||
|
||||
# Fallback for complex types
|
||||
if isinstance(value, BaseModel) and hasattr(
|
||||
value, "model_dump"
|
||||
): # Pydantic v2
|
||||
value = value.model_dump()
|
||||
elif hasattr(value, "isoformat"): # datetime objects
|
||||
return True # datetime strings are always small
|
||||
|
||||
# Only convert to JSON if absolutely necessary
|
||||
if not isinstance(value, (str, bytes)):
|
||||
value = json.dumps(value, default=str)
|
||||
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_key_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a specific key is expired
|
||||
"""
|
||||
return key in self.ttl_dict and time.time() > self.ttl_dict[key]
|
||||
|
||||
def _remove_key(self, key: str) -> None:
|
||||
"""
|
||||
Remove a key from both cache_dict and ttl_dict
|
||||
"""
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
def evict_cache(self):
|
||||
"""
|
||||
Eviction policy:
|
||||
1. First, remove expired items from ttl_dict and cache_dict
|
||||
2. If cache is still at or above max_size_in_memory, evict items with earliest expiration times
|
||||
|
||||
|
||||
This guarantees the following:
|
||||
- 1. When item ttl not set: At minimum each item will remain in memory for the default ttl
|
||||
- 2. When ttl is set: the item will remain in memory for at least that amount of time, unless cache size requires eviction
|
||||
- 3. the size of in-memory cache is bounded
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Step 1: Remove expired or outdated items
|
||||
while self.expiration_heap:
|
||||
expiration_time, key = self.expiration_heap[0]
|
||||
|
||||
# Case 1: Heap entry is outdated
|
||||
if expiration_time != self.ttl_dict.get(key):
|
||||
heapq.heappop(self.expiration_heap)
|
||||
# Case 2: Entry is valid but expired
|
||||
elif expiration_time <= current_time:
|
||||
heapq.heappop(self.expiration_heap)
|
||||
self._remove_key(key)
|
||||
else:
|
||||
# Case 3: Entry is valid and not expired
|
||||
break
|
||||
|
||||
# Step 2: Evict if cache is still full
|
||||
while len(self.cache_dict) >= self.max_size_in_memory:
|
||||
expiration_time, key = heapq.heappop(self.expiration_heap)
|
||||
# Skip if key was removed or updated
|
||||
if self.ttl_dict.get(key) == expiration_time:
|
||||
self._remove_key(key)
|
||||
|
||||
# de-reference the removed item
|
||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||
|
||||
def allow_ttl_override(self, key: str) -> bool:
|
||||
"""
|
||||
Check if ttl is set for a key
|
||||
"""
|
||||
ttl_time = self.ttl_dict.get(key)
|
||||
if ttl_time is None: # if ttl is not set, allow override
|
||||
return True
|
||||
elif float(ttl_time) < time.time(): # if ttl is expired, allow override
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Handle the edge case where max_size_in_memory is 0
|
||||
if self.max_size_in_memory == 0:
|
||||
return # Don't cache anything if max size is 0
|
||||
|
||||
if len(self.cache_dict) >= self.max_size_in_memory:
|
||||
# only evict when cache is full
|
||||
self.evict_cache()
|
||||
if not self.check_value_size(value):
|
||||
return
|
||||
|
||||
self.cache_dict[key] = value
|
||||
if self.allow_ttl_override(key): # if ttl is not set, set it to default ttl
|
||||
if "ttl" in kwargs and kwargs["ttl"] is not None:
|
||||
self.ttl_dict[key] = time.time() + float(kwargs["ttl"])
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
else:
|
||||
self.ttl_dict[key] = time.time() + self.default_ttl
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||
"""
|
||||
Add value to set
|
||||
"""
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or set()
|
||||
for val in value:
|
||||
init_value.add(val)
|
||||
self.set_cache(key, init_value, ttl=ttl)
|
||||
return value
|
||||
|
||||
def evict_element_if_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Returns True if the element is expired and removed from the cache
|
||||
|
||||
Returns False if the element is not expired
|
||||
"""
|
||||
if self._is_key_expired(key):
|
||||
self._remove_key(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if self.evict_element_if_expired(key):
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_increment_pipeline(
|
||||
self, increment_list: List["RedisPipelineIncrementOperation"], **kwargs
|
||||
) -> Optional[List[float]]:
|
||||
results = []
|
||||
for increment in increment_list:
|
||||
result = await self.async_increment(
|
||||
increment["key"], increment["increment_value"], **kwargs
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
self.expiration_heap.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self._remove_key(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache
|
||||
"""
|
||||
return self.ttl_dict.get(key, None)
|
||||
|
||||
async def async_get_oldest_n_keys(self, n: int) -> List[str]:
|
||||
"""
|
||||
Get the oldest n keys in the cache
|
||||
"""
|
||||
# sorted ttl dict by ttl
|
||||
sorted_ttl_dict = sorted(self.ttl_dict.items(), key=lambda x: x[1])
|
||||
return [key for key, _ in sorted_ttl_dict[:n]]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
"""Cache for LLM HTTP clients (OpenAI, Azure, httpx, etc.).
|
||||
|
||||
IMPORTANT: This cache intentionally does NOT close clients on eviction.
|
||||
Evicted clients may still be in use by in-flight requests. Closing them
|
||||
eagerly causes ``RuntimeError: Cannot send a request, as the client has
|
||||
been closed.`` errors in production after the TTL (1 hour) expires.
|
||||
|
||||
Clients that are no longer referenced will be garbage-collected normally.
|
||||
For explicit shutdown cleanup, use ``close_litellm_async_clients()``.
|
||||
"""
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
If none, use the key as is.
|
||||
"""
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
stringified_event_loop = str(id(event_loop))
|
||||
return f"{key}-{stringified_event_loop}"
|
||||
except RuntimeError: # handle no current running event loop
|
||||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return super().set_cache(key, value, **kwargs)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return await super().async_set_cache(key, value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return super().get_cache(key, **kwargs)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return await super().async_get_cache(key, **kwargs)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Qdrant Semantic Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__( # noqa: PLR0915
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
qdrant_api_key=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type=None,
|
||||
vector_size=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
)
|
||||
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_size = (
|
||||
vector_size if vector_size is not None else QDRANT_VECTOR_SIZE
|
||||
)
|
||||
headers = {}
|
||||
|
||||
# check if defined as os.environ/ variable
|
||||
if qdrant_api_base:
|
||||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = get_secret_str(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = get_secret_str(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
)
|
||||
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if qdrant_api_key:
|
||||
headers["api-key"] = qdrant_api_key
|
||||
|
||||
if qdrant_api_base is None:
|
||||
raise ValueError("Qdrant url must be provided")
|
||||
|
||||
self.qdrant_api_base = qdrant_api_base
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||
|
||||
self.headers = headers
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Caching
|
||||
)
|
||||
|
||||
if quantization_config is None:
|
||||
print_verbose(
|
||||
"Quantization config is not provided. Default binary quantization will be used."
|
||||
)
|
||||
collection_exists = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers,
|
||||
)
|
||||
if collection_exists.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||
)
|
||||
|
||||
if collection_exists.json()["result"]["exists"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "scalar":
|
||||
quantization_params = {
|
||||
"scalar": {
|
||||
"type": "int8",
|
||||
"quantile": QDRANT_SCALAR_QUANTILE,
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "product":
|
||||
quantization_params = {
|
||||
"product": {"compression": "x16", "always_ram": False}
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||
)
|
||||
|
||||
new_collection_status = self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
json={
|
||||
"vectors": {"size": self.vector_size, "distance": "Cosine"},
|
||||
"quantization_config": quantization_params,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
if new_collection_status.json()["result"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
from litellm._uuid import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm._uuid import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
await self.async_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Redis Cluster Cache implementation
|
||||
|
||||
Key differences:
|
||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class RedisClusterCache(RedisCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
||||
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
||||
|
||||
def init_async_client(self):
|
||||
from redis.asyncio import RedisCluster
|
||||
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
if self.redis_async_redis_cluster_client:
|
||||
return self.redis_async_redis_cluster_client
|
||||
|
||||
_redis_client = get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
if isinstance(_redis_client, RedisCluster):
|
||||
self.redis_async_redis_cluster_client = _redis_client
|
||||
|
||||
return _redis_client
|
||||
|
||||
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
async_redis_cluster_client = self.init_async_client()
|
||||
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the Redis Cluster connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis_async
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
# Create ClusterNode objects from startup_nodes
|
||||
cluster_kwargs = self.redis_kwargs.copy()
|
||||
startup_nodes = cluster_kwargs.pop("startup_nodes", [])
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
for item in startup_nodes:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
# Create a fresh Redis Cluster client with current settings
|
||||
redis_client = redis_async.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
# Test the connection
|
||||
ping_result = await redis_client.ping() # type: ignore[attr-defined, misc]
|
||||
|
||||
# Close the connection
|
||||
await redis_client.aclose() # type: ignore[attr-defined]
|
||||
|
||||
if ping_result:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Redis Cluster connection test successful",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": "Redis Cluster ping returned False",
|
||||
}
|
||||
except Exception as e:
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.error(f"Redis Cluster connection test failed: {str(e)}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": f"Redis Cluster connection failed: {str(e)}",
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Redis Semantic Cache implementation for LiteLLM
|
||||
|
||||
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
|
||||
This cache stores responses based on the semantic similarity of prompts rather than
|
||||
exact matching, allowing for more flexible caching of LLM responses.
|
||||
|
||||
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
|
||||
and their cached responses.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
||||
but carry similar meaning.
|
||||
"""
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
redis_url: Optional[str] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
index_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis Semantic Cache.
|
||||
|
||||
Args:
|
||||
host: Redis host address
|
||||
port: Redis port
|
||||
password: Redis password
|
||||
redis_url: Full Redis URL (alternative to separate host/port/password)
|
||||
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
|
||||
where 1.0 requires exact matches and 0.0 accepts any match
|
||||
embedding_model: Model to use for generating embeddings
|
||||
index_name: Name for the Redis index
|
||||
ttl: Default time-to-live for cache entries in seconds
|
||||
**kwargs: Additional arguments passed to the Redis client
|
||||
|
||||
Raises:
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
"""
|
||||
from redisvl.extensions.llmcache import SemanticCache
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer
|
||||
|
||||
if index_name is None:
|
||||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
|
||||
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
||||
|
||||
# Validate similarity threshold
|
||||
if similarity_threshold is None:
|
||||
raise ValueError("similarity_threshold must be provided, passed None")
|
||||
|
||||
# Store configuration
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
self.distance_threshold = 1 - similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# Set up Redis connection
|
||||
if redis_url is None:
|
||||
try:
|
||||
# Attempt to use provided parameters or fallback to environment variables
|
||||
host = host or os.environ["REDIS_HOST"]
|
||||
port = port or os.environ["REDIS_PORT"]
|
||||
password = password or os.environ["REDIS_PASSWORD"]
|
||||
except KeyError as e:
|
||||
# Raise a more informative exception if any of the required keys are missing
|
||||
missing_var = e.args[0]
|
||||
raise ValueError(
|
||||
f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url."
|
||||
) from e
|
||||
|
||||
redis_url = f"redis://:{password}@{host}:{port}"
|
||||
|
||||
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
|
||||
|
||||
# Initialize the Redis vectorizer and cache
|
||||
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
||||
|
||||
self.llmcache = SemanticCache(
|
||||
name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments that may contain a custom TTL
|
||||
|
||||
Returns:
|
||||
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
|
||||
"""
|
||||
ttl = kwargs.get("ttl")
|
||||
if ttl is not None:
|
||||
ttl = int(ttl)
|
||||
return ttl
|
||||
|
||||
def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given prompt using the configured embedding model.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
# Create an embedding from prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any) -> Any:
|
||||
"""
|
||||
Process the cached response to prepare it for use.
|
||||
|
||||
Args:
|
||||
cached_response: The raw cached response
|
||||
|
||||
Returns:
|
||||
The processed cache response, or None if input was None
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
|
||||
# Convert bytes to string if needed
|
||||
if isinstance(cached_response, bytes):
|
||||
cached_response = cached_response.decode("utf-8")
|
||||
|
||||
# Convert string representation to Python object
|
||||
try:
|
||||
cached_response = json.loads(cached_response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
print_verbose(f"Error parsing cached response: {str(e)}")
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
value_str: Optional[str] = None
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
||||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
)
|
||||
|
||||
def get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
|
||||
# Return None if no similar prompts found
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
Asynchronously generate an embedding for the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
**kwargs: Additional arguments that may contain metadata
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# Route the embedding request through the proxy if appropriate
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
# Use the router for embedding generation
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Generate embedding directly
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# Extract and return the embedding vector
|
||||
return embedding_response["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print_verbose(f"Error generating async embedding: {str(e)}")
|
||||
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_get_cache: {str(e)}")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
|
||||
async def _index_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the Redis index.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Information about the Redis index
|
||||
"""
|
||||
aindex = await self.llmcache._get_async_index()
|
||||
return await aindex.info()
|
||||
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store multiple values in the semantic cache.
|
||||
|
||||
Args:
|
||||
cache_list: List of (key, value) tuples to cache
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
try:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
S3 Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache (uses run_in_executor)
|
||||
- async_get_cache (uses run_in_executor)
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
s3_bucket_name,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
s3_aws_secret_access_key=None,
|
||||
s3_aws_session_token=None,
|
||||
s3_config=None,
|
||||
s3_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
import boto3
|
||||
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
|
||||
# Create an S3 client with custom endpoint URL
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_region_name,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
api_version=s3_api_version,
|
||||
use_ssl=s3_use_ssl,
|
||||
verify=s3_verify,
|
||||
aws_access_key_id=s3_aws_access_key_id,
|
||||
aws_secret_access_key=s3_aws_secret_access_key,
|
||||
aws_session_token=s3_aws_session_token,
|
||||
config=s3_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _to_s3_key(self, key: str) -> str:
|
||||
"""Convert cache key to S3 key"""
|
||||
return self.key_prefix + key.replace(":", "/")
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
|
||||
ttl = kwargs.get("ttl", None)
|
||||
# Convert value to JSON before storing in S3
|
||||
serialized_value = json.dumps(value)
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
if ttl is not None:
|
||||
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
|
||||
|
||||
# Calculate expiration time
|
||||
expiration_time = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
# Upload the data to S3 with the calculated expiration time
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
Expires=expiration_time,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
else:
|
||||
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
|
||||
# Upload the data to S3 without specifying Expires
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
"""
|
||||
Asynchronously set cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Set ASYNC S3 Cache: Key={key}. Value={value}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.set_cache, key, value, **kwargs)
|
||||
await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_set_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import botocore
|
||||
|
||||
try:
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
print_verbose(f"Get S3 Cache: key: {key}")
|
||||
# Download the data from S3
|
||||
cached_response = self.s3_client.get_object(
|
||||
Bucket=self.bucket_name, Key=key
|
||||
)
|
||||
|
||||
if cached_response is not None:
|
||||
if "Expires" in cached_response:
|
||||
expires_time = cached_response["Expires"]
|
||||
current_time = datetime.now(expires_time.tzinfo)
|
||||
|
||||
if current_time > expires_time:
|
||||
return None
|
||||
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = (
|
||||
cached_response["Body"].read().decode("utf-8")
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if not isinstance(cached_response, dict):
|
||||
cached_response = dict(cached_response)
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e: # type: ignore
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
verbose_logger.debug(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
"""
|
||||
Asynchronously get cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Get ASYNC S3 Cache: key: {key}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.get_cache, key, **kwargs)
|
||||
result = await loop.run_in_executor(None, func)
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,4 @@
|
||||
Logic specific for `litellm.completion`.
|
||||
|
||||
Includes:
|
||||
- Bridge for transforming completion requests to responses api requests
|
||||
@@ -0,0 +1,3 @@
|
||||
from .litellm_responses_transformation import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,3 @@
|
||||
from .handler import responses_api_bridge
|
||||
|
||||
__all__ = ["responses_api_bridge"]
|
||||
@@ -0,0 +1,331 @@
|
||||
"""
|
||||
Handler for transforming /chat/completions api requests to litellm.responses requests
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Coroutine, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.types.llms.openai import ResponsesAPIResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper, LiteLLMLoggingObj, ModelResponse
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandlerInputKwargs(TypedDict):
|
||||
model: str
|
||||
messages: list
|
||||
optional_params: dict
|
||||
litellm_params: dict
|
||||
headers: dict
|
||||
model_response: "ModelResponse"
|
||||
logging_obj: "LiteLLMLoggingObj"
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
class ResponsesToCompletionBridgeHandler:
|
||||
def __init__(self):
|
||||
from .transformation import LiteLLMResponsesTransformationHandler
|
||||
|
||||
super().__init__()
|
||||
self.transformation_handler = LiteLLMResponsesTransformationHandler()
|
||||
|
||||
@staticmethod
|
||||
def _resolve_stream_flag(optional_params: dict, litellm_params: dict) -> bool:
|
||||
stream = optional_params.get("stream")
|
||||
if stream is None:
|
||||
stream = litellm_params.get("stream", False)
|
||||
return bool(stream)
|
||||
|
||||
@staticmethod
|
||||
def _coerce_response_object(
|
||||
response_obj: Any,
|
||||
hidden_params: Optional[dict],
|
||||
) -> "ResponsesAPIResponse":
|
||||
if isinstance(response_obj, ResponsesAPIResponse):
|
||||
response = response_obj
|
||||
elif isinstance(response_obj, dict):
|
||||
try:
|
||||
response = ResponsesAPIResponse(**response_obj)
|
||||
except Exception:
|
||||
response = ResponsesAPIResponse.model_construct(**response_obj)
|
||||
else:
|
||||
raise ValueError("Unexpected responses stream payload")
|
||||
|
||||
if hidden_params:
|
||||
existing = getattr(response, "_hidden_params", None)
|
||||
if not isinstance(existing, dict) or not existing:
|
||||
setattr(response, "_hidden_params", dict(hidden_params))
|
||||
else:
|
||||
for key, value in hidden_params.items():
|
||||
existing.setdefault(key, value)
|
||||
return response
|
||||
|
||||
def _collect_response_from_stream(self, stream_iter: Any) -> "ResponsesAPIResponse":
|
||||
for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
async def _collect_response_from_stream_async(
|
||||
self, stream_iter: Any
|
||||
) -> "ResponsesAPIResponse":
|
||||
async for _ in stream_iter:
|
||||
pass
|
||||
|
||||
completed = getattr(stream_iter, "completed_response", None)
|
||||
response_obj = getattr(completed, "response", None) if completed else None
|
||||
if response_obj is None:
|
||||
raise ValueError("Stream ended without a completed response")
|
||||
|
||||
hidden_params = getattr(stream_iter, "_hidden_params", None)
|
||||
response = self._coerce_response_object(response_obj, hidden_params)
|
||||
if not isinstance(response, ResponsesAPIResponse):
|
||||
raise ValueError("Stream completed response is invalid")
|
||||
return response
|
||||
|
||||
def validate_input_kwargs(
|
||||
self, kwargs: dict
|
||||
) -> ResponsesToCompletionBridgeHandlerInputKwargs:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None or not isinstance(model, str):
|
||||
raise ValueError("model is required")
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
messages = kwargs.get("messages")
|
||||
if messages is None or not isinstance(messages, list):
|
||||
raise ValueError("messages is required")
|
||||
|
||||
optional_params = kwargs.get("optional_params")
|
||||
if optional_params is None or not isinstance(optional_params, dict):
|
||||
raise ValueError("optional_params is required")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if litellm_params is None or not isinstance(litellm_params, dict):
|
||||
raise ValueError("litellm_params is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
model_response = kwargs.get("model_response")
|
||||
if model_response is None or not isinstance(model_response, ModelResponse):
|
||||
raise ValueError("model_response is required")
|
||||
|
||||
logging_obj = kwargs.get("logging_obj")
|
||||
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
|
||||
raise ValueError("logging_obj is required")
|
||||
|
||||
return ResponsesToCompletionBridgeHandlerInputKwargs(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
def completion(
|
||||
self, *args, **kwargs
|
||||
) -> Union[
|
||||
Coroutine[Any, Any, Union["ModelResponse", "CustomStreamWrapper"]],
|
||||
"ModelResponse",
|
||||
"CustomStreamWrapper",
|
||||
]:
|
||||
if kwargs.get("acompletion") is True:
|
||||
return self.acompletion(**kwargs)
|
||||
|
||||
from litellm import responses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
client=kwargs.get("client"),
|
||||
)
|
||||
|
||||
result = responses(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = self._collect_response_from_stream(result)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=True,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self, *args, **kwargs
|
||||
) -> Union["ModelResponse", "CustomStreamWrapper"]:
|
||||
from litellm import aresponses
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(kwargs)
|
||||
model = validated_kwargs["model"]
|
||||
messages = validated_kwargs["messages"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
model_response = validated_kwargs["model_response"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
|
||||
try:
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
result = await aresponses(
|
||||
**request_data,
|
||||
aresponses=True,
|
||||
)
|
||||
|
||||
stream = self._resolve_stream_flag(optional_params, litellm_params)
|
||||
if isinstance(result, ResponsesAPIResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=result,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
elif not stream:
|
||||
responses_api_response = await self._collect_response_from_stream_async(
|
||||
result
|
||||
)
|
||||
return self.transformation_handler.transform_response(
|
||||
model=model,
|
||||
raw_response=responses_api_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=kwargs.get("encoding"),
|
||||
api_key=kwargs.get("api_key"),
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
else:
|
||||
completion_stream = self.transformation_handler.get_model_response_iterator(
|
||||
streaming_response=result, # type: ignore
|
||||
sync_stream=False,
|
||||
json_mode=kwargs.get("json_mode"),
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return self._apply_post_stream_processing(
|
||||
streamwrapper, model, custom_llm_provider
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _apply_post_stream_processing(
|
||||
stream: "CustomStreamWrapper",
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
) -> Any:
|
||||
"""Apply provider-specific post-stream processing if available."""
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
try:
|
||||
provider_config = ProviderConfigManager.get_provider_chat_config(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return stream
|
||||
|
||||
if provider_config is not None:
|
||||
return provider_config.post_stream_processing(stream)
|
||||
return stream
|
||||
|
||||
|
||||
responses_api_bridge = ResponsesToCompletionBridgeHandler()
|
||||
File diff suppressed because it is too large
Load Diff
1530
llm-gateway-competitors/litellm-wheel-src/litellm/constants.py
Normal file
1530
llm-gateway-competitors/litellm-wheel-src/litellm/constants.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,241 @@
|
||||
# Container Files API
|
||||
|
||||
This module provides a unified interface for container file operations across multiple LLM providers (OpenAI, Azure OpenAI, etc.).
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
endpoints.json # Declarative endpoint definitions
|
||||
↓
|
||||
endpoint_factory.py # Auto-generates SDK functions
|
||||
↓
|
||||
container_handler.py # Generic HTTP handler
|
||||
↓
|
||||
BaseContainerConfig # Provider-specific transformations
|
||||
├── OpenAIContainerConfig
|
||||
└── AzureContainerConfig (example)
|
||||
```
|
||||
|
||||
## Files Overview
|
||||
|
||||
| File | Purpose |
|
||||
|------|---------|
|
||||
| `endpoints.json` | **Single source of truth** - Defines all container file endpoints |
|
||||
| `endpoint_factory.py` | Auto-generates SDK functions (`list_container_files`, etc.) |
|
||||
| `main.py` | Core container operations (create, list, retrieve, delete containers) |
|
||||
| `utils.py` | Request parameter utilities |
|
||||
|
||||
## Adding a New Endpoint
|
||||
|
||||
To add a new container file endpoint (e.g., `get_container_file_content`):
|
||||
|
||||
### Step 1: Add to `endpoints.json`
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "get_container_file_content",
|
||||
"async_name": "aget_container_file_content",
|
||||
"path": "/containers/{container_id}/files/{file_id}/content",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileContentResponse"
|
||||
}
|
||||
```
|
||||
|
||||
### Step 2: Add Response Type (if new)
|
||||
|
||||
In `litellm/types/containers/main.py`:
|
||||
|
||||
```python
|
||||
class ContainerFileContentResponse(BaseModel):
|
||||
"""Response for file content download."""
|
||||
content: bytes
|
||||
# ... other fields
|
||||
```
|
||||
|
||||
### Step 3: Register Response Type
|
||||
|
||||
In `litellm/llms/custom_httpx/container_handler.py`, add to `RESPONSE_TYPES`:
|
||||
|
||||
```python
|
||||
RESPONSE_TYPES = {
|
||||
# ... existing types
|
||||
"ContainerFileContentResponse": ContainerFileContentResponse,
|
||||
}
|
||||
```
|
||||
|
||||
### Step 4: Update Router (one-time setup)
|
||||
|
||||
In `litellm/router.py`, add the call_type to the factory_function Literal and `_init_containers_api_endpoints` condition.
|
||||
|
||||
In `litellm/proxy/route_llm_request.py`, add to the route mappings and skip-model-routing lists.
|
||||
|
||||
### Step 5: Update Proxy Handler Factory (if new path params)
|
||||
|
||||
If your endpoint has a new combination of path parameters, add a handler in `litellm/proxy/container_endpoints/handler_factory.py`:
|
||||
|
||||
```python
|
||||
elif path_params == ["container_id", "file_id", "new_param"]:
|
||||
async def handler(...):
|
||||
# handler implementation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Adding a New Provider (e.g., Azure OpenAI)
|
||||
|
||||
### Step 1: Create Provider Config
|
||||
|
||||
Create `litellm/llms/azure/containers/transformation.py`:
|
||||
|
||||
```python
|
||||
from typing import Dict, Optional, Tuple, Any
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class AzureContainerConfig(BaseContainerConfig):
|
||||
"""Configuration class for Azure OpenAI container API."""
|
||||
|
||||
def get_supported_openai_params(self) -> list:
|
||||
return ["name", "expires_after", "file_ids", "extra_headers"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
container_create_optional_params,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
return dict(container_create_optional_params)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Azure uses api-key header instead of Bearer token."""
|
||||
import litellm
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
headers["api-key"] = api_key
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Azure format:
|
||||
https://{resource}.openai.azure.com/openai/containers?api-version=2024-xx
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required for Azure")
|
||||
|
||||
api_version = litellm_params.get("api_version", "2024-02-15-preview")
|
||||
return f"{api_base.rstrip('/')}/openai/containers?api-version={api_version}"
|
||||
|
||||
# Implement remaining abstract methods from BaseContainerConfig:
|
||||
# - transform_container_create_request
|
||||
# - transform_container_create_response
|
||||
# - transform_container_list_request
|
||||
# - transform_container_list_response
|
||||
# - transform_container_retrieve_request
|
||||
# - transform_container_retrieve_response
|
||||
# - transform_container_delete_request
|
||||
# - transform_container_delete_response
|
||||
# - transform_container_file_list_request
|
||||
# - transform_container_file_list_response
|
||||
```
|
||||
|
||||
### Step 2: Register Provider Config
|
||||
|
||||
In `litellm/utils.py`, find `ProviderConfigManager.get_provider_container_config()` and add:
|
||||
|
||||
```python
|
||||
@staticmethod
|
||||
def get_provider_container_config(
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseContainerConfig]:
|
||||
if provider == LlmProviders.OPENAI:
|
||||
from litellm.llms.openai.containers.transformation import OpenAIContainerConfig
|
||||
return OpenAIContainerConfig()
|
||||
elif provider == LlmProviders.AZURE:
|
||||
from litellm.llms.azure.containers.transformation import AzureContainerConfig
|
||||
return AzureContainerConfig()
|
||||
return None
|
||||
```
|
||||
|
||||
### Step 3: Test the New Provider
|
||||
|
||||
```bash
|
||||
# Create container via Azure
|
||||
curl -X POST "http://localhost:4000/v1/containers" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "custom-llm-provider: azure" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"name": "My Azure Container"}'
|
||||
|
||||
# List container files via Azure
|
||||
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "custom-llm-provider: azure"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## How Provider Selection Works
|
||||
|
||||
1. **Proxy receives request** with `custom-llm-provider` header/query/body
|
||||
2. **Router calls** `ProviderConfigManager.get_provider_container_config(provider)`
|
||||
3. **Generic handler** uses the provider config for:
|
||||
- URL construction (`get_complete_url`)
|
||||
- Authentication (`validate_environment`)
|
||||
- Request/response transformation
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
Run the container API tests:
|
||||
|
||||
```bash
|
||||
cd /Users/ishaanjaffer/github/litellm
|
||||
python -m pytest tests/test_litellm/containers/ -v
|
||||
```
|
||||
|
||||
Test via proxy:
|
||||
|
||||
```bash
|
||||
# Start proxy
|
||||
cd litellm/proxy && python proxy_cli.py --config proxy_config.yaml --port 4000
|
||||
|
||||
# Test endpoints
|
||||
curl -X GET "http://localhost:4000/v1/containers/cntr_123/files" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Endpoint Reference
|
||||
|
||||
| Endpoint | Method | Path |
|
||||
|----------|--------|------|
|
||||
| List container files | GET | `/v1/containers/{container_id}/files` |
|
||||
| Retrieve container file | GET | `/v1/containers/{container_id}/files/{file_id}` |
|
||||
| Delete container file | DELETE | `/v1/containers/{container_id}/files/{file_id}` |
|
||||
|
||||
See `endpoints.json` for the complete list.
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
"""Container management functions for LiteLLM."""
|
||||
|
||||
# Auto-generated container file functions from endpoints.json
|
||||
from .endpoint_factory import (
|
||||
adelete_container_file,
|
||||
alist_container_files,
|
||||
aretrieve_container_file,
|
||||
aretrieve_container_file_content,
|
||||
delete_container_file,
|
||||
list_container_files,
|
||||
retrieve_container_file,
|
||||
retrieve_container_file_content,
|
||||
)
|
||||
from .main import (
|
||||
acreate_container,
|
||||
adelete_container,
|
||||
alist_containers,
|
||||
aretrieve_container,
|
||||
create_container,
|
||||
delete_container,
|
||||
list_containers,
|
||||
retrieve_container,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core container operations
|
||||
"acreate_container",
|
||||
"adelete_container",
|
||||
"alist_containers",
|
||||
"aretrieve_container",
|
||||
"create_container",
|
||||
"delete_container",
|
||||
"list_containers",
|
||||
"retrieve_container",
|
||||
# Container file operations (auto-generated from endpoints.json)
|
||||
"adelete_container_file",
|
||||
"alist_container_files",
|
||||
"aretrieve_container_file",
|
||||
"aretrieve_container_file_content",
|
||||
"delete_container_file",
|
||||
"list_container_files",
|
||||
"retrieve_container_file",
|
||||
"retrieve_container_file_content",
|
||||
]
|
||||
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
Factory for generating container SDK functions from JSON config.
|
||||
|
||||
This module reads endpoints.json and dynamically generates SDK functions
|
||||
that use the generic container handler.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Type
|
||||
|
||||
import litellm
|
||||
from litellm.constants import request_timeout as DEFAULT_REQUEST_TIMEOUT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.llms.custom_httpx.container_handler import generic_container_handler
|
||||
from litellm.types.containers.main import (
|
||||
ContainerFileListResponse,
|
||||
ContainerFileObject,
|
||||
DeleteContainerFileResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import ProviderConfigManager, client
|
||||
|
||||
# Response type mapping
|
||||
RESPONSE_TYPES: Dict[str, Type] = {
|
||||
"ContainerFileListResponse": ContainerFileListResponse,
|
||||
"ContainerFileObject": ContainerFileObject,
|
||||
"DeleteContainerFileResponse": DeleteContainerFileResponse,
|
||||
}
|
||||
|
||||
|
||||
def _load_endpoints_config() -> Dict:
|
||||
"""Load the endpoints configuration from JSON file."""
|
||||
config_path = Path(__file__).parent / "endpoints.json"
|
||||
with open(config_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def create_sync_endpoint_function(endpoint_config: Dict) -> Callable:
|
||||
"""
|
||||
Create a sync SDK function from endpoint config.
|
||||
|
||||
Uses the generic container handler instead of individual handler methods.
|
||||
"""
|
||||
endpoint_name = endpoint_config["name"]
|
||||
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
|
||||
path_params = endpoint_config.get("path_params", [])
|
||||
|
||||
@client
|
||||
def endpoint_func(
|
||||
timeout: int = 600,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
local_vars = locals()
|
||||
try:
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.pop("litellm_logging_obj")
|
||||
litellm_call_id: Optional[str] = kwargs.get("litellm_call_id")
|
||||
_is_async = kwargs.pop("async_call", False) is True
|
||||
|
||||
# Check for mock response
|
||||
mock_response = kwargs.get("mock_response")
|
||||
if mock_response is not None:
|
||||
if isinstance(mock_response, str):
|
||||
mock_response = json.loads(mock_response)
|
||||
if response_type:
|
||||
return response_type(**mock_response)
|
||||
return mock_response
|
||||
|
||||
# Get provider config
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
container_provider_config: Optional[
|
||||
BaseContainerConfig
|
||||
] = ProviderConfigManager.get_provider_container_config(
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if container_provider_config is None:
|
||||
raise ValueError(
|
||||
f"Container provider config not found for: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
# Build optional params for logging
|
||||
optional_params = {k: kwargs.get(k) for k in path_params if k in kwargs}
|
||||
|
||||
# Pre-call logging
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model="",
|
||||
optional_params=optional_params,
|
||||
litellm_params={"litellm_call_id": litellm_call_id},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Use generic handler
|
||||
return generic_container_handler.handle(
|
||||
endpoint_name=endpoint_name,
|
||||
container_provider_config=container_provider_config,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=litellm_logging_obj,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
timeout=timeout or DEFAULT_REQUEST_TIMEOUT,
|
||||
_is_async=_is_async,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model="",
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return endpoint_func
|
||||
|
||||
|
||||
def create_async_endpoint_function(
|
||||
sync_func: Callable,
|
||||
endpoint_config: Dict,
|
||||
) -> Callable:
|
||||
"""Create an async SDK function that wraps the sync function."""
|
||||
|
||||
@client
|
||||
async def async_endpoint_func(
|
||||
timeout: int = 600,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, Any]] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
local_vars = locals()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["async_call"] = True
|
||||
|
||||
func = partial(
|
||||
sync_func,
|
||||
timeout=timeout,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
extra_headers=extra_headers,
|
||||
extra_query=extra_query,
|
||||
extra_body=extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise litellm.exception_type(
|
||||
model="",
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
original_exception=e,
|
||||
completion_kwargs=local_vars,
|
||||
extra_kwargs=kwargs,
|
||||
)
|
||||
|
||||
return async_endpoint_func
|
||||
|
||||
|
||||
def generate_container_endpoints() -> Dict[str, Callable]:
|
||||
"""
|
||||
Generate all container endpoint functions from the JSON config.
|
||||
|
||||
Returns a dict mapping function names to their implementations.
|
||||
"""
|
||||
config = _load_endpoints_config()
|
||||
endpoints = {}
|
||||
|
||||
for endpoint_config in config["endpoints"]:
|
||||
# Create sync function
|
||||
sync_func = create_sync_endpoint_function(endpoint_config)
|
||||
endpoints[endpoint_config["name"]] = sync_func
|
||||
|
||||
# Create async function
|
||||
async_func = create_async_endpoint_function(sync_func, endpoint_config)
|
||||
endpoints[endpoint_config["async_name"]] = async_func
|
||||
|
||||
return endpoints
|
||||
|
||||
|
||||
def get_all_endpoint_names() -> List[str]:
|
||||
"""Get all endpoint names (sync and async) from config."""
|
||||
config = _load_endpoints_config()
|
||||
names = []
|
||||
for endpoint in config["endpoints"]:
|
||||
names.append(endpoint["name"])
|
||||
names.append(endpoint["async_name"])
|
||||
return names
|
||||
|
||||
|
||||
def get_async_endpoint_names() -> List[str]:
|
||||
"""Get all async endpoint names for router registration."""
|
||||
config = _load_endpoints_config()
|
||||
return [endpoint["async_name"] for endpoint in config["endpoints"]]
|
||||
|
||||
|
||||
# Generate endpoints on module load
|
||||
_generated_endpoints = generate_container_endpoints()
|
||||
|
||||
# Export generated functions dynamically
|
||||
list_container_files = _generated_endpoints.get("list_container_files")
|
||||
alist_container_files = _generated_endpoints.get("alist_container_files")
|
||||
upload_container_file = _generated_endpoints.get("upload_container_file")
|
||||
aupload_container_file = _generated_endpoints.get("aupload_container_file")
|
||||
retrieve_container_file = _generated_endpoints.get("retrieve_container_file")
|
||||
aretrieve_container_file = _generated_endpoints.get("aretrieve_container_file")
|
||||
delete_container_file = _generated_endpoints.get("delete_container_file")
|
||||
adelete_container_file = _generated_endpoints.get("adelete_container_file")
|
||||
retrieve_container_file_content = _generated_endpoints.get(
|
||||
"retrieve_container_file_content"
|
||||
)
|
||||
aretrieve_container_file_content = _generated_endpoints.get(
|
||||
"aretrieve_container_file_content"
|
||||
)
|
||||
@@ -0,0 +1,51 @@
|
||||
{
|
||||
"endpoints": [
|
||||
{
|
||||
"name": "list_container_files",
|
||||
"async_name": "alist_container_files",
|
||||
"path": "/containers/{container_id}/files",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id"],
|
||||
"query_params": ["after", "limit", "order"],
|
||||
"response_type": "ContainerFileListResponse"
|
||||
},
|
||||
{
|
||||
"name": "upload_container_file",
|
||||
"async_name": "aupload_container_file",
|
||||
"path": "/containers/{container_id}/files",
|
||||
"method": "POST",
|
||||
"path_params": ["container_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileObject",
|
||||
"is_multipart": true
|
||||
},
|
||||
{
|
||||
"name": "retrieve_container_file",
|
||||
"async_name": "aretrieve_container_file",
|
||||
"path": "/containers/{container_id}/files/{file_id}",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "ContainerFileObject"
|
||||
},
|
||||
{
|
||||
"name": "delete_container_file",
|
||||
"async_name": "adelete_container_file",
|
||||
"path": "/containers/{container_id}/files/{file_id}",
|
||||
"method": "DELETE",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "DeleteContainerFileResponse"
|
||||
},
|
||||
{
|
||||
"name": "retrieve_container_file_content",
|
||||
"async_name": "aretrieve_container_file_content",
|
||||
"path": "/containers/{container_id}/files/{file_id}/content",
|
||||
"method": "GET",
|
||||
"path_params": ["container_id", "file_id"],
|
||||
"query_params": [],
|
||||
"response_type": "raw",
|
||||
"returns_binary": true
|
||||
}
|
||||
]
|
||||
}
|
||||
1290
llm-gateway-competitors/litellm-wheel-src/litellm/containers/main.py
Normal file
1290
llm-gateway-competitors/litellm-wheel-src/litellm/containers/main.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,70 @@
|
||||
from typing import Dict
|
||||
|
||||
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
|
||||
from litellm.types.containers.main import (
|
||||
ContainerCreateOptionalRequestParams,
|
||||
ContainerListOptionalRequestParams,
|
||||
)
|
||||
|
||||
|
||||
class ContainerRequestUtils:
|
||||
@staticmethod
|
||||
def get_requested_container_create_optional_param(
|
||||
passed_params: dict,
|
||||
) -> ContainerCreateOptionalRequestParams:
|
||||
"""Extract only valid container creation parameters from the passed parameters."""
|
||||
container_create_optional_params = ContainerCreateOptionalRequestParams()
|
||||
|
||||
valid_params = [
|
||||
"expires_after",
|
||||
"file_ids",
|
||||
"extra_headers",
|
||||
"extra_body",
|
||||
]
|
||||
|
||||
for param in valid_params:
|
||||
if param in passed_params and passed_params[param] is not None:
|
||||
container_create_optional_params[param] = passed_params[param] # type: ignore
|
||||
|
||||
return container_create_optional_params
|
||||
|
||||
@staticmethod
|
||||
def get_optional_params_container_create(
|
||||
container_provider_config: BaseContainerConfig,
|
||||
container_create_optional_params: ContainerCreateOptionalRequestParams,
|
||||
) -> Dict:
|
||||
"""Get the optional parameters for container creation."""
|
||||
supported_params = container_provider_config.get_supported_openai_params()
|
||||
|
||||
# Filter out unsupported parameters
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in container_create_optional_params.items()
|
||||
if k in supported_params
|
||||
}
|
||||
|
||||
return container_provider_config.map_openai_params(
|
||||
container_create_optional_params=filtered_params, # type: ignore
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_requested_container_list_optional_param(
|
||||
passed_params: dict,
|
||||
) -> ContainerListOptionalRequestParams:
|
||||
"""Extract only valid container list parameters from the passed parameters."""
|
||||
container_list_optional_params = ContainerListOptionalRequestParams()
|
||||
|
||||
valid_params = [
|
||||
"after",
|
||||
"limit",
|
||||
"order",
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
]
|
||||
|
||||
for param in valid_params:
|
||||
if param in passed_params and passed_params[param] is not None:
|
||||
container_list_optional_params[param] = passed_params[param] # type: ignore
|
||||
|
||||
return container_list_optional_params
|
||||
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"gpt-3.5-turbo-0613": 0.00015000000000000001,
|
||||
"claude-2": 0.00016454,
|
||||
"gpt-4-0613": 0.015408
|
||||
}
|
||||
2268
llm-gateway-competitors/litellm-wheel-src/litellm/cost_calculator.py
Normal file
2268
llm-gateway-competitors/litellm-wheel-src/litellm/cost_calculator.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,128 @@
|
||||
"""
|
||||
Handler for transforming /chat/completions api requests to litellm.responses requests
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeHandlerInputKwargs(TypedDict):
|
||||
model: str
|
||||
input: str
|
||||
voice: Optional[Union[str, dict]]
|
||||
optional_params: dict
|
||||
litellm_params: dict
|
||||
logging_obj: "LiteLLMLoggingObj"
|
||||
headers: dict
|
||||
custom_llm_provider: str
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeHandler:
|
||||
def __init__(self):
|
||||
from .transformation import SpeechToCompletionBridgeTransformationHandler
|
||||
|
||||
super().__init__()
|
||||
self.transformation_handler = SpeechToCompletionBridgeTransformationHandler()
|
||||
|
||||
def validate_input_kwargs(
|
||||
self, kwargs: dict
|
||||
) -> SpeechToCompletionBridgeHandlerInputKwargs:
|
||||
from litellm import LiteLLMLoggingObj
|
||||
|
||||
model = kwargs.get("model")
|
||||
if model is None or not isinstance(model, str):
|
||||
raise ValueError("model is required")
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
if custom_llm_provider is None or not isinstance(custom_llm_provider, str):
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
input = kwargs.get("input")
|
||||
if input is None or not isinstance(input, str):
|
||||
raise ValueError("input is required")
|
||||
|
||||
optional_params = kwargs.get("optional_params")
|
||||
if optional_params is None or not isinstance(optional_params, dict):
|
||||
raise ValueError("optional_params is required")
|
||||
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
if litellm_params is None or not isinstance(litellm_params, dict):
|
||||
raise ValueError("litellm_params is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
headers = kwargs.get("headers")
|
||||
if headers is None or not isinstance(headers, dict):
|
||||
raise ValueError("headers is required")
|
||||
|
||||
logging_obj = kwargs.get("logging_obj")
|
||||
if logging_obj is None or not isinstance(logging_obj, LiteLLMLoggingObj):
|
||||
raise ValueError("logging_obj is required")
|
||||
|
||||
return SpeechToCompletionBridgeHandlerInputKwargs(
|
||||
model=model,
|
||||
input=input,
|
||||
voice=kwargs.get("voice"),
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def speech(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, dict]],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
custom_llm_provider: str,
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
received_args = locals()
|
||||
from litellm import completion
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
validated_kwargs = self.validate_input_kwargs(received_args)
|
||||
model = validated_kwargs["model"]
|
||||
input = validated_kwargs["input"]
|
||||
optional_params = validated_kwargs["optional_params"]
|
||||
litellm_params = validated_kwargs["litellm_params"]
|
||||
headers = validated_kwargs["headers"]
|
||||
logging_obj = validated_kwargs["logging_obj"]
|
||||
custom_llm_provider = validated_kwargs["custom_llm_provider"]
|
||||
voice = validated_kwargs["voice"]
|
||||
|
||||
request_data = self.transformation_handler.transform_request(
|
||||
model=model,
|
||||
input=input,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
litellm_logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
voice=voice,
|
||||
)
|
||||
|
||||
result = completion(
|
||||
**request_data,
|
||||
)
|
||||
|
||||
if isinstance(result, ModelResponse):
|
||||
return self.transformation_handler.transform_response(
|
||||
model_response=result,
|
||||
)
|
||||
else:
|
||||
raise Exception("Unmapped response type. Got type: {}".format(type(result)))
|
||||
|
||||
|
||||
speech_to_completion_bridge_handler = SpeechToCompletionBridgeHandler()
|
||||
@@ -0,0 +1,134 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
from litellm.constants import OPENAI_CHAT_COMPLETION_PARAMS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class SpeechToCompletionBridgeTransformationHandler:
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
input: str,
|
||||
voice: Optional[Union[str, dict]],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
custom_llm_provider: str,
|
||||
) -> dict:
|
||||
passed_optional_params = {}
|
||||
for op in optional_params:
|
||||
if op in OPENAI_CHAT_COMPLETION_PARAMS:
|
||||
passed_optional_params[op] = optional_params[op]
|
||||
|
||||
if voice is not None:
|
||||
if isinstance(voice, str):
|
||||
passed_optional_params["audio"] = {"voice": voice}
|
||||
if "response_format" in optional_params:
|
||||
passed_optional_params["audio"]["format"] = optional_params[
|
||||
"response_format"
|
||||
]
|
||||
|
||||
return_kwargs = {
|
||||
"model": model,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
"modalities": ["audio"],
|
||||
**passed_optional_params,
|
||||
**litellm_params,
|
||||
"headers": headers,
|
||||
"litellm_logging_obj": litellm_logging_obj,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
}
|
||||
|
||||
# filter out None values
|
||||
return_kwargs = {k: v for k, v in return_kwargs.items() if v is not None}
|
||||
return return_kwargs
|
||||
|
||||
def _convert_pcm16_to_wav(
|
||||
self, pcm_data: bytes, sample_rate: int = 24000, channels: int = 1
|
||||
) -> bytes:
|
||||
"""
|
||||
Convert raw PCM16 data to WAV format.
|
||||
|
||||
Args:
|
||||
pcm_data: Raw PCM16 audio data
|
||||
sample_rate: Sample rate in Hz (Gemini TTS typically uses 24000)
|
||||
channels: Number of audio channels (1 for mono)
|
||||
|
||||
Returns:
|
||||
bytes: WAV formatted audio data
|
||||
"""
|
||||
import struct
|
||||
|
||||
# WAV header parameters
|
||||
byte_rate = sample_rate * channels * 2 # 2 bytes per sample (16-bit)
|
||||
block_align = channels * 2
|
||||
data_size = len(pcm_data)
|
||||
file_size = 36 + data_size
|
||||
|
||||
# Create WAV header
|
||||
wav_header = struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", # Chunk ID
|
||||
file_size, # Chunk Size
|
||||
b"WAVE", # Format
|
||||
b"fmt ", # Subchunk1 ID
|
||||
16, # Subchunk1 Size (PCM)
|
||||
1, # Audio Format (PCM)
|
||||
channels, # Number of Channels
|
||||
sample_rate, # Sample Rate
|
||||
byte_rate, # Byte Rate
|
||||
block_align, # Block Align
|
||||
16, # Bits per Sample
|
||||
b"data", # Subchunk2 ID
|
||||
data_size, # Subchunk2 Size
|
||||
)
|
||||
|
||||
return wav_header + pcm_data
|
||||
|
||||
def _is_gemini_tts_model(self, model: str) -> bool:
|
||||
"""Check if the model is a Gemini TTS model that returns PCM16 data."""
|
||||
return "gemini" in model.lower() and (
|
||||
"tts" in model.lower() or "preview-tts" in model.lower()
|
||||
)
|
||||
|
||||
def transform_response(
|
||||
self, model_response: "ModelResponse"
|
||||
) -> "HttpxBinaryResponseContent":
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.utils import Choices
|
||||
|
||||
audio_part = cast(Choices, model_response.choices[0]).message.audio
|
||||
if audio_part is None:
|
||||
raise ValueError("No audio part found in the response")
|
||||
audio_content = audio_part.data
|
||||
|
||||
# Decode base64 to get binary content
|
||||
binary_data = base64.b64decode(audio_content)
|
||||
|
||||
# Check if this is a Gemini TTS model that returns raw PCM16 data
|
||||
model = getattr(model_response, "model", "")
|
||||
headers = {}
|
||||
if self._is_gemini_tts_model(model):
|
||||
# Convert PCM16 to WAV format for proper audio file playback
|
||||
binary_data = self._convert_pcm16_to_wav(binary_data)
|
||||
headers["Content-Type"] = "audio/wav"
|
||||
else:
|
||||
headers["Content-Type"] = "audio/mpeg"
|
||||
|
||||
# Create an httpx.Response object
|
||||
response = httpx.Response(status_code=200, content=binary_data, headers=headers)
|
||||
return HttpxBinaryResponseContent(response)
|
||||
@@ -0,0 +1,33 @@
|
||||
"""
|
||||
Evals API operations
|
||||
"""
|
||||
|
||||
from .main import (
|
||||
acancel_eval,
|
||||
acreate_eval,
|
||||
adelete_eval,
|
||||
aget_eval,
|
||||
alist_evals,
|
||||
aupdate_eval,
|
||||
cancel_eval,
|
||||
create_eval,
|
||||
delete_eval,
|
||||
get_eval,
|
||||
list_evals,
|
||||
update_eval,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"acreate_eval",
|
||||
"alist_evals",
|
||||
"aget_eval",
|
||||
"aupdate_eval",
|
||||
"adelete_eval",
|
||||
"acancel_eval",
|
||||
"create_eval",
|
||||
"list_evals",
|
||||
"get_eval",
|
||||
"update_eval",
|
||||
"delete_eval",
|
||||
"cancel_eval",
|
||||
]
|
||||
1975
llm-gateway-competitors/litellm-wheel-src/litellm/evals/main.py
Normal file
1975
llm-gateway-competitors/litellm-wheel-src/litellm/evals/main.py
Normal file
File diff suppressed because it is too large
Load Diff
1030
llm-gateway-competitors/litellm-wheel-src/litellm/exceptions.py
Normal file
1030
llm-gateway-competitors/litellm-wheel-src/litellm/exceptions.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
||||
# LiteLLM MCP Client
|
||||
|
||||
LiteLLM MCP Client is a client that allows you to use MCP tools with LiteLLM.
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from .tools import call_openai_tool, load_mcp_tools
|
||||
|
||||
__all__ = ["load_mcp_tools", "call_openai_tool"]
|
||||
@@ -0,0 +1,697 @@
|
||||
"""
|
||||
LiteLLM Proxy uses this MCP Client to connnect to other MCP servers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from mcp import ClientSession, ReadResourceResult, Resource, StdioServerParameters
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
streamable_http_client: Optional[Any] = None
|
||||
try:
|
||||
import mcp.client.streamable_http as streamable_http_module # type: ignore
|
||||
|
||||
streamable_http_client = getattr(
|
||||
streamable_http_module, "streamable_http_client", None
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import (
|
||||
GetPromptRequestParams,
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
ResourceTemplate,
|
||||
TextContent,
|
||||
)
|
||||
from mcp.types import Tool as MCPTool
|
||||
from pydantic import AnyUrl
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import MCP_CLIENT_TIMEOUT
|
||||
from litellm.llms.custom_httpx.http_handler import get_ssl_configuration
|
||||
from litellm.types.llms.custom_http import VerifyTypes
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
MCPStdioConfig,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
)
|
||||
|
||||
|
||||
def to_basic_auth(auth_value: str) -> str:
|
||||
"""Convert auth value to Basic Auth format."""
|
||||
return base64.b64encode(auth_value.encode("utf-8")).decode()
|
||||
|
||||
|
||||
TSessionResult = TypeVar("TSessionResult")
|
||||
|
||||
|
||||
class MCPSigV4Auth(httpx.Auth):
|
||||
"""
|
||||
httpx Auth class that signs each request with AWS SigV4.
|
||||
|
||||
This is used for MCP servers that require AWS SigV4 authentication,
|
||||
such as AWS Bedrock AgentCore MCP servers. httpx calls auth_flow()
|
||||
for every outgoing request, enabling per-request signature computation.
|
||||
"""
|
||||
|
||||
requires_request_body = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_service_name: Optional[str] = None,
|
||||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing botocore to use AWS SigV4 authentication. "
|
||||
"Run 'pip install boto3'."
|
||||
)
|
||||
|
||||
self.service_name = aws_service_name or "bedrock-agentcore"
|
||||
self.region_name = aws_region_name or "us-east-1"
|
||||
|
||||
# Note: os.environ/ prefixed values are already resolved by
|
||||
# ProxyConfig._check_for_os_environ_vars() at config load time.
|
||||
# Values arrive here as plain strings.
|
||||
if aws_access_key_id and aws_secret_access_key:
|
||||
self.credentials = Credentials(
|
||||
access_key=aws_access_key_id,
|
||||
secret_key=aws_secret_access_key,
|
||||
token=aws_session_token,
|
||||
)
|
||||
else:
|
||||
# Fall back to default boto3 credential chain
|
||||
import botocore.session
|
||||
|
||||
session = botocore.session.get_session()
|
||||
self.credentials = session.get_credentials()
|
||||
if self.credentials is None:
|
||||
raise ValueError(
|
||||
"No AWS credentials found. Provide aws_access_key_id and "
|
||||
"aws_secret_access_key, or configure default credentials "
|
||||
"(env vars, ~/.aws/credentials, instance profile)."
|
||||
)
|
||||
|
||||
def auth_flow(
|
||||
self, request: httpx.Request
|
||||
) -> Generator[httpx.Request, httpx.Response, None]:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
|
||||
# Build AWSRequest from the httpx Request.
|
||||
# Pass all request headers so the canonical SigV4 signature covers them.
|
||||
aws_request = AWSRequest(
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
data=request.content,
|
||||
headers=dict(request.headers),
|
||||
)
|
||||
|
||||
# Sign the request — SigV4Auth.add_auth() adds Authorization,
|
||||
# X-Amz-Date, and X-Amz-Security-Token (if session token present).
|
||||
# Host header is derived automatically from the URL.
|
||||
sigv4 = SigV4Auth(self.credentials, self.service_name, self.region_name)
|
||||
sigv4.add_auth(aws_request)
|
||||
|
||||
# Copy SigV4 headers back to the httpx request
|
||||
for header_name, header_value in aws_request.headers.items():
|
||||
request.headers[header_name] = header_value
|
||||
|
||||
yield request
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""
|
||||
MCP Client supporting:
|
||||
SSE and HTTP transports
|
||||
Authentication via Bearer token, Basic Auth, or API Key
|
||||
Tool calling with error handling and result parsing
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_url: str = "",
|
||||
transport_type: MCPTransportType = MCPTransport.http,
|
||||
auth_type: MCPAuthType = None,
|
||||
auth_value: Optional[Union[str, Dict[str, str]]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
stdio_config: Optional[MCPStdioConfig] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
ssl_verify: Optional[VerifyTypes] = None,
|
||||
aws_auth: Optional[httpx.Auth] = None,
|
||||
):
|
||||
self.server_url: str = server_url
|
||||
self.transport_type: MCPTransport = transport_type
|
||||
self.auth_type: MCPAuthType = auth_type
|
||||
self.timeout: float = timeout if timeout is not None else MCP_CLIENT_TIMEOUT
|
||||
self._mcp_auth_value: Optional[Union[str, Dict[str, str]]] = None
|
||||
self.stdio_config: Optional[MCPStdioConfig] = stdio_config
|
||||
self.extra_headers: Optional[Dict[str, str]] = extra_headers
|
||||
self.ssl_verify: Optional[VerifyTypes] = ssl_verify
|
||||
self._aws_auth: Optional[httpx.Auth] = aws_auth
|
||||
# handle the basic auth value if provided
|
||||
if auth_value:
|
||||
self.update_auth_value(auth_value)
|
||||
|
||||
def _create_transport_context(
|
||||
self,
|
||||
) -> Tuple[Any, Optional[httpx.AsyncClient]]:
|
||||
"""
|
||||
Create the appropriate transport context based on transport type.
|
||||
|
||||
Returns:
|
||||
Tuple of (transport_context, http_client).
|
||||
http_client is only set for HTTP transport and needs cleanup.
|
||||
"""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
|
||||
if self.transport_type == MCPTransport.stdio:
|
||||
if not self.stdio_config:
|
||||
raise ValueError("stdio_config is required for stdio transport")
|
||||
server_params = StdioServerParameters(
|
||||
command=self.stdio_config.get("command", ""),
|
||||
args=self.stdio_config.get("args", []),
|
||||
env=self.stdio_config.get("env", {}),
|
||||
)
|
||||
return stdio_client(server_params), None
|
||||
|
||||
if self.transport_type == MCPTransport.sse:
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
return (
|
||||
sse_client(
|
||||
url=self.server_url,
|
||||
timeout=self.timeout,
|
||||
headers=headers,
|
||||
httpx_client_factory=httpx_client_factory,
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# HTTP transport (default)
|
||||
if streamable_http_client is None:
|
||||
raise ImportError(
|
||||
"streamable_http_client is not available. "
|
||||
"Please install mcp with HTTP support."
|
||||
)
|
||||
|
||||
headers = self._get_auth_headers()
|
||||
httpx_client_factory = self._create_httpx_client_factory()
|
||||
verbose_logger.debug("litellm headers for streamable_http_client: %s", headers)
|
||||
http_client = httpx_client_factory(
|
||||
headers=headers,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
)
|
||||
transport_ctx = streamable_http_client(
|
||||
url=self.server_url,
|
||||
http_client=http_client,
|
||||
)
|
||||
return transport_ctx, http_client
|
||||
|
||||
async def _execute_session_operation(
|
||||
self,
|
||||
transport_ctx: Any,
|
||||
operation: Callable[[ClientSession], Awaitable[TSessionResult]],
|
||||
) -> TSessionResult:
|
||||
"""
|
||||
Execute an operation within a transport and session context.
|
||||
|
||||
Handles entering/exiting contexts and running the operation.
|
||||
"""
|
||||
transport = await transport_ctx.__aenter__()
|
||||
try:
|
||||
read_stream, write_stream = transport[0], transport[1]
|
||||
session_ctx = ClientSession(read_stream, write_stream)
|
||||
session = await session_ctx.__aenter__()
|
||||
try:
|
||||
await session.initialize()
|
||||
return await operation(session)
|
||||
finally:
|
||||
try:
|
||||
await session_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during session context exit: {e}")
|
||||
finally:
|
||||
try:
|
||||
await transport_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during transport context exit: {e}")
|
||||
|
||||
async def run_with_session(
|
||||
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
|
||||
) -> TSessionResult:
|
||||
"""Open a session, run the provided coroutine, and clean up."""
|
||||
http_client: Optional[httpx.AsyncClient] = None
|
||||
try:
|
||||
transport_ctx, http_client = self._create_transport_context()
|
||||
return await self._execute_session_operation(transport_ctx, operation)
|
||||
except Exception:
|
||||
verbose_logger.warning(
|
||||
"MCP client run_with_session failed for %s", self.server_url or "stdio"
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if http_client is not None:
|
||||
try:
|
||||
await http_client.aclose()
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during http_client cleanup: {e}")
|
||||
|
||||
def update_auth_value(self, mcp_auth_value: Union[str, Dict[str, str]]):
|
||||
"""
|
||||
Set the authentication header for the MCP client.
|
||||
"""
|
||||
if isinstance(mcp_auth_value, dict):
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
else:
|
||||
if self.auth_type == MCPAuth.basic:
|
||||
# Assuming mcp_auth_value is in format "username:password", convert it when updating
|
||||
mcp_auth_value = to_basic_auth(mcp_auth_value)
|
||||
self._mcp_auth_value = mcp_auth_value
|
||||
|
||||
def _get_auth_headers(self) -> dict:
|
||||
"""Generate authentication headers based on auth type."""
|
||||
headers = {}
|
||||
|
||||
if self._mcp_auth_value:
|
||||
if isinstance(self._mcp_auth_value, str):
|
||||
if self.auth_type == MCPAuth.bearer_token:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.basic:
|
||||
headers["Authorization"] = f"Basic {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.api_key:
|
||||
headers["X-API-Key"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.authorization:
|
||||
headers["Authorization"] = self._mcp_auth_value
|
||||
elif self.auth_type == MCPAuth.oauth2:
|
||||
headers["Authorization"] = f"Bearer {self._mcp_auth_value}"
|
||||
elif self.auth_type == MCPAuth.token:
|
||||
headers["Authorization"] = f"token {self._mcp_auth_value}"
|
||||
elif isinstance(self._mcp_auth_value, dict):
|
||||
headers.update(self._mcp_auth_value)
|
||||
# Note: aws_sigv4 auth is not handled here — SigV4 requires per-request
|
||||
# signing (including the body hash), so it uses httpx.Auth flow instead
|
||||
# of static headers. See MCPSigV4Auth and _create_httpx_client_factory().
|
||||
|
||||
# update the headers with the extra headers
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
|
||||
return headers
|
||||
|
||||
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
|
||||
"""
|
||||
Create a custom httpx client factory that uses LiteLLM's SSL configuration.
|
||||
|
||||
This factory follows the same CA bundle path logic as http_handler.py:
|
||||
1. Check ssl_verify parameter (can be SSLContext, bool, or path to CA bundle)
|
||||
2. Check SSL_VERIFY environment variable
|
||||
3. Check SSL_CERT_FILE environment variable
|
||||
4. Fall back to certifi CA bundle
|
||||
"""
|
||||
|
||||
def factory(
|
||||
*,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: Optional[httpx.Timeout] = None,
|
||||
auth: Optional[httpx.Auth] = None,
|
||||
) -> httpx.AsyncClient:
|
||||
"""Create an httpx.AsyncClient with LiteLLM's SSL configuration."""
|
||||
# Get unified SSL configuration using the same logic as http_handler.py
|
||||
ssl_config = get_ssl_configuration(self.ssl_verify)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"MCP client using SSL configuration: {type(ssl_config).__name__}"
|
||||
)
|
||||
|
||||
# Use SigV4 auth if configured and no explicit auth provided.
|
||||
# The MCP SDK's sse_client and streamable_http_client call this
|
||||
# factory without passing auth=, so self._aws_auth is used.
|
||||
# For non-SigV4 clients, self._aws_auth is None — no behavior change.
|
||||
effective_auth = auth if auth is not None else self._aws_auth
|
||||
|
||||
return httpx.AsyncClient(
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
auth=effective_auth,
|
||||
verify=ssl_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
return factory
|
||||
|
||||
async def list_tools(self) -> List[MCPTool]:
|
||||
"""List available tools from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_tools_operation(session: ClientSession):
|
||||
return await session.list_tools()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_tools_operation)
|
||||
tool_count = len(result.tools)
|
||||
tool_names = [tool.name for tool in result.tools]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {tool_count} tools from {self.server_url or 'stdio'}: {tool_names}"
|
||||
)
|
||||
return result.tools
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_tools was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.exception(
|
||||
f"MCP client list_tools failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
host_progress_callback: Optional[Callable] = None,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an MCP Tool.
|
||||
"""
|
||||
verbose_logger.info(
|
||||
f"MCP client calling tool '{call_tool_request_params.name}' with arguments: {call_tool_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def on_progress(
|
||||
progress: float, total: float | None, message: str | None
|
||||
):
|
||||
percentage = (progress / total * 100) if total else 0
|
||||
verbose_logger.info(
|
||||
f"MCP Tool '{call_tool_request_params.name}' progress: "
|
||||
f"{progress}/{total} ({percentage:.0f}%) - {message or ''}"
|
||||
)
|
||||
|
||||
# Forward to Host if callback provided
|
||||
if host_progress_callback:
|
||||
try:
|
||||
await host_progress_callback(progress, total)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to forward to Host: {e}")
|
||||
|
||||
async def _call_tool_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending tool call to session")
|
||||
return await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
progress_callback=on_progress,
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result = await self.run_with_session(_call_tool_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client tool call '{call_tool_request_params.name}' completed successfully"
|
||||
)
|
||||
return tool_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client tool call was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client tool call traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client call_tool failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Tool: {call_tool_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
# Return a default error result instead of raising
|
||||
return MCPCallToolResult(
|
||||
content=[
|
||||
TextContent(type="text", text=f"{error_type}: {str(e)}")
|
||||
], # Empty content for error case
|
||||
isError=True,
|
||||
)
|
||||
|
||||
async def list_prompts(self) -> List[Prompt]:
|
||||
"""List available prompts from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing tools from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_prompts_operation(session: ClientSession):
|
||||
return await session.list_prompts()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_prompts_operation)
|
||||
prompt_count = len(result.prompts)
|
||||
prompt_names = [prompt.name for prompt in result.prompts]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {prompt_count} tools from {self.server_url or 'stdio'}: {prompt_names}"
|
||||
)
|
||||
return result.prompts
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_prompts was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_prompts failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_tools - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def get_prompt(
|
||||
self, get_prompt_request_params: GetPromptRequestParams
|
||||
) -> GetPromptResult:
|
||||
"""Fetch a prompt definition from the MCP server."""
|
||||
verbose_logger.info(
|
||||
f"MCP client fetching prompt '{get_prompt_request_params.name}' with arguments: {get_prompt_request_params.arguments}"
|
||||
)
|
||||
|
||||
async def _get_prompt_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending get_prompt request to session")
|
||||
return await session.get_prompt(
|
||||
name=get_prompt_request_params.name,
|
||||
arguments=get_prompt_request_params.arguments,
|
||||
)
|
||||
|
||||
try:
|
||||
get_prompt_result = await self.run_with_session(_get_prompt_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client get_prompt '{get_prompt_request_params.name}' completed successfully"
|
||||
)
|
||||
return get_prompt_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client get_prompt was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client get_prompt traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client get_prompt failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Prompt: {get_prompt_request_params.name}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during get_prompt - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""List available resources from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resources from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resources_operation(session: ClientSession):
|
||||
return await session.list_resources()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resources_operation)
|
||||
resource_count = len(result.resources)
|
||||
resource_names = [resource.name for resource in result.resources]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_count} resources from {self.server_url or 'stdio'}: {resource_names}"
|
||||
)
|
||||
return result.resources
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resources was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resources failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resources - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def list_resource_templates(self) -> list[ResourceTemplate]:
|
||||
"""List available resource templates from the server."""
|
||||
verbose_logger.debug(
|
||||
f"MCP client listing resource templates from {self.server_url or 'stdio'}"
|
||||
)
|
||||
|
||||
async def _list_resource_templates_operation(session: ClientSession):
|
||||
return await session.list_resource_templates()
|
||||
|
||||
try:
|
||||
result = await self.run_with_session(_list_resource_templates_operation)
|
||||
resource_template_count = len(result.resourceTemplates)
|
||||
resource_template_names = [
|
||||
resourceTemplate.name for resourceTemplate in result.resourceTemplates
|
||||
]
|
||||
verbose_logger.info(
|
||||
f"MCP client listed {resource_template_count} resource templates from {self.server_url or 'stdio'}: {resource_template_names}"
|
||||
)
|
||||
return result.resourceTemplates
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client list_resource_templates was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client list_resource_templates failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during list_resource_templates - "
|
||||
"the MCP server may have crashed, disconnected, or timed out"
|
||||
)
|
||||
|
||||
# Return empty list instead of raising to allow graceful degradation
|
||||
return []
|
||||
|
||||
async def read_resource(self, url: AnyUrl) -> ReadResourceResult:
|
||||
"""Fetch resource contents from the MCP server."""
|
||||
verbose_logger.info(f"MCP client fetching resource '{url}'")
|
||||
|
||||
async def _read_resource_operation(session: ClientSession):
|
||||
verbose_logger.debug("MCP client sending read_resource request to session")
|
||||
return await session.read_resource(url)
|
||||
|
||||
try:
|
||||
read_resource_result = await self.run_with_session(_read_resource_operation)
|
||||
verbose_logger.info(
|
||||
f"MCP client read_resource '{url}' completed successfully"
|
||||
)
|
||||
return read_resource_result
|
||||
except asyncio.CancelledError:
|
||||
verbose_logger.warning("MCP client read_resource was cancelled")
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_trace = traceback.format_exc()
|
||||
verbose_logger.debug(f"MCP client read_resource traceback:\n{error_trace}")
|
||||
|
||||
# Log detailed error information
|
||||
error_type = type(e).__name__
|
||||
verbose_logger.error(
|
||||
f"MCP client read_resource failed - "
|
||||
f"Error Type: {error_type}, "
|
||||
f"Error: {str(e)}, "
|
||||
f"Url: {url}, "
|
||||
f"Server: {self.server_url or 'stdio'}, "
|
||||
f"Transport: {self.transport_type}"
|
||||
)
|
||||
|
||||
# Check if it's a stream/connection error
|
||||
if "BrokenResourceError" in error_type or "Broken" in error_type:
|
||||
verbose_logger.error(
|
||||
"MCP client detected broken connection/stream during read_resource - "
|
||||
"the MCP server may have crashed, disconnected, or timed out."
|
||||
)
|
||||
|
||||
raise
|
||||
@@ -0,0 +1,159 @@
|
||||
import json
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import CallToolRequestParams as MCPCallToolRequestParams
|
||||
from mcp.types import CallToolResult as MCPCallToolResult
|
||||
from mcp.types import Tool as MCPTool
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
from openai.types.responses.function_tool_param import FunctionToolParam
|
||||
from openai.types.shared_params.function_definition import FunctionDefinition
|
||||
|
||||
from litellm.types.utils import ChatCompletionMessageToolCall
|
||||
|
||||
|
||||
########################################################
|
||||
# List MCP Tool functions
|
||||
########################################################
|
||||
def transform_mcp_tool_to_openai_tool(mcp_tool: MCPTool) -> ChatCompletionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=FunctionDefinition(
|
||||
name=mcp_tool.name,
|
||||
description=mcp_tool.description or "",
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _normalize_mcp_input_schema(input_schema: dict) -> dict:
|
||||
"""
|
||||
Normalize MCP input schema to ensure it's valid for OpenAI function calling.
|
||||
|
||||
OpenAI requires that function parameters have:
|
||||
- type: 'object'
|
||||
- properties: dict (can be empty)
|
||||
- additionalProperties: false (recommended)
|
||||
"""
|
||||
if not input_schema:
|
||||
return {"type": "object", "properties": {}, "additionalProperties": False}
|
||||
|
||||
# Make a copy to avoid modifying the original
|
||||
normalized_schema = dict(input_schema)
|
||||
|
||||
# Ensure type is 'object'
|
||||
if "type" not in normalized_schema:
|
||||
normalized_schema["type"] = "object"
|
||||
|
||||
# Ensure properties exists (can be empty)
|
||||
if "properties" not in normalized_schema:
|
||||
normalized_schema["properties"] = {}
|
||||
|
||||
# Add additionalProperties if not present (recommended by OpenAI)
|
||||
if "additionalProperties" not in normalized_schema:
|
||||
normalized_schema["additionalProperties"] = False
|
||||
|
||||
return normalized_schema
|
||||
|
||||
|
||||
def transform_mcp_tool_to_openai_responses_api_tool(
|
||||
mcp_tool: MCPTool,
|
||||
) -> FunctionToolParam:
|
||||
"""Convert an MCP tool to an OpenAI Responses API tool."""
|
||||
normalized_parameters = _normalize_mcp_input_schema(mcp_tool.inputSchema)
|
||||
|
||||
return FunctionToolParam(
|
||||
name=mcp_tool.name,
|
||||
parameters=normalized_parameters,
|
||||
strict=False,
|
||||
type="function",
|
||||
description=mcp_tool.description or "",
|
||||
)
|
||||
|
||||
|
||||
async def load_mcp_tools(
|
||||
session: ClientSession, format: Literal["mcp", "openai"] = "mcp"
|
||||
) -> Union[List[MCPTool], List[ChatCompletionToolParam]]:
|
||||
"""
|
||||
Load all available MCP tools
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
format: The format to convert the tools to
|
||||
By default, the tools are returned in MCP format.
|
||||
|
||||
If format is set to "openai", the tools are converted to OpenAI API compatible tools.
|
||||
"""
|
||||
tools = await session.list_tools()
|
||||
if format == "openai":
|
||||
return [
|
||||
transform_mcp_tool_to_openai_tool(mcp_tool=tool) for tool in tools.tools
|
||||
]
|
||||
return tools.tools
|
||||
|
||||
|
||||
########################################################
|
||||
# Call MCP Tool functions
|
||||
########################################################
|
||||
|
||||
|
||||
async def call_mcp_tool(
|
||||
session: ClientSession,
|
||||
call_tool_request_params: MCPCallToolRequestParams,
|
||||
) -> MCPCallToolResult:
|
||||
"""Call an MCP tool."""
|
||||
tool_result = await session.call_tool(
|
||||
name=call_tool_request_params.name,
|
||||
arguments=call_tool_request_params.arguments,
|
||||
)
|
||||
return tool_result
|
||||
|
||||
|
||||
def _get_function_arguments(function: FunctionDefinition) -> dict:
|
||||
"""Helper to safely get and parse function arguments."""
|
||||
arguments = function.get("arguments", {})
|
||||
if isinstance(arguments, str):
|
||||
try:
|
||||
arguments = json.loads(arguments)
|
||||
except json.JSONDecodeError:
|
||||
arguments = {}
|
||||
return arguments if isinstance(arguments, dict) else {}
|
||||
|
||||
|
||||
def transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool: Union[ChatCompletionMessageToolCall, Dict],
|
||||
) -> MCPCallToolRequestParams:
|
||||
"""Convert an OpenAI ChatCompletionMessageToolCall to an MCP CallToolRequestParams."""
|
||||
function = openai_tool["function"]
|
||||
return MCPCallToolRequestParams(
|
||||
name=function["name"],
|
||||
arguments=_get_function_arguments(function),
|
||||
)
|
||||
|
||||
|
||||
async def call_openai_tool(
|
||||
session: ClientSession,
|
||||
openai_tool: ChatCompletionMessageToolCall,
|
||||
) -> MCPCallToolResult:
|
||||
"""
|
||||
Call an OpenAI tool using MCP client.
|
||||
|
||||
Args:
|
||||
session: The MCP session to use
|
||||
openai_tool: The OpenAI tool to call. You can get this from the `choices[0].message.tool_calls[0]` of the response from the OpenAI API.
|
||||
Returns:
|
||||
The result of the MCP tool call.
|
||||
"""
|
||||
mcp_tool_call_request_params = (
|
||||
transform_openai_tool_call_request_to_mcp_tool_call_request(
|
||||
openai_tool=openai_tool,
|
||||
)
|
||||
)
|
||||
return await call_mcp_tool(
|
||||
session=session,
|
||||
call_tool_request_params=mcp_tool_call_request_params,
|
||||
)
|
||||
984
llm-gateway-competitors/litellm-wheel-src/litellm/files/main.py
Normal file
984
llm-gateway-competitors/litellm-wheel-src/litellm/files/main.py
Normal file
@@ -0,0 +1,984 @@
|
||||
"""
|
||||
Main File for Files API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/files
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import time
|
||||
import uuid as uuid_module
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
# Type aliases for provider parameters
|
||||
FileCreateProvider = Literal[
|
||||
"openai",
|
||||
"azure",
|
||||
"gemini",
|
||||
"vertex_ai",
|
||||
"bedrock",
|
||||
"hosted_vllm",
|
||||
"manus",
|
||||
"anthropic",
|
||||
]
|
||||
FileRetrieveProvider = Literal[
|
||||
"openai", "azure", "gemini", "vertex_ai", "hosted_vllm", "manus", "anthropic"
|
||||
]
|
||||
FileDeleteProvider = Literal["openai", "azure", "gemini", "manus", "anthropic"]
|
||||
FileListProvider = Literal["openai", "azure", "manus", "anthropic"]
|
||||
FileContentProvider = Literal[
|
||||
"openai", "azure", "vertex_ai", "bedrock", "hosted_vllm", "anthropic", "manus"
|
||||
]
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret_str
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure.common_utils import get_azure_credentials
|
||||
from litellm.llms.azure.files.handler import AzureOpenAIFilesAPI
|
||||
from litellm.llms.bedrock.files.handler import BedrockFilesHandler
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.llms.openai.common_utils import get_openai_credentials
|
||||
from litellm.llms.openai.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.llms.vertex_ai.files.handler import VertexAIFilesHandler
|
||||
from litellm.types.llms.openai import (
|
||||
CreateFileRequest,
|
||||
FileContentRequest,
|
||||
FileExpiresAfter,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.router import *
|
||||
from litellm.types.utils import (
|
||||
OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS,
|
||||
LlmProviders,
|
||||
)
|
||||
from litellm.utils import (
|
||||
ProviderConfigManager,
|
||||
client,
|
||||
get_litellm_params,
|
||||
supports_httpx_timeout,
|
||||
)
|
||||
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_files_instance = OpenAIFilesAPI()
|
||||
azure_files_instance = AzureOpenAIFilesAPI()
|
||||
vertex_ai_files_instance = VertexAIFilesHandler()
|
||||
bedrock_files_instance = BedrockFilesHandler()
|
||||
#################################################
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
|
||||
expires_after: Optional[FileExpiresAfter] = None,
|
||||
custom_llm_provider: FileCreateProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Async: Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_file"] = True
|
||||
|
||||
call_args = {
|
||||
"file": file,
|
||||
"purpose": purpose,
|
||||
"expires_after": expires_after,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"extra_headers": extra_headers,
|
||||
"extra_body": extra_body,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(create_file, **call_args)
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def create_file(
|
||||
file: FileTypes,
|
||||
purpose: Literal["assistants", "batch", "fine-tune", "messages"],
|
||||
expires_after: Optional[FileExpiresAfter] = None,
|
||||
custom_llm_provider: Optional[FileCreateProvider] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[OpenAIFileObject, Coroutine[Any, Any, OpenAIFileObject]]:
|
||||
"""
|
||||
Files are used to upload documents that can be used with features like Assistants, Fine-tuning, and Batch API.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
|
||||
Specify either provider_list or custom_llm_provider.
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_file", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = dict(**kwargs)
|
||||
logging_obj = cast(
|
||||
Optional[LiteLLMLoggingObj], kwargs.get("litellm_logging_obj")
|
||||
)
|
||||
if logging_obj is None:
|
||||
raise ValueError("logging_obj is required")
|
||||
client = kwargs.get("client")
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
if expires_after is not None:
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
expires_after=expires_after,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
else:
|
||||
_create_file_request = CreateFileRequest(
|
||||
file=file,
|
||||
purpose=purpose,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
response = base_llm_http_handler.create_file(
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
create_file_data=_create_file_request,
|
||||
headers=extra_headers or {},
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
create_file_data=_create_file_request,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_file'. Only ['openai', 'azure', 'vertex_ai', 'manus', 'anthropic'] are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_file", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def afile_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileRetrieveProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_retrieve,
|
||||
file_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response
|
||||
|
||||
return OpenAIFileObject(**response.model_dump())
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_retrieve(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileRetrieveProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileObject:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.retrieve_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.retrieve_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
)
|
||||
else:
|
||||
# Try using provider config pattern (for Manus, Bedrock, etc.)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_retrieve" if _is_async else "file_retrieve",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
client = kwargs.get("client")
|
||||
response = base_llm_http_handler.retrieve_file(
|
||||
file_id=file_id,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_retrieve'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
return cast(FileObject, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# Delete file
|
||||
@client
|
||||
async def afile_delete(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileDeleteProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Coroutine[Any, Any, FileObject]:
|
||||
"""
|
||||
Async: Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
model = kwargs.pop("model", None)
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_delete,
|
||||
file_id,
|
||||
model,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return cast(FileDeleted, response) # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_delete(
|
||||
file_id: str,
|
||||
model: Optional[str] = None,
|
||||
custom_llm_provider: Union[FileDeleteProvider, str] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Delete file
|
||||
|
||||
LiteLLM Equivalent of DELETE https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
try:
|
||||
if model is not None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
client = kwargs.get("client")
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.delete_file(
|
||||
file_id=file_id,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.delete_file(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_id=file_id,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
else:
|
||||
# Try using provider config pattern (for Manus, Bedrock, etc.)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_delete" if _is_async else "file_delete",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.delete_file(
|
||||
file_id=file_id,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_delete'. Only 'openai', 'azure', 'gemini', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return cast(FileDeleted, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
# List files
|
||||
@client
|
||||
async def afile_list(
|
||||
custom_llm_provider: FileListProvider = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["is_async"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_list,
|
||||
custom_llm_provider,
|
||||
purpose,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_list(
|
||||
custom_llm_provider: FileListProvider = "openai",
|
||||
purpose: Optional[str] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List files
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("is_async", False) is True
|
||||
|
||||
# Check if provider has a custom files config (e.g., Manus, Bedrock, Vertex AI)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_list" if _is_async else "file_list",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id", "")),
|
||||
)
|
||||
|
||||
client = kwargs.get("client")
|
||||
response = base_llm_http_handler.list_files(
|
||||
purpose=purpose,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
elif custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.list_files(
|
||||
purpose=purpose,
|
||||
_is_async=_is_async,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.list_files(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
purpose=purpose,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_list'. Only 'openai', 'azure', 'manus', and 'anthropic' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="file_list", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def afile_content(
|
||||
file_id: str,
|
||||
custom_llm_provider: FileContentProvider = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
LiteLLM Equivalent of GET https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["afile_content"] = True
|
||||
model = kwargs.pop("model", None)
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
file_content,
|
||||
file_id,
|
||||
model,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def file_content(
|
||||
file_id: str,
|
||||
model: Optional[str] = None,
|
||||
custom_llm_provider: Optional[Union[FileContentProvider, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]]:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
LiteLLM Equivalent of POST: POST https://api.openai.com/v1/files
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
client = kwargs.get("client")
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
try:
|
||||
if model is not None:
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(cast(str, custom_llm_provider)) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_file_content_request = FileContentRequest(
|
||||
file_id=file_id,
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
_is_async = kwargs.pop("afile_content", False) is True
|
||||
|
||||
# Check if provider has a custom files config (e.g., Anthropic, Manus)
|
||||
provider_config = ProviderConfigManager.get_provider_files_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
litellm_params_dict["api_key"] = optional_params.api_key
|
||||
litellm_params_dict["api_base"] = optional_params.api_base
|
||||
|
||||
logging_obj = kwargs.get("litellm_logging_obj")
|
||||
if logging_obj is None:
|
||||
logging_obj = LiteLLMLoggingObj(
|
||||
model="",
|
||||
messages=[],
|
||||
stream=False,
|
||||
call_type="afile_content" if _is_async else "file_content",
|
||||
start_time=time.time(),
|
||||
litellm_call_id=kwargs.get(
|
||||
"litellm_call_id", str(uuid_module.uuid4())
|
||||
),
|
||||
function_id=str(kwargs.get("id") or ""),
|
||||
)
|
||||
|
||||
response = base_llm_http_handler.retrieve_file_content(
|
||||
file_content_request=_file_content_request,
|
||||
provider_config=provider_config,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=extra_headers or {},
|
||||
logging_obj=logging_obj,
|
||||
_is_async=_is_async,
|
||||
client=(
|
||||
client
|
||||
if client is not None
|
||||
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
return response
|
||||
|
||||
if custom_llm_provider in OPENAI_COMPATIBLE_BATCH_AND_FILES_PROVIDERS:
|
||||
openai_creds = get_openai_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
response = openai_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=openai_creds.api_base,
|
||||
api_key=openai_creds.api_key,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
organization=openai_creds.organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
azure_creds = get_azure_credentials(
|
||||
api_base=optional_params.api_base,
|
||||
api_key=optional_params.api_key,
|
||||
api_version=optional_params.api_version,
|
||||
)
|
||||
response = azure_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
api_base=azure_creds.api_base,
|
||||
api_key=azure_creds.api_key,
|
||||
api_version=azure_creds.api_version,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
file_content_request=_file_content_request,
|
||||
client=client,
|
||||
litellm_params=litellm_params_dict,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
|
||||
response = vertex_ai_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=api_base,
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "bedrock":
|
||||
response = bedrock_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
file_content_request=_file_content_request,
|
||||
api_base=optional_params.api_base,
|
||||
optional_params=litellm_params_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'file_content'. Supported providers are 'openai', 'azure', 'vertex_ai', 'bedrock', 'manus', 'anthropic'.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,32 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.types.llms.openai import CreateFileRequest
|
||||
from litellm.types.utils import ExtractedFileData
|
||||
|
||||
|
||||
class FilesAPIUtils:
|
||||
"""
|
||||
Utils for files API interface on litellm
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def is_batch_jsonl_file(
|
||||
create_file_data: CreateFileRequest, extracted_file_data: ExtractedFileData
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the file is a batch jsonl file
|
||||
"""
|
||||
return (
|
||||
create_file_data.get("purpose") == "batch"
|
||||
and FilesAPIUtils.valid_content_type(
|
||||
extracted_file_data.get("content_type")
|
||||
)
|
||||
and extracted_file_data.get("content") is not None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def valid_content_type(content_type: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if the content type is valid
|
||||
"""
|
||||
return content_type in set(["application/jsonl", "application/octet-stream"])
|
||||
@@ -0,0 +1,826 @@
|
||||
"""
|
||||
Main File for Fine Tuning API implementation
|
||||
|
||||
https://platform.openai.com/docs/api-reference/fine-tuning
|
||||
|
||||
- fine_tuning.jobs.create()
|
||||
- fine_tuning.jobs.list()
|
||||
- client.fine_tuning.jobs.list_events()
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.azure.fine_tuning.handler import AzureOpenAIFineTuningAPI
|
||||
from litellm.llms.openai.fine_tuning.handler import OpenAIFineTuningAPI
|
||||
from litellm.llms.vertex_ai.fine_tuning.handler import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import FineTuningJobCreate, Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.types.utils import LiteLLMFineTuningJob
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_fine_tuning_apis_instance = OpenAIFineTuningAPI()
|
||||
azure_fine_tuning_apis_instance = AzureOpenAIFineTuningAPI()
|
||||
vertex_fine_tuning_apis_instance = VertexFineTuningAPI()
|
||||
#################################################
|
||||
|
||||
|
||||
def _prepare_azure_extra_body(
|
||||
extra_body: Optional[Dict[str, Any]],
|
||||
kwargs: Dict[str, Any],
|
||||
azure_specific_hyperparams: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Prepare extra_body for Azure fine-tuning API by combining Azure-specific parameters.
|
||||
|
||||
Azure fine-tuning API accepts additional parameters beyond the standard OpenAI spec:
|
||||
- trainingType: Type of training (e.g., 1 for supervised fine-tuning)
|
||||
- prompt_loss_weight: Weight for prompt loss in training
|
||||
|
||||
These parameters must be passed in the extra_body field when calling the Azure OpenAI SDK.
|
||||
|
||||
Args:
|
||||
extra_body: Optional existing extra_body dict
|
||||
kwargs: Request kwargs that may contain Azure-specific parameters
|
||||
azure_specific_hyperparams: Dict of Azure-specific hyperparameters already extracted
|
||||
|
||||
Returns:
|
||||
Dict containing all Azure-specific parameters to be passed in extra_body
|
||||
"""
|
||||
if extra_body is None:
|
||||
extra_body = {}
|
||||
|
||||
# Azure-specific root-level parameters
|
||||
azure_specific_params = ["trainingType"]
|
||||
for param in azure_specific_params:
|
||||
if param in kwargs:
|
||||
extra_body[param] = kwargs[param]
|
||||
|
||||
# Add Azure-specific hyperparameters
|
||||
if azure_specific_hyperparams:
|
||||
extra_body.update(azure_specific_hyperparams)
|
||||
|
||||
return extra_body
|
||||
|
||||
|
||||
@client
|
||||
async def acreate_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Creates and executes a batch from an uploaded file of request
|
||||
|
||||
"""
|
||||
verbose_logger.debug(
|
||||
"inside acreate_fine_tuning_job model=%s and kwargs=%s", model, kwargs
|
||||
)
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acreate_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
create_fine_tuning_job,
|
||||
model,
|
||||
training_file,
|
||||
hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _build_fine_tuning_job_data(
|
||||
model, training_file, hyperparameters, suffix, validation_file, integrations, seed
|
||||
):
|
||||
return FineTuningJobCreate(
|
||||
model=model,
|
||||
training_file=training_file,
|
||||
hyperparameters=hyperparameters,
|
||||
suffix=suffix,
|
||||
validation_file=validation_file,
|
||||
integrations=integrations,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def _resolve_fine_tuning_timeout(
|
||||
timeout: Any,
|
||||
custom_llm_provider: str,
|
||||
) -> Union[float, httpx.Timeout]:
|
||||
"""Normalise a raw timeout value to a float (seconds) or httpx.Timeout for fine-tuning calls."""
|
||||
timeout = timeout or 600.0
|
||||
if isinstance(timeout, httpx.Timeout):
|
||||
if not supports_httpx_timeout(custom_llm_provider):
|
||||
return float(timeout.read or 600)
|
||||
return timeout
|
||||
return float(timeout)
|
||||
|
||||
|
||||
@client
|
||||
def create_fine_tuning_job(
|
||||
model: str,
|
||||
training_file: str,
|
||||
hyperparameters: Optional[dict] = {},
|
||||
suffix: Optional[str] = None,
|
||||
validation_file: Optional[str] = None,
|
||||
integrations: Optional[List[str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Creates a fine-tuning job which begins the process of creating a new model from a given dataset.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
_is_async = kwargs.pop("acreate_fine_tuning_job", False) is True
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
# handle hyperparameters
|
||||
hyperparameters = hyperparameters or {} # original hyperparameters
|
||||
|
||||
# For Azure, extract Azure-specific hyperparameters before creating OpenAI-spec hyperparameters
|
||||
azure_specific_hyperparams = {}
|
||||
if custom_llm_provider == "azure":
|
||||
azure_hyperparameter_keys = ["prompt_loss_weight"]
|
||||
for key in azure_hyperparameter_keys:
|
||||
if key in hyperparameters:
|
||||
azure_specific_hyperparams[key] = hyperparameters.pop(key)
|
||||
|
||||
_oai_hyperparameters: Hyperparameters = Hyperparameters(
|
||||
**hyperparameters
|
||||
) # Typed Hyperparameters for OpenAI Spec
|
||||
timeout = _resolve_fine_tuning_timeout(
|
||||
optional_params.timeout or kwargs.get("request_timeout", 600),
|
||||
custom_llm_provider,
|
||||
)
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get(
|
||||
"client", None
|
||||
), # note, when we add this to `GenericLiteLLMParams` it impacts a lot of other tests + linting
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
# Prepare Azure-specific parameters for extra_body
|
||||
extra_body = _prepare_azure_extra_body(
|
||||
extra_body, kwargs, azure_specific_hyperparams
|
||||
)
|
||||
|
||||
create_fine_tuning_job_data_dict = _build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
# Add extra_body if it has Azure-specific parameters
|
||||
if extra_body:
|
||||
create_fine_tuning_job_data_dict["extra_body"] = extra_body
|
||||
|
||||
response = azure_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data_dict,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
api_base = optional_params.api_base or ""
|
||||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
response = vertex_fine_tuning_apis_instance.create_fine_tuning_job(
|
||||
_is_async=_is_async,
|
||||
create_fine_tuning_job_data=_build_fine_tuning_job_data(
|
||||
model,
|
||||
training_file,
|
||||
_oai_hyperparameters,
|
||||
suffix,
|
||||
validation_file,
|
||||
integrations,
|
||||
seed,
|
||||
),
|
||||
vertex_credentials=vertex_credentials,
|
||||
vertex_project=vertex_ai_project,
|
||||
vertex_location=vertex_ai_location,
|
||||
timeout=timeout,
|
||||
api_base=api_base,
|
||||
kwargs=kwargs,
|
||||
original_hyperparameters=hyperparameters,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
verbose_logger.error("got exception in create_fine_tuning_job=%s", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def acancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Immediately cancel a fine-tune job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["acancel_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
cancel_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def cancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Immediately cancel a fine-tune job.
|
||||
|
||||
Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete
|
||||
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("acancel_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async: List your organization's fine-tuning jobs
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["alist_fine_tuning_jobs"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
list_fine_tuning_jobs,
|
||||
after,
|
||||
limit,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
List your organization's fine-tuning jobs
|
||||
|
||||
Params:
|
||||
|
||||
- after: Optional[str] = None, Identifier for the last job from the previous pagination request.
|
||||
- limit: Optional[int] = None, Number of fine-tuning jobs to retrieve. Defaults to 20
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("alist_fine_tuning_jobs", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
after=after,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
async def aretrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> LiteLLMFineTuningJob:
|
||||
"""
|
||||
Async: Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
kwargs["aretrieve_fine_tuning_job"] = True
|
||||
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(
|
||||
retrieve_fine_tuning_job,
|
||||
fine_tuning_job_id,
|
||||
custom_llm_provider,
|
||||
extra_headers,
|
||||
extra_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
@client
|
||||
def retrieve_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
"""
|
||||
Get info about a fine-tuning job.
|
||||
"""
|
||||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
elif timeout is not None and not isinstance(timeout, httpx.Timeout):
|
||||
timeout = float(timeout) # type: ignore
|
||||
elif timeout is None:
|
||||
timeout = 600.0
|
||||
|
||||
_is_async = kwargs.pop("aretrieve_fine_tuning_job", False) is True
|
||||
|
||||
# OpenAI
|
||||
if custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
organization = (
|
||||
optional_params.organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None
|
||||
)
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
response = openai_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=optional_params.api_version,
|
||||
organization=organization,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
client=kwargs.get("client", None),
|
||||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
if extra_body is not None:
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.retrieve_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
timeout=timeout,
|
||||
max_retries=optional_params.max_retries,
|
||||
_is_async=_is_async,
|
||||
organization=optional_params.organization,
|
||||
)
|
||||
else:
|
||||
raise litellm.exceptions.BadRequestError(
|
||||
message="LiteLLM doesn't support {} for 'retrieve_fine_tuning_job'. Only 'openai' and 'azure' are supported.".format(
|
||||
custom_llm_provider
|
||||
),
|
||||
model="n/a",
|
||||
llm_provider=custom_llm_provider,
|
||||
response=httpx.Response(
|
||||
status_code=400,
|
||||
content="Unsupported provider",
|
||||
request=httpx.Request(method="retrieve_fine_tuning_job", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,123 @@
|
||||
# LiteLLM Google GenAI Interface
|
||||
|
||||
Interface to interact with Google GenAI Functions in the native Google interface format.
|
||||
|
||||
## Overview
|
||||
|
||||
This module provides a native interface to Google's Generative AI API, allowing you to use Google's content generation capabilities with both streaming and non-streaming modes, in both synchronous and asynchronous contexts.
|
||||
|
||||
## Available Functions
|
||||
|
||||
### Non-Streaming Functions
|
||||
|
||||
- `generate_content()` - Synchronous content generation
|
||||
- `agenerate_content()` - Asynchronous content generation
|
||||
|
||||
### Streaming Functions
|
||||
|
||||
- `generate_content_stream()` - Synchronous streaming content generation
|
||||
- `agenerate_content_stream()` - Asynchronous streaming content generation
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Non-Streaming Usage
|
||||
|
||||
```python
|
||||
from litellm.google_genai import generate_content, agenerate_content
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
# Synchronous usage
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Hello, can you tell me a short joke?")
|
||||
],
|
||||
)
|
||||
|
||||
response = generate_content(
|
||||
contents=contents,
|
||||
model="gemini-pro", # or your preferred model
|
||||
# Add other model-specific parameters as needed
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
### Async Non-Streaming Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from litellm.google_genai import agenerate_content
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
async def main():
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Hello, can you tell me a short joke?")
|
||||
],
|
||||
)
|
||||
|
||||
response = await agenerate_content(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
# Add other model-specific parameters as needed
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
# Run the async function
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Streaming Usage
|
||||
|
||||
```python
|
||||
from litellm.google_genai import generate_content_stream
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
# Synchronous streaming
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Tell me a story about space exploration")
|
||||
],
|
||||
)
|
||||
|
||||
for chunk in generate_content_stream(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
):
|
||||
print(f"Chunk: {chunk}")
|
||||
```
|
||||
|
||||
### Async Streaming Usage
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from litellm.google_genai import agenerate_content_stream
|
||||
from google.genai.types import ContentDict, PartDict
|
||||
|
||||
async def main():
|
||||
contents = ContentDict(
|
||||
parts=[
|
||||
PartDict(text="Tell me a story about space exploration")
|
||||
],
|
||||
)
|
||||
|
||||
async for chunk in agenerate_content_stream(
|
||||
contents=contents,
|
||||
model="gemini-pro",
|
||||
):
|
||||
print(f"Async chunk: {chunk}")
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
|
||||
## Testing
|
||||
|
||||
This module includes comprehensive tests covering:
|
||||
- Sync and async non-streaming requests
|
||||
- Sync and async streaming requests
|
||||
- Response validation
|
||||
- Error handling scenarios
|
||||
|
||||
See `tests/unified_google_tests/base_google_test.py` for test implementation examples.
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
This allows using Google GenAI model in their native interface.
|
||||
|
||||
This module provides generate_content functionality for Google GenAI models.
|
||||
"""
|
||||
|
||||
from .main import (
|
||||
agenerate_content,
|
||||
agenerate_content_stream,
|
||||
generate_content,
|
||||
generate_content_stream,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"generate_content",
|
||||
"agenerate_content",
|
||||
"generate_content_stream",
|
||||
"agenerate_content_stream",
|
||||
]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user