chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,170 @@
# Dynamic Rate Limiter v3 - Saturation-Aware Priority-Based Rate Limiting
## Overview
The v3 dynamic rate limiter implements saturation-aware rate limiting with priority-based allocation. It balances resource efficiency (allowing unused capacity to be borrowed) with fairness guarantees (enforcing priorities during high load).
**Key Behavior:**
- When system is under 80% capacity: Generous mode - allows priority borrowing
- When system is at/above 80% capacity: Strict mode - enforces normalized priority limits
## How It Works
### Flow Diagram
```
┌─────────────────────────────────────────────────────────────┐
│ Incoming Request │
└────────────────────────┬────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ 1. Check Model Saturation │
│ - Query v3 limiter's Redis counters │
│ - Calculate: current_usage / capacity │
│ - Returns: 0.0 (empty) to 1.0+ (saturated) │
└────────────────────────┬────────────────────────────────────┘
┌────────┴────────┐
│ Saturation? │
└────────┬────────┘
┌───────────────┴───────────────┐
│ │
▼ ▼
< 80% (Generous) >= 80% (Strict)
│ │
▼ ▼
┌─────────────────────┐ ┌─────────────────────┐
│ Generous Mode │ │ Strict Mode │
│ │ │ │
│ - Enforce model- │ │ - Normalize │
│ wide capacity │ │ priority weights │
│ - No priority │ │ (if over 1.0) │
│ restrictions │ │ │
│ - Allows borrowing │ │ - Create priority- │
│ │ │ specific │
│ - First-come- │ │ descriptors │
│ first-served │ │ │
│ until capacity │ │ - Enforce strict │
│ │ │ limits per │
│ │ │ priority │
└──────────┬──────────┘ └──────────┬──────────┘
│ │
│ ▼
│ ┌──────────────────────┐
│ │ Track model usage │
│ │ for future │
│ │ saturation checks │
│ └──────────┬───────────┘
│ │
└───────────────┬───────────────┘
┌──────────────┐
│ v3 Limiter │
│ Check │
└──────┬───────┘
┌───────────────┴───────────────┐
│ │
▼ ▼
OVER_LIMIT OK
│ │
▼ ▼
Return 429 Error Allow Request
```
## Configuration
### Priority Reservation
Set priority weights in your proxy configuration:
```python
litellm.priority_reservation = {
"premium": 0.75, # 75% of capacity
"standard": 0.25 # 25% of capacity
}
```
### Priority Reservation Settings
Configure saturation-aware behavior:
```python
litellm.priority_reservation_settings = PriorityReservationSettings(
default_priority=0.5, # Default weight for users without explicit priority
saturation_threshold=0.80, # 80% - threshold for strict mode enforcement
tracking_multiplier=10 # 10x - multiplier for non-blocking tracking in strict mode
)
```
**Settings:**
- `default_priority` (default: 0.5) - Priority weight for users without explicit priority metadata
- `saturation_threshold` (default: 0.80) - Saturation level (0.0-1.0) at which strict priority enforcement begins
- `tracking_multiplier` (default: 10) - Multiplier for model-wide tracking limits in strict mode
### User Priority Assignment
Set priority in user metadata:
```python
user_api_key_dict.metadata = {"priority": "premium"}
```
## Priority Weight Normalization
If priorities sum to > 1.0, they are automatically normalized:
```
Input: {key_a: 0.60, key_b: 0.80} = 1.40 total
Output: {key_a: 0.43, key_b: 0.57} = 1.00 total
```
This ensures total allocation never exceeds model capacity.
## Implementation Details
### Saturation Detection
- Queries v3 limiter's Redis counters for model-wide usage
- Checks both RPM and TPM, returns higher saturation value
- Non-blocking reads (doesn't increment counters)
### Mode Selection
**Generous Mode (< 80% saturation):**
- Creates single model-wide descriptor
- Enforces total capacity only
- Allows any priority to use available capacity
- Prevents over-subscription via model-wide limit
**Strict Mode (>= 80% saturation):**
- Creates priority-specific descriptors with normalized weights
- Each priority gets its reserved allocation
- Tracks model-wide usage separately (non-blocking, 10x multiplier)
- Ensures fairness under load
Test scenarios covered:
1. No rate limiting when under capacity
2. Priority queue behavior during saturation
3. Spillover capacity for default keys
4. Over-allocated priorities with normalization
5. Default priority value handling
### `_PROXY_DynamicRateLimitHandlerV3`
Main handler class inheriting from `CustomLogger`.
**Key Methods:**
- `async_pre_call_hook()` - Main entry point, routes to generous/strict mode
- `_check_model_saturation()` - Queries Redis for current usage
- `_handle_generous_mode()` - Enforces model-wide capacity only
- `_handle_strict_mode()` - Enforces normalized priority limits
- `_normalize_priority_weights()` - Handles over-allocation
- `_create_priority_based_descriptors()` - Creates rate limit descriptors

View File

@@ -0,0 +1,60 @@
import os
from typing import Literal, Union
from . import *
from .cache_control_check import _PROXY_CacheControlCheck
from .litellm_skills import SkillsInjectionHook
from .max_budget_limiter import _PROXY_MaxBudgetLimiter
from .max_budget_per_session_limiter import _PROXY_MaxBudgetPerSessionHandler
from .max_iterations_limiter import _PROXY_MaxIterationsHandler
from .parallel_request_limiter import _PROXY_MaxParallelRequestsHandler
from .parallel_request_limiter_v3 import _PROXY_MaxParallelRequestsHandler_v3
from .responses_id_security import ResponsesIDSecurity
### CHECK IF ENTERPRISE HOOKS ARE AVAILABLE ####
try:
from enterprise.enterprise_hooks import ENTERPRISE_PROXY_HOOKS
except ImportError:
ENTERPRISE_PROXY_HOOKS = {}
# List of all available hooks that can be enabled
PROXY_HOOKS = {
"max_budget_limiter": _PROXY_MaxBudgetLimiter,
"parallel_request_limiter": _PROXY_MaxParallelRequestsHandler_v3,
"cache_control_check": _PROXY_CacheControlCheck,
"responses_id_security": ResponsesIDSecurity,
"litellm_skills": SkillsInjectionHook,
"max_iterations_limiter": _PROXY_MaxIterationsHandler,
"max_budget_per_session_limiter": _PROXY_MaxBudgetPerSessionHandler,
}
## FEATURE FLAG HOOKS ##
if os.getenv("LEGACY_MULTI_INSTANCE_RATE_LIMITING", "false").lower() == "true":
PROXY_HOOKS["parallel_request_limiter"] = _PROXY_MaxParallelRequestsHandler
### update PROXY_HOOKS with ENTERPRISE_PROXY_HOOKS ###
PROXY_HOOKS.update(ENTERPRISE_PROXY_HOOKS)
def get_proxy_hook(
hook_name: Union[
Literal[
"max_budget_limiter",
"managed_files",
"parallel_request_limiter",
"cache_control_check",
],
str,
],
):
"""
Factory method to get a proxy hook instance by name
"""
if hook_name not in PROXY_HOOKS:
raise ValueError(
f"Unknown hook: {hook_name}. Available hooks: {list(PROXY_HOOKS.keys())}"
)
return PROXY_HOOKS[hook_name]

View File

@@ -0,0 +1,156 @@
import traceback
from typing import Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_AzureContentSafety(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
# Class variables or attributes
def __init__(self, endpoint, api_key, thresholds=None):
try:
from azure.ai.contentsafety.aio import ContentSafetyClient
from azure.ai.contentsafety.models import (
AnalyzeTextOptions,
AnalyzeTextOutputType,
TextCategory,
)
from azure.core.credentials import AzureKeyCredential
from azure.core.exceptions import HttpResponseError
except Exception as e:
raise Exception(
f"\033[91mAzure Content-Safety not installed, try running 'pip install azure-ai-contentsafety' to fix this error: {e}\n{traceback.format_exc()}\033[0m"
)
self.endpoint = endpoint
self.api_key = api_key
self.text_category = TextCategory
self.analyze_text_options = AnalyzeTextOptions
self.analyze_text_output_type = AnalyzeTextOutputType
self.azure_http_error = HttpResponseError
self.thresholds = self._configure_thresholds(thresholds)
self.client = ContentSafetyClient(
self.endpoint, AzureKeyCredential(self.api_key)
)
def _configure_thresholds(self, thresholds=None):
default_thresholds = {
self.text_category.HATE: 4,
self.text_category.SELF_HARM: 4,
self.text_category.SEXUAL: 4,
self.text_category.VIOLENCE: 4,
}
if thresholds is None:
return default_thresholds
for key, default in default_thresholds.items():
if key not in thresholds:
thresholds[key] = default
return thresholds
def _compute_result(self, response):
result = {}
category_severity = {
item.category: item.severity for item in response.categories_analysis
}
for category in self.text_category:
severity = category_severity.get(category)
if severity is not None:
result[category] = {
"filtered": severity >= self.thresholds[category],
"severity": severity,
}
return result
async def test_violation(self, content: str, source: Optional[str] = None):
verbose_proxy_logger.debug("Testing Azure Content-Safety for: %s", content)
# Construct a request
request = self.analyze_text_options(
text=content,
output_type=self.analyze_text_output_type.EIGHT_SEVERITY_LEVELS,
)
# Analyze text
try:
response = await self.client.analyze_text(request)
except self.azure_http_error:
verbose_proxy_logger.debug(
"Error in Azure Content-Safety: %s", traceback.format_exc()
)
verbose_proxy_logger.debug(traceback.format_exc())
raise
result = self._compute_result(response)
verbose_proxy_logger.debug("Azure Content-Safety Result: %s", result)
for key, value in result.items():
if value["filtered"]:
raise HTTPException(
status_code=400,
detail={
"error": "Violated content safety policy",
"source": source,
"category": key,
"severity": value["severity"],
},
)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Pre-Call Hook")
try:
if call_type == "completion" and "messages" in data:
for m in data["messages"]:
if "content" in m and isinstance(m["content"], str):
await self.test_violation(content=m["content"], source="input")
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.azure_content_safety.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
verbose_proxy_logger.debug("Inside Azure Content-Safety Post-Call Hook")
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.utils.Choices
):
await self.test_violation(
content=response.choices[0].message.content or "", source="output"
)
# async def async_post_call_streaming_hook(
# self,
# user_api_key_dict: UserAPIKeyAuth,
# response: str,
# ):
# verbose_proxy_logger.debug("Inside Azure Content-Safety Call-Stream Hook")
# await self.test_violation(content=response, source="output")

View File

@@ -0,0 +1,456 @@
"""
Batch Rate Limiter Hook
This hook implements rate limiting for batch API requests by:
1. Reading batch input files to count requests and estimate tokens at submission
2. Validating actual usage from output files when batches complete
3. Integrating with the existing parallel request limiter infrastructure
## Integration & Calling
This hook is automatically registered and called by the proxy system.
See BATCH_RATE_LIMITER_INTEGRATION.md for complete integration details.
Quick summary:
- Add to PROXY_HOOKS in litellm/proxy/hooks/__init__.py
- Gets auto-instantiated on proxy startup via _add_proxy_hooks()
- async_pre_call_hook() fires on POST /v1/batches (batch submission)
- async_log_success_event() fires on GET /v1/batches/{id} (batch completion)
"""
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
from fastapi import HTTPException
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.batch_utils import (
_get_batch_job_input_file_usage,
_get_file_content_as_dictionary,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
RateLimitDescriptor as _RateLimitDescriptor,
)
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
RateLimitStatus as _RateLimitStatus,
)
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
_PROXY_MaxParallelRequestsHandler_v3 as _ParallelRequestLimiter,
)
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
from litellm.router import Router as _Router
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
Router = _Router
ParallelRequestLimiter = _ParallelRequestLimiter
RateLimitStatus = _RateLimitStatus
RateLimitDescriptor = _RateLimitDescriptor
else:
Span = Any
InternalUsageCache = Any
Router = Any
ParallelRequestLimiter = Any
RateLimitStatus = Dict[str, Any]
RateLimitDescriptor = Dict[str, Any]
class BatchFileUsage(BaseModel):
"""
Internal model for batch file usage tracking, used for batch rate limiting
"""
total_tokens: int
request_count: int
class _PROXY_BatchRateLimiter(CustomLogger):
"""
Rate limiter for batch API requests.
Handles rate limiting at two points:
1. Batch submission - reads input file and reserves capacity
2. Batch completion - reads output file and adjusts for actual usage
"""
def __init__(
self,
internal_usage_cache: InternalUsageCache,
parallel_request_limiter: ParallelRequestLimiter,
):
"""
Initialize the batch rate limiter.
Note: These dependencies are automatically injected by ProxyLogging._add_proxy_hooks()
when this hook is registered in PROXY_HOOKS. See BATCH_RATE_LIMITER_INTEGRATION.md.
Args:
internal_usage_cache: Cache for storing rate limit data (auto-injected)
parallel_request_limiter: Existing rate limiter to integrate with (needs custom injection)
"""
self.internal_usage_cache = internal_usage_cache
self.parallel_request_limiter = parallel_request_limiter
def _raise_rate_limit_error(
self,
status: "RateLimitStatus",
descriptors: List["RateLimitDescriptor"],
batch_usage: BatchFileUsage,
limit_type: str,
) -> None:
"""Raise HTTPException for rate limit exceeded."""
from datetime import datetime
# Find the descriptor for this status
descriptor_index = next(
(
i
for i, d in enumerate(descriptors)
if d.get("key") == status.get("descriptor_key")
),
0,
)
descriptor: RateLimitDescriptor = (
descriptors[descriptor_index]
if descriptors
else {"key": "", "value": "", "rate_limit": None}
)
now = datetime.now().timestamp()
window_size = self.parallel_request_limiter.window_size
reset_time = now + window_size
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
"%Y-%m-%d %H:%M:%S UTC"
)
remaining_display = max(0, status["limit_remaining"])
current_limit = status["current_limit"]
if limit_type == "requests":
detail = (
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
f"Batch contains {batch_usage.request_count} requests but only {remaining_display} requests remaining "
f"out of {current_limit} RPM limit. "
f"Limit resets at: {reset_time_formatted}"
)
else: # tokens
detail = (
f"Batch rate limit exceeded for {descriptor.get('key', 'unknown')}: {descriptor.get('value', 'unknown')}. "
f"Batch contains {batch_usage.total_tokens} tokens but only {remaining_display} tokens remaining "
f"out of {current_limit} TPM limit. "
f"Limit resets at: {reset_time_formatted}"
)
raise HTTPException(
status_code=429,
detail=detail,
headers={
"retry-after": str(window_size),
"rate_limit_type": limit_type,
"reset_at": reset_time_formatted,
},
)
async def _check_and_increment_batch_counters(
self,
user_api_key_dict: UserAPIKeyAuth,
data: Dict,
batch_usage: BatchFileUsage,
) -> None:
"""
Check rate limits and increment counters by the batch amounts.
Raises HTTPException if any limit would be exceeded.
"""
from litellm.types.caching import RedisPipelineIncrementOperation
# Create descriptors and check if batch would exceed limits
descriptors = self.parallel_request_limiter._create_rate_limit_descriptors(
user_api_key_dict=user_api_key_dict,
data=data,
rpm_limit_type=None,
tpm_limit_type=None,
model_has_failures=False,
)
# Check current usage without incrementing
rate_limit_response = await self.parallel_request_limiter.should_rate_limit(
descriptors=descriptors,
parent_otel_span=user_api_key_dict.parent_otel_span,
read_only=True,
)
# Verify batch won't exceed any limits
for status in rate_limit_response["statuses"]:
rate_limit_type = status["rate_limit_type"]
limit_remaining = status["limit_remaining"]
required_capacity = (
batch_usage.request_count
if rate_limit_type == "requests"
else batch_usage.total_tokens
if rate_limit_type == "tokens"
else 0
)
if required_capacity > limit_remaining:
self._raise_rate_limit_error(
status, descriptors, batch_usage, rate_limit_type
)
# Build pipeline operations for batch increments
# Reuse the same keys that descriptors check
pipeline_operations: List[RedisPipelineIncrementOperation] = []
for descriptor in descriptors:
key = descriptor["key"]
value = descriptor["value"]
rate_limit = descriptor.get("rate_limit")
if rate_limit is None:
continue
# Add RPM increment if limit is set
if rate_limit.get("requests_per_unit") is not None:
rpm_key = self.parallel_request_limiter.create_rate_limit_keys(
key=key, value=value, rate_limit_type="requests"
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=rpm_key,
increment_value=batch_usage.request_count,
ttl=self.parallel_request_limiter.window_size,
)
)
# Add TPM increment if limit is set
if rate_limit.get("tokens_per_unit") is not None:
tpm_key = self.parallel_request_limiter.create_rate_limit_keys(
key=key, value=value, rate_limit_type="tokens"
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=tpm_key,
increment_value=batch_usage.total_tokens,
ttl=self.parallel_request_limiter.window_size,
)
)
# Execute increments
if pipeline_operations:
await self.parallel_request_limiter.async_increment_tokens_with_ttl_preservation(
pipeline_operations=pipeline_operations,
parent_otel_span=user_api_key_dict.parent_otel_span,
)
async def count_input_file_usage(
self,
file_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
) -> BatchFileUsage:
"""
Count number of requests and tokens in a batch input file.
Args:
file_id: The file ID to read
custom_llm_provider: The custom LLM provider to use for token encoding
user_api_key_dict: User authentication information for file access (required for managed files)
Returns:
BatchFileUsage with total_tokens and request_count
"""
try:
# Check if this is a managed file (base64 encoded unified file ID)
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
)
# Managed files require bypassing the HTTP endpoint (which runs access-check hooks)
# and calling the managed files hook directly with the user's credentials.
is_managed_file = _is_base64_encoded_unified_file_id(file_id)
if is_managed_file and user_api_key_dict is not None:
file_content = await self._fetch_managed_file_content(
file_id=file_id,
user_api_key_dict=user_api_key_dict,
)
else:
# For non-managed files, use the standard litellm.afile_content
file_content = await litellm.afile_content(
file_id=file_id,
custom_llm_provider=custom_llm_provider,
user_api_key_dict=user_api_key_dict,
)
file_content_as_dict = _get_file_content_as_dictionary(file_content.content)
input_file_usage = _get_batch_job_input_file_usage(
file_content_dictionary=file_content_as_dict,
custom_llm_provider=custom_llm_provider,
)
request_count = len(file_content_as_dict)
return BatchFileUsage(
total_tokens=input_file_usage.total_tokens,
request_count=request_count,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error counting input file usage for {file_id}: {str(e)}"
)
raise
async def _fetch_managed_file_content(
self,
file_id: str,
user_api_key_dict: UserAPIKeyAuth,
) -> Any:
"""
Fetch file content from managed files hook.
This is needed for managed files because they require proper user context
to verify file ownership and access permissions.
Args:
file_id: The managed file ID (base64 encoded)
user_api_key_dict: User authentication information
Returns:
HttpxBinaryResponseContent with the file content
"""
from litellm.llms.base_llm.files.transformation import BaseFileEndpoints
# Import proxy_server dependencies at runtime to avoid circular imports
try:
from litellm.proxy.proxy_server import llm_router, proxy_logging_obj
except ImportError as e:
raise ValueError(
f"Cannot import proxy_server dependencies: {str(e)}. "
"Managed files require proxy_server to be initialized."
)
# Get the managed files hook
if proxy_logging_obj is None:
raise ValueError(
"proxy_logging_obj not available. Cannot access managed files hook."
)
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
if managed_files_obj is None:
raise ValueError(
"Managed files hook not found. Cannot access managed file."
)
if not isinstance(managed_files_obj, BaseFileEndpoints):
raise ValueError("Managed files hook is not a BaseFileEndpoints instance.")
if llm_router is None:
raise ValueError("llm_router not available. Cannot access managed files.")
# Use the managed files hook to get file content
# This properly handles user permissions and file ownership
file_content = await managed_files_obj.afile_content(
file_id=file_id,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
llm_router=llm_router,
)
return file_content
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: Any,
data: Dict,
call_type: str,
) -> Union[Exception, str, Dict, None]:
"""
Pre-call hook for batch operations.
Only handles batch creation (acreate_batch):
- Reads input file
- Counts tokens and requests
- Reserves rate limit capacity via parallel_request_limiter
Args:
user_api_key_dict: User authentication information
cache: Cache instance (not used directly)
data: Request data
call_type: Type of call being made
Returns:
Modified data dict or None
Raises:
HTTPException: 429 if rate limit would be exceeded
"""
# Only handle batch creation
if call_type != "acreate_batch":
verbose_proxy_logger.debug(
f"Batch rate limiter: Not handling batch creation rate limiting for call type: {call_type}"
)
return data
verbose_proxy_logger.debug(
"Batch rate limiter: Handling batch creation rate limiting"
)
try:
# Extract input_file_id from data
input_file_id = data.get("input_file_id")
if not input_file_id:
verbose_proxy_logger.debug(
"No input_file_id in batch request, skipping rate limiting"
)
return data
# Get custom_llm_provider for token counting
custom_llm_provider = data.get("custom_llm_provider", "openai")
# Count tokens and requests from input file
verbose_proxy_logger.debug(
f"Counting tokens from batch input file: {input_file_id}"
)
batch_usage = await self.count_input_file_usage(
file_id=input_file_id,
custom_llm_provider=custom_llm_provider,
user_api_key_dict=user_api_key_dict,
)
verbose_proxy_logger.debug(
f"Batch input file usage - Tokens: {batch_usage.total_tokens}, "
f"Requests: {batch_usage.request_count}"
)
# Store batch usage in data for later reference
data["_batch_token_count"] = batch_usage.total_tokens
data["_batch_request_count"] = batch_usage.request_count
# Directly increment counters by batch amounts (check happens atomically)
# This will raise HTTPException if limits are exceeded
await self._check_and_increment_batch_counters(
user_api_key_dict=user_api_key_dict,
data=data,
batch_usage=batch_usage,
)
verbose_proxy_logger.debug(
"Batch rate limit check passed, counters incremented"
)
return data
except HTTPException:
# Re-raise HTTP exceptions (rate limit exceeded)
raise
except Exception as e:
verbose_proxy_logger.error(
f"Error in batch rate limiting: {str(e)}", exc_info=True
)
# Don't block the request if rate limiting fails
return data

View File

@@ -0,0 +1,149 @@
# What this does?
## Gets a key's redis cache, and store it in memory for 1 minute.
## This reduces the number of REDIS GET requests made during high-traffic by the proxy.
### [BETA] this is in Beta. And might change.
import traceback
from typing import Literal, Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache, InMemoryCache, RedisCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_BatchRedisRequests(CustomLogger):
# Class variables or attributes
in_memory_cache: Optional[InMemoryCache] = None
def __init__(self):
if litellm.cache is not None:
litellm.cache.async_get_cache = (
self.async_get_cache
) # map the litellm 'get_cache' function to our custom function
def print_verbose(
self, print_statement, debug_level: Literal["INFO", "DEBUG"] = "DEBUG"
):
if debug_level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
elif debug_level == "INFO":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
"""
Get the user key
Check if a key starting with `litellm:<api_key>:<call_type:` exists in-memory
If no, then get relevant cache from redis
"""
api_key = user_api_key_dict.api_key
cache_key_name = f"litellm:{api_key}:{call_type}"
self.in_memory_cache = cache.in_memory_cache
key_value_dict = {}
in_memory_cache_exists = False
for key in cache.in_memory_cache.cache_dict.keys():
if isinstance(key, str) and key.startswith(cache_key_name):
in_memory_cache_exists = True
if in_memory_cache_exists is False and litellm.cache is not None:
"""
- Check if `litellm.Cache` is redis
- Get the relevant values
"""
if litellm.cache.type is not None and isinstance(
litellm.cache.cache, RedisCache
):
# Initialize an empty list to store the keys
keys = []
self.print_verbose(f"cache_key_name: {cache_key_name}")
# Use the SCAN iterator to fetch keys matching the pattern
keys = await litellm.cache.cache.async_scan_iter(
pattern=cache_key_name, count=100
)
# If you need the truly "last" based on time or another criteria,
# ensure your key naming or storage strategy allows this determination
# Here you would sort or filter the keys as needed based on your strategy
self.print_verbose(f"redis keys: {keys}")
if len(keys) > 0:
key_value_dict = (
await litellm.cache.cache.async_batch_get_cache(
key_list=keys
)
)
## Add to cache
if len(key_value_dict.items()) > 0:
await cache.in_memory_cache.async_set_cache_pipeline(
cache_list=list(key_value_dict.items()), ttl=60
)
## Set cache namespace if it's a miss
data["metadata"]["redis_namespace"] = cache_key_name
except HTTPException as e:
raise e
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.hooks.batch_redis_get.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
verbose_proxy_logger.debug(traceback.format_exc())
async def async_get_cache(self, *args, **kwargs):
"""
- Check if the cache key is in-memory
- Else:
- add missing cache key from REDIS
- update in-memory cache
- return redis cache request
"""
try: # never block execution
cache_key: Optional[str] = None
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
elif litellm.cache is not None:
cache_key = litellm.cache.get_cache_key(
*args, **kwargs
) # returns "<cache_key_name>:<hash>" - we pass redis_namespace in async_pre_call_hook. Done to avoid rewriting the async_set_cache logic
if (
cache_key is not None
and self.in_memory_cache is not None
and litellm.cache is not None
):
cache_control_args = kwargs.get("cache", {})
max_age = cache_control_args.get(
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
)
cached_result = self.in_memory_cache.get_cache(
cache_key, *args, **kwargs
)
if cached_result is None:
cached_result = await litellm.cache.cache.async_get_cache(
cache_key, *args, **kwargs
)
if cached_result is not None:
await self.in_memory_cache.async_set_cache(
cache_key, cached_result, ttl=60
)
return litellm.cache._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception:
return None

View File

@@ -0,0 +1,58 @@
# What this does?
## Checks if key is allowed to use the cache controls passed in to the completion() call
from fastapi import HTTPException
from litellm import verbose_logger
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_CacheControlCheck(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
verbose_proxy_logger.debug("Inside Cache Control Check Pre-Call Hook")
allowed_cache_controls = user_api_key_dict.allowed_cache_controls
if data.get("cache", None) is None:
return
cache_args = data.get("cache", None)
if isinstance(cache_args, dict):
for k, v in cache_args.items():
if (
(allowed_cache_controls is not None)
and (isinstance(allowed_cache_controls, list))
and (
len(allowed_cache_controls) > 0
) # assume empty list to be nullable - https://github.com/prisma/prisma/issues/847#issuecomment-546895663
and k not in allowed_cache_controls
):
raise HTTPException(
status_code=403,
detail=f"Not allowed to set {k} as a cache control. Contact admin to change permissions.",
)
else: # invalid cache
return
except HTTPException as e:
raise e
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.cache_control_check.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)

View File

@@ -0,0 +1,303 @@
# What is this?
## Allocates dynamic tpm/rpm quota for a project based on current traffic
## Tracks num active projects per minute
import asyncio
import os
from typing import List, Optional, Tuple, Union
from fastapi import HTTPException
import litellm
from litellm import ModelResponse, Router
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.router import ModelGroupInfo
from litellm.types.utils import CallTypesLiteral
from litellm.utils import get_utc_datetime
from .rate_limiter_utils import convert_priority_to_percent
class DynamicRateLimiterCache:
"""
Thin wrapper on DualCache for this file.
Track number of active projects calling a model.
"""
def __init__(self, cache: DualCache) -> None:
self.cache = cache
self.ttl = 60 # 1 min ttl
async def async_get_cache(self, model: str) -> Optional[int]:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
_response = await self.cache.async_get_cache(key=key_name)
response: Optional[int] = None
if _response is not None:
response = len(_response)
return response
async def async_set_cache_sadd(self, model: str, value: List):
"""
Add value to set.
Parameters:
- model: str, the name of the model group
- value: str, the team id
Returns:
- None
Raises:
- Exception, if unable to connect to cache client (if redis caching enabled)
"""
try:
dt = get_utc_datetime()
current_minute = dt.strftime("%H-%M")
key_name = "{}:{}".format(current_minute, model)
await self.cache.async_set_cache_sadd(
key=key_name, value=value, ttl=self.ttl
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_set_cache_sadd(): Exception occured - {}".format(
str(e)
)
)
raise e
class _PROXY_DynamicRateLimitHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: DualCache):
self.internal_usage_cache = DynamicRateLimiterCache(cache=internal_usage_cache)
def update_variables(self, llm_router: Router):
self.llm_router = llm_router
async def check_available_usage(
self, model: str, priority: Optional[str] = None
) -> Tuple[
Optional[int], Optional[int], Optional[int], Optional[int], Optional[int]
]:
"""
For a given model, get its available tpm
Params:
- model: str, the name of the model in the router model_list
- priority: Optional[str], the priority for the request.
Returns
- Tuple[available_tpm, available_tpm, model_tpm, model_rpm, active_projects]
- available_tpm: int or null - always 0 or positive.
- available_tpm: int or null - always 0 or positive.
- remaining_model_tpm: int or null. If available tpm is int, then this will be too.
- remaining_model_rpm: int or null. If available rpm is int, then this will be too.
- active_projects: int or null
"""
try:
# Get model info first for conversion
model_group_info: Optional[
ModelGroupInfo
] = self.llm_router.get_model_group_info(model_group=model)
weight: float = 1
if (
litellm.priority_reservation is None
or priority not in litellm.priority_reservation
):
verbose_proxy_logger.error(
"Priority Reservation not set. priority={}, but litellm.priority_reservation is {}.".format(
priority, litellm.priority_reservation
)
)
elif priority is not None and litellm.priority_reservation is not None:
if os.getenv("LITELLM_LICENSE", None) is None:
verbose_proxy_logger.error(
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
)
else:
value = litellm.priority_reservation[priority]
weight = convert_priority_to_percent(value, model_group_info)
active_projects = await self.internal_usage_cache.async_get_cache(
model=model
)
(
current_model_tpm,
current_model_rpm,
) = await self.llm_router.get_model_group_usage(model_group=model)
total_model_tpm: Optional[int] = None
total_model_rpm: Optional[int] = None
if model_group_info is not None:
if model_group_info.tpm is not None:
total_model_tpm = model_group_info.tpm
if model_group_info.rpm is not None:
total_model_rpm = model_group_info.rpm
remaining_model_tpm: Optional[int] = None
if total_model_tpm is not None and current_model_tpm is not None:
remaining_model_tpm = total_model_tpm - current_model_tpm
elif total_model_tpm is not None:
remaining_model_tpm = total_model_tpm
remaining_model_rpm: Optional[int] = None
if total_model_rpm is not None and current_model_rpm is not None:
remaining_model_rpm = total_model_rpm - current_model_rpm
elif total_model_rpm is not None:
remaining_model_rpm = total_model_rpm
available_tpm: Optional[int] = None
if remaining_model_tpm is not None:
if active_projects is not None:
available_tpm = int(remaining_model_tpm * weight / active_projects)
else:
available_tpm = int(remaining_model_tpm * weight)
if available_tpm is not None and available_tpm < 0:
available_tpm = 0
available_rpm: Optional[int] = None
if remaining_model_rpm is not None:
if active_projects is not None:
available_rpm = int(remaining_model_rpm * weight / active_projects)
else:
available_rpm = int(remaining_model_rpm * weight)
if available_rpm is not None and available_rpm < 0:
available_rpm = 0
return (
available_tpm,
available_rpm,
remaining_model_tpm,
remaining_model_rpm,
active_projects,
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::check_available_usage: Exception occurred - {}".format(
str(e)
)
)
return None, None, None, None, None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[
Union[Exception, str, dict]
]: # raise exception if invalid, return a str for the user to receive - if rejected, or return a modified dictionary for passing into litellm
"""
- For a model group
- Check if tpm/rpm available
- Raise RateLimitError if no tpm/rpm available
"""
if "model" in data:
key_priority: Optional[str] = user_api_key_dict.metadata.get(
"priority", None
)
(
available_tpm,
available_rpm,
model_tpm,
model_rpm,
active_projects,
) = await self.check_available_usage(
model=data["model"], priority=key_priority
)
### CHECK TPM ###
if available_tpm is not None and available_tpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available TPM={}. Model TPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_tpm,
model_tpm,
active_projects,
)
},
)
### CHECK RPM ###
elif available_rpm is not None and available_rpm == 0:
raise HTTPException(
status_code=429,
detail={
"error": "Key={} over available RPM={}. Model RPM={}, Active keys={}".format(
user_api_key_dict.api_key,
available_rpm,
model_rpm,
active_projects,
)
},
)
elif available_rpm is not None or available_tpm is not None:
## UPDATE CACHE WITH ACTIVE PROJECT
asyncio.create_task(
self.internal_usage_cache.async_set_cache_sadd( # this is a set
model=data["model"], # type: ignore
value=[user_api_key_dict.token or "default_key"],
)
)
return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
try:
if isinstance(response, ModelResponse):
model_info = self.llm_router.get_model_info(
id=response._hidden_params["model_id"]
)
assert (
model_info is not None
), "Model info for model with id={} is None".format(
response._hidden_params["model_id"]
)
key_priority: Optional[str] = user_api_key_dict.metadata.get(
"priority", None
)
(
available_tpm,
available_rpm,
model_tpm,
model_rpm,
active_projects,
) = await self.check_available_usage(
model=model_info["model_name"], priority=key_priority
)
response._hidden_params[
"additional_headers"
] = { # Add additional response headers - easier debugging
"x-litellm-model_group": model_info["model_name"],
"x-ratelimit-remaining-litellm-project-tokens": available_tpm,
"x-ratelimit-remaining-litellm-project-requests": available_rpm,
"x-ratelimit-remaining-model-tokens": model_tpm,
"x-ratelimit-remaining-model-requests": model_rpm,
"x-ratelimit-current-active-projects": active_projects,
}
return response
return await super().async_post_call_success_hook(
data=data,
user_api_key_dict=user_api_key_dict,
response=response,
)
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.dynamic_rate_limiter.py::async_post_call_success_hook(): Exception occured - {}".format(
str(e)
)
)
return response

View File

@@ -0,0 +1,809 @@
"""
Dynamic rate limiter v3 - Saturation-aware priority-based rate limiting
"""
import os
from datetime import datetime
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union
from fastapi import HTTPException
import litellm
from litellm import ModelResponse, Router
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
RateLimitDescriptor,
RateLimitDescriptorRateLimitObject,
_PROXY_MaxParallelRequestsHandler_v3,
)
from litellm.proxy.hooks.rate_limiter_utils import convert_priority_to_percent
from litellm.proxy.utils import InternalUsageCache
from litellm.types.router import ModelGroupInfo
from litellm.types.utils import CallTypesLiteral
if TYPE_CHECKING:
from litellm.types.utils import PriorityReservationSettings
def _get_priority_settings() -> "PriorityReservationSettings":
"""
Get the priority reservation settings, guaranteed to be non-None.
The settings are lazy-loaded in litellm.__init__ and always return an instance.
This helper provides proper type narrowing for mypy.
"""
settings = litellm.priority_reservation_settings
if settings is None:
# This should never happen due to lazy loading, but satisfy mypy
from litellm.types.utils import PriorityReservationSettings
return PriorityReservationSettings()
return settings
class _PROXY_DynamicRateLimitHandlerV3(CustomLogger):
"""
Saturation-aware priority-based rate limiter using v3 infrastructure.
Key features:
1. Model capacity ALWAYS enforced at 100% (prevents over-allocation)
2. Priority usage tracked from first request (accurate accounting)
3. Priority limits only enforced when saturated >= threshold
4. Three-phase checking prevents partial counter increments
5. Reuses v3 limiter's Redis-based tracking (multi-instance safe)
How it works:
- Phase 1: Read-only check of ALL limits (no increments)
- Phase 2: Decide enforcement based on saturation
- Phase 3: Increment counters only if request allowed
- When under-saturated: priorities can borrow unused capacity (generous)
- When saturated: strict priority-based limits enforced (fair)
- Uses v3 limiter's atomic Lua scripts for race-free increments
"""
def __init__(
self,
internal_usage_cache: DualCache,
time_provider: Optional[Callable[[], datetime]] = None,
):
self.internal_usage_cache = InternalUsageCache(dual_cache=internal_usage_cache)
self.v3_limiter = _PROXY_MaxParallelRequestsHandler_v3(
self.internal_usage_cache, time_provider=time_provider
)
def update_variables(self, llm_router: Router):
self.llm_router = llm_router
def _get_saturation_check_cache_ttl(self) -> int:
"""Get the configurable TTL for local cache when reading saturation values."""
return _get_priority_settings().saturation_check_cache_ttl
async def _get_saturation_value_from_cache(
self,
counter_key: str,
) -> Optional[str]:
"""
Get saturation value with configurable local cache TTL.
Uses DualCache with configurable TTL for local cache storage.
TTL is configurable via litellm.priority_reservation_settings.saturation_check_cache_ttl
Args:
counter_key: The cache key for the saturation counter
Returns:
Counter value as string, or None if not found
"""
local_cache_ttl = self._get_saturation_check_cache_ttl()
return await self.internal_usage_cache.async_get_cache(
key=counter_key,
litellm_parent_otel_span=None,
local_only=False,
ttl=local_cache_ttl,
)
def _get_priority_weight(
self, priority: Optional[str], model_info: Optional[ModelGroupInfo] = None
) -> float:
"""Get the weight for a given priority from litellm.priority_reservation"""
weight: float = _get_priority_settings().default_priority
if (
litellm.priority_reservation is None
or priority not in litellm.priority_reservation
):
verbose_proxy_logger.debug(
"Priority Reservation not set for the given priority."
)
elif priority is not None and litellm.priority_reservation is not None:
if os.getenv("LITELLM_LICENSE", None) is None:
verbose_proxy_logger.error(
"PREMIUM FEATURE: Reserving tpm/rpm by priority is a premium feature. Please add a 'LITELLM_LICENSE' to your .env to enable this.\nGet a license: https://docs.litellm.ai/docs/proxy/enterprise."
)
else:
value = litellm.priority_reservation[priority]
weight = convert_priority_to_percent(value, model_info)
return weight
def _get_priority_from_user_api_key_dict(
self, user_api_key_dict: UserAPIKeyAuth
) -> Optional[str]:
"""
Get priority from user_api_key_dict.
Checks team metadata first (takes precedence), then falls back to key metadata.
Args:
user_api_key_dict: User authentication info
Returns:
Priority string if found, None otherwise
"""
priority: Optional[str] = None
# Check team metadata first (takes precedence)
if user_api_key_dict.team_metadata is not None:
priority = user_api_key_dict.team_metadata.get("priority", None)
# Fall back to key metadata
if priority is None:
priority = user_api_key_dict.metadata.get("priority", None)
return priority
def _normalize_priority_weights(
self, model_info: ModelGroupInfo
) -> Dict[str, float]:
"""
Normalize priority weights if they sum to > 1.0
Handles over-allocation: {key_a: 0.60, key_b: 0.80} -> {key_a: 0.43, key_b: 0.57}
Converts absolute rpm/tpm values to percentages based on model capacity.
"""
if litellm.priority_reservation is None:
return {}
# Convert all values to percentages first
weights: Dict[str, float] = {}
for k, v in litellm.priority_reservation.items():
weights[k] = convert_priority_to_percent(v, model_info)
total_weight = sum(weights.values())
if total_weight > 1.0:
normalized = {k: v / total_weight for k, v in weights.items()}
verbose_proxy_logger.debug(
f"Normalized over-allocated priorities: {weights} -> {normalized}"
)
return normalized
return weights
def _get_priority_allocation(
self,
model: str,
priority: Optional[str],
normalized_weights: Dict[str, float],
model_info: Optional[ModelGroupInfo] = None,
) -> tuple[float, str]:
"""
Get priority weight and pool key for a given priority.
For explicit priorities: returns specific allocation and unique pool key
For default priority: returns default allocation and shared pool key
Args:
model: Model name
priority: Priority level (None for default)
normalized_weights: Pre-computed normalized weights
model_info: Model configuration (optional, for fallback conversion)
Returns:
tuple: (priority_weight, priority_key)
"""
# Check if this key has an explicit priority in litellm.priority_reservation
has_explicit_priority = (
priority is not None
and litellm.priority_reservation is not None
and priority in litellm.priority_reservation
)
if has_explicit_priority and priority is not None:
# Explicit priority: get its specific allocation
priority_weight = normalized_weights.get(
priority, self._get_priority_weight(priority, model_info)
)
# Use unique key per priority level
priority_key = f"{model}:{priority}"
else:
# No explicit priority: share the default_priority pool with ALL other default keys
priority_weight = _get_priority_settings().default_priority
# Use shared key for all default-priority requests
priority_key = f"{model}:default_pool"
return priority_weight, priority_key
async def _check_model_saturation(
self,
model: str,
model_group_info: ModelGroupInfo,
) -> float:
"""
Check current saturation by directly querying v3 limiter's cache keys.
Reuses v3 limiter's Redis-based tracking (works across multiple instances).
Reads counters WITHOUT incrementing them.
Returns:
float: Saturation ratio (0.0 = empty, 1.0 = at capacity, >1.0 = over)
"""
try:
max_saturation = 0.0
# Query RPM saturation - always read from Redis for multi-node consistency
if model_group_info.rpm is not None and model_group_info.rpm > 0:
# Use v3 limiter's key format: {key:value}:rate_limit_type
counter_key = self.v3_limiter.create_rate_limit_keys(
key="model_saturation_check",
value=model,
rate_limit_type="requests",
)
# Query Redis directly for current counter value (skip local cache for consistency)
counter_value = await self._get_saturation_value_from_cache(
counter_key=counter_key
)
if counter_value is not None:
current_requests = int(counter_value)
rpm_saturation = current_requests / model_group_info.rpm
max_saturation = max(max_saturation, rpm_saturation)
verbose_proxy_logger.debug(
f"Model {model} RPM: {current_requests}/{model_group_info.rpm} "
f"({rpm_saturation:.1%})"
)
# Query TPM saturation
if model_group_info.tpm is not None and model_group_info.tpm > 0:
counter_key = self.v3_limiter.create_rate_limit_keys(
key="model_saturation_check",
value=model,
rate_limit_type="tokens",
)
counter_value = await self._get_saturation_value_from_cache(
counter_key=counter_key
)
if counter_value is not None:
current_tokens = float(counter_value)
tpm_saturation = current_tokens / model_group_info.tpm
max_saturation = max(max_saturation, tpm_saturation)
verbose_proxy_logger.debug(
f"Model {model} TPM: {current_tokens}/{model_group_info.tpm} "
f"({tpm_saturation:.1%})"
)
verbose_proxy_logger.debug(
f"Model {model} overall saturation: {max_saturation:.1%}"
)
return max_saturation
except Exception as e:
verbose_proxy_logger.error(
f"Error checking saturation for {model}: {str(e)}"
)
# Fail open: assume not saturated on error
return 0.0
def _create_priority_based_descriptors(
self,
model: str,
user_api_key_dict: UserAPIKeyAuth,
priority: Optional[str],
) -> List[RateLimitDescriptor]:
"""
Create rate limit descriptors with normalized priority weights.
Uses normalized weights to handle over-allocation scenarios.
For explicit priorities: each priority gets its own pool (e.g., prod gets 75%)
For default priority: ALL keys without explicit priority share ONE pool (e.g., all share 25%)
"""
descriptors: List[RateLimitDescriptor] = []
if litellm.priority_reservation is None:
return descriptors
# Get model group info
model_group_info: Optional[
ModelGroupInfo
] = self.llm_router.get_model_group_info(model_group=model)
if model_group_info is None:
return descriptors
# Get normalized priority weight and pool key
normalized_weights = self._normalize_priority_weights(model_group_info)
priority_weight, priority_key = self._get_priority_allocation(
model=model,
priority=priority,
normalized_weights=normalized_weights,
model_info=model_group_info,
)
rate_limit_config: RateLimitDescriptorRateLimitObject = {}
# Apply priority weight to model limits
if model_group_info.tpm is not None:
reserved_tpm = int(model_group_info.tpm * priority_weight)
rate_limit_config["tokens_per_unit"] = reserved_tpm
if model_group_info.rpm is not None:
reserved_rpm = int(model_group_info.rpm * priority_weight)
rate_limit_config["requests_per_unit"] = reserved_rpm
if rate_limit_config:
rate_limit_config["window_size"] = self.v3_limiter.window_size
descriptors.append(
RateLimitDescriptor(
key="priority_model",
value=priority_key,
rate_limit=rate_limit_config,
)
)
return descriptors
def _create_model_tracking_descriptor(
self,
model: str,
model_group_info: ModelGroupInfo,
high_limit_multiplier: int = 1,
) -> RateLimitDescriptor:
"""
Create a descriptor for tracking model-wide usage.
Args:
model: Model name
model_group_info: Model configuration with RPM/TPM limits
high_limit_multiplier: Multiplier for limits (use >1 for tracking-only)
Returns:
Rate limit descriptor for model-wide tracking
"""
return RateLimitDescriptor(
key="model_saturation_check",
value=model,
rate_limit={
"requests_per_unit": (
model_group_info.rpm * high_limit_multiplier
if model_group_info.rpm
else None
),
"tokens_per_unit": (
model_group_info.tpm * high_limit_multiplier
if model_group_info.tpm
else None
),
"window_size": self.v3_limiter.window_size,
},
)
async def _check_rate_limits(
self,
model: str,
model_group_info: ModelGroupInfo,
user_api_key_dict: UserAPIKeyAuth,
priority: Optional[str],
saturation: float,
data: dict,
) -> None:
"""
Check rate limits using THREE-PHASE approach to prevent partial increments.
Phase 1: Read-only check of ALL limits (no increments)
Phase 2: Decide which limits to enforce based on saturation
Phase 3: Increment ALL counters atomically (model + priority)
This prevents the bug where:
- Model counter increments in stage 1
- Priority check fails in stage 2
- Request blocked but model counter already incremented
Key behaviors:
- All checks performed first (read-only)
- Only increment counters if request will be allowed
- Model capacity: Always enforced at 100%
- Priority limits: Only enforced when saturated >= threshold
- Both counters tracked from first request (accurate accounting)
Args:
model: Model name
model_group_info: Model configuration
user_api_key_dict: User authentication info
priority: User's priority level
saturation: Current saturation level
data: Request data dictionary
Raises:
HTTPException: If any limit is exceeded
"""
import json
saturation_threshold = _get_priority_settings().saturation_threshold
should_enforce_priority = saturation >= saturation_threshold
# Build ALL descriptors upfront
descriptors_to_check: List[RateLimitDescriptor] = []
# Model-wide descriptor (always enforce)
model_wide_descriptor = self._create_model_tracking_descriptor(
model=model,
model_group_info=model_group_info,
high_limit_multiplier=1,
)
descriptors_to_check.append(model_wide_descriptor)
# Priority descriptors (always track, conditionally enforce)
priority_descriptors = self._create_priority_based_descriptors(
model=model,
user_api_key_dict=user_api_key_dict,
priority=priority,
)
if priority_descriptors:
descriptors_to_check.extend(priority_descriptors)
# PHASE 1: Read-only check of ALL limits (no increments)
check_response = await self.v3_limiter.should_rate_limit(
descriptors=descriptors_to_check,
parent_otel_span=user_api_key_dict.parent_otel_span,
read_only=True, # CRITICAL: Don't increment counters yet
)
verbose_proxy_logger.debug(
f"Read-only check: {json.dumps(check_response, indent=2)}"
)
# PHASE 2: Decide which limits to enforce
if check_response["overall_code"] == "OVER_LIMIT":
for status in check_response["statuses"]:
if status["code"] == "OVER_LIMIT":
descriptor_key = status["descriptor_key"]
# Model-wide limit exceeded (ALWAYS enforce)
if descriptor_key == "model_saturation_check":
raise HTTPException(
status_code=429,
detail={
"error": f"Model capacity reached for {model}. "
f"Priority: {priority}, "
f"Rate limit type: {status['rate_limit_type']}, "
f"Remaining: {status['limit_remaining']}"
},
headers={
"retry-after": str(self.v3_limiter.window_size),
"rate_limit_type": str(status["rate_limit_type"]),
"x-litellm-priority": priority or "default",
},
)
# Priority limit exceeded (ONLY enforce when saturated)
elif descriptor_key == "priority_model" and should_enforce_priority:
verbose_proxy_logger.debug(
f"Enforcing priority limits for {model}, saturation: {saturation:.1%}, "
f"priority: {priority}"
)
raise HTTPException(
status_code=429,
detail={
"error": f"Priority-based rate limit exceeded. "
f"Priority: {priority}, "
f"Rate limit type: {status['rate_limit_type']}, "
f"Remaining: {status['limit_remaining']}, "
f"Model saturation: {saturation:.1%}"
},
headers={
"retry-after": str(self.v3_limiter.window_size),
"rate_limit_type": str(status["rate_limit_type"]),
"x-litellm-priority": priority or "default",
"x-litellm-saturation": f"{saturation:.2%}",
},
)
# PHASE 3: Increment counters separately to avoid early-exit issues
# Model counter must ALWAYS increment, but priority counter might be over limit
# If we increment them together, v3_limiter's in-memory check will exit early
# and skip incrementing the model counter
# Step 3a: Increment model-wide counter (always)
model_increment_response = await self.v3_limiter.should_rate_limit(
descriptors=[model_wide_descriptor],
parent_otel_span=user_api_key_dict.parent_otel_span,
read_only=False,
)
# Step 3b: Increment priority counter (may be over limit, but we still track it)
if priority_descriptors:
priority_increment_response = await self.v3_limiter.should_rate_limit(
descriptors=priority_descriptors,
parent_otel_span=user_api_key_dict.parent_otel_span,
read_only=False,
)
# Combine responses for post-call hook
combined_response = {
"overall_code": model_increment_response["overall_code"],
"statuses": model_increment_response["statuses"]
+ priority_increment_response["statuses"],
}
data["litellm_proxy_rate_limit_response"] = combined_response
else:
data["litellm_proxy_rate_limit_response"] = model_increment_response
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
"""
Saturation-aware pre-call hook for priority-based rate limiting.
Flow:
1. Check current saturation level
2. THREE-PHASE rate limit check:
- PHASE 1: Read-only check of ALL limits (no increments)
- PHASE 2: Decide which limits to enforce based on saturation
- PHASE 3: Increment ALL counters atomically if request allowed
This three-phase approach ensures:
- Model capacity is NEVER exceeded (always enforced at 100%)
- Priority usage tracked from first request (accurate metrics)
- Counters only increment when request will be allowed (prevents phantom usage)
- When under-saturated: priorities can borrow unused capacity (generous)
- When saturated: fair allocation based on normalized priority weights (strict)
Example with 100 RPM model, 60% priority allocation, 80% threshold:
- Saturation < 80%: Priority can use up to 100 RPM (model limit enforced only)
- Saturation >= 80%: Priority limited to 60 RPM (both limits enforced)
Prevents bugs where:
- Model counter increments but priority check fails → model over-capacity
- Priority counter increments but not enforced → inaccurate metrics
Args:
user_api_key_dict: User authentication and metadata
cache: Dual cache instance
data: Request data containing model name
call_type: Type of API call being made
Returns:
None if request is allowed, otherwise raises HTTPException
"""
if "model" not in data:
return None
model = data["model"]
priority = self._get_priority_from_user_api_key_dict(
user_api_key_dict=user_api_key_dict
)
# Get model configuration
model_group_info: Optional[
ModelGroupInfo
] = self.llm_router.get_model_group_info(model_group=model)
if model_group_info is None:
verbose_proxy_logger.debug(
f"No model group info for {model}, allowing request"
)
return None
try:
# STEP 1: Check current saturation level
saturation = await self._check_model_saturation(model, model_group_info)
saturation_threshold = _get_priority_settings().saturation_threshold
verbose_proxy_logger.debug(
f"[Dynamic Rate Limiter] Model={model}, Saturation={saturation:.1%}, "
f"Threshold={saturation_threshold:.1%}, Priority={priority}"
)
# STEP 2: Check rate limits in THREE phases
# Phase 1: Read-only check of ALL limits (no increments)
# Phase 2: Decide which limits to enforce (based on saturation)
# Phase 3: Increment ALL counters only if request will be allowed
# This prevents partial increments and ensures accurate tracking
await self._check_rate_limits(
model=model,
model_group_info=model_group_info,
user_api_key_dict=user_api_key_dict,
priority=priority,
saturation=saturation,
data=data,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(
f"Error in dynamic rate limiter: {str(e)}, allowing request"
)
# Fail open on unexpected errors
return None
return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
"""
Post-call hook to add rate limit headers to response.
Leverages v3 limiter's post-call hook functionality.
"""
try:
# Call v3 limiter's post-call hook to add standard rate limit headers
await self.v3_limiter.async_post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# Add additional priority-specific headers
if isinstance(response, ModelResponse):
priority = self._get_priority_from_user_api_key_dict(
user_api_key_dict=user_api_key_dict
)
# Get existing additional headers
additional_headers = (
getattr(response, "_hidden_params", {}).get(
"additional_headers", {}
)
or {}
)
# Add priority information
additional_headers["x-litellm-priority"] = priority or "default"
additional_headers["x-litellm-rate-limiter-version"] = "v3"
# Update response
if not hasattr(response, "_hidden_params"):
response._hidden_params = {}
response._hidden_params["additional_headers"] = additional_headers
return response
except Exception as e:
verbose_proxy_logger.exception(
f"Error in dynamic rate limiter v3 post-call hook: {str(e)}"
)
return response
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Update token usage for priority-based rate limiting after successful API calls.
Increments token counters for:
- model_saturation_check: Model-wide token tracking
- priority_model: Priority-specific token tracking
"""
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
)
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
from litellm.types.caching import RedisPipelineIncrementOperation
from litellm.types.utils import Usage
try:
verbose_proxy_logger.debug(
"INSIDE dynamic rate limiter ASYNC SUCCESS LOGGING"
)
litellm_parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
# Get metadata from standard_logging_object
standard_logging_object = kwargs.get("standard_logging_object") or {}
standard_logging_metadata = standard_logging_object.get("metadata") or {}
# Get model and priority
model_group = get_model_group_from_litellm_kwargs(kwargs)
if not model_group:
return
# Get priority from user_api_key_auth_metadata in standard_logging_metadata
# This is where user_api_key_dict.metadata is stored during pre-call
user_api_key_auth_metadata = (
standard_logging_metadata.get("user_api_key_auth_metadata") or {}
)
key_priority: Optional[str] = user_api_key_auth_metadata.get("priority")
# Get total tokens from response
total_tokens = 0
rate_limit_type = self.v3_limiter.get_rate_limit_type()
if isinstance(response_obj, ModelResponse):
_usage = getattr(response_obj, "usage", None)
if _usage and isinstance(_usage, Usage):
if rate_limit_type == "output":
total_tokens = _usage.completion_tokens
elif rate_limit_type == "input":
total_tokens = _usage.prompt_tokens
elif rate_limit_type == "total":
total_tokens = _usage.total_tokens
if total_tokens == 0:
return
# Create pipeline operations for token increments
pipeline_operations: List[RedisPipelineIncrementOperation] = []
# Model-wide token tracking (model_saturation_check)
model_token_key = self.v3_limiter.create_rate_limit_keys(
key="model_saturation_check",
value=model_group,
rate_limit_type="tokens",
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=model_token_key,
increment_value=total_tokens,
ttl=self.v3_limiter.window_size,
)
)
# Priority-specific token tracking (priority_model)
# Determine priority key (same logic as _get_priority_allocation)
has_explicit_priority = (
key_priority is not None
and litellm.priority_reservation is not None
and key_priority in litellm.priority_reservation
)
if has_explicit_priority and key_priority is not None:
priority_key = f"{model_group}:{key_priority}"
else:
priority_key = f"{model_group}:default_pool"
priority_token_key = self.v3_limiter.create_rate_limit_keys(
key="priority_model",
value=priority_key,
rate_limit_type="tokens",
)
pipeline_operations.append(
RedisPipelineIncrementOperation(
key=priority_token_key,
increment_value=total_tokens,
ttl=self.v3_limiter.window_size,
)
)
# Execute token increments with TTL preservation
if pipeline_operations:
await self.v3_limiter.async_increment_tokens_with_ttl_preservation(
pipeline_operations=pipeline_operations,
parent_otel_span=litellm_parent_otel_span,
)
# Only log 'priority' if it's known safe; otherwise, redact.
SAFE_PRIORITIES = {"low", "medium", "high", "default"}
logged_priority = (
key_priority if key_priority in SAFE_PRIORITIES else "REDACTED"
)
verbose_proxy_logger.debug(
f"[Dynamic Rate Limiter] Incremented tokens by {total_tokens} for "
f"model={model_group}, priority={logged_priority}"
)
except Exception as e:
verbose_proxy_logger.exception(
f"Error in dynamic rate limiter success event: {str(e)}"
)

View File

@@ -0,0 +1,28 @@
[
{
"name": "Zip code Recognizer",
"supported_language": "en",
"patterns": [
{
"name": "zip code (weak)",
"regex": "(\\b\\d{5}(?:\\-\\d{4})?\\b)",
"score": 0.01
}
],
"context": ["zip", "code"],
"supported_entity": "ZIP"
},
{
"name": "Swiss AHV Number Recognizer",
"supported_language": "en",
"patterns": [
{
"name": "AHV number (strong)",
"regex": "(756\\.\\d{4}\\.\\d{4}\\.\\d{2})|(756\\d{10})",
"score": 0.95
}
],
"context": ["AHV", "social security", "Swiss"],
"supported_entity": "AHV_NUMBER"
}
]

View File

@@ -0,0 +1,607 @@
import asyncio
import json
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import (
GenerateKeyRequest,
GenerateKeyResponse,
KeyRequest,
LiteLLM_AuditLogs,
Litellm_EntityType,
LiteLLM_VerificationToken,
LitellmTableNames,
RegenerateKeyRequest,
UpdateKeyRequest,
UserAPIKeyAuth,
)
# NOTE: This is the prefix for all virtual keys stored in AWS Secrets Manager
LITELLM_PREFIX_STORED_VIRTUAL_KEYS = "litellm/"
class KeyManagementEventHooks:
@staticmethod
async def async_key_generated_hook(
data: GenerateKeyRequest,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Hook that runs after a successful /key/generate request
Handles the following:
- Sending Email with Key Details
- Storing Audit Logs for key generation
- Storing Generated Key in DB
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Send email notification - non-blocking, independent operation
if data.send_invite_email is True:
try:
await KeyManagementEventHooks._send_key_created_email(
response.model_dump(exclude_none=True)
)
except Exception as e:
verbose_proxy_logger.warning(f"Failed to send key created email: {e}")
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = response.model_dump_json(exclude_none=True)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=response.token_id or "",
action="created",
updated_values=_updated_values,
before_value=None,
)
)
)
# Store the generated key in the secret manager - non-blocking, independent operation
try:
await KeyManagementEventHooks._store_virtual_key_in_secret_manager(
secret_name=data.key_alias or f"virtual-key-{response.token_id}",
secret_token=response.key,
team_id=data.team_id,
)
except Exception as e:
verbose_proxy_logger.warning(
f"Failed to store virtual key in secret manager: {e}"
)
@staticmethod
async def async_key_updated_hook(
data: UpdateKeyRequest,
existing_key_row: Any,
response: Any,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/update processing hook
Handles the following:
- Storing Audit Logs for key update
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
if litellm.store_audit_logs is True:
_updated_values = json.dumps(data.json(exclude_none=True), default=str)
_before_value = existing_key_row.json(exclude_none=True)
_before_value = json.dumps(_before_value, default=str)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=data.key,
action="updated",
updated_values=_updated_values,
before_value=_before_value,
)
)
)
@staticmethod
async def async_key_rotated_hook(
data: Optional[RegenerateKeyRequest],
existing_key_row: LiteLLM_VerificationToken,
response: GenerateKeyResponse,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Store the generated key in the secret manager - non-blocking, independent operation
if data is not None and response.token_id is not None:
try:
initial_secret_name = (
existing_key_row.key_alias
or f"virtual-key-{existing_key_row.token}"
)
new_secret_name = (
response.key_alias or data.key_alias or initial_secret_name
)
verbose_proxy_logger.info(
"Updating secret in secret manager: secret_name=%s",
new_secret_name,
)
team_id = getattr(existing_key_row, "team_id", None)
await KeyManagementEventHooks._rotate_virtual_key_in_secret_manager(
current_secret_name=initial_secret_name,
new_secret_name=new_secret_name,
new_secret_value=response.key,
team_id=team_id,
)
verbose_proxy_logger.info(
"Secret updated in secret manager: secret_name=%s",
new_secret_name,
)
except Exception as e:
verbose_proxy_logger.warning(
f"Failed to rotate virtual key in secret manager: {e}"
)
# Send key rotated email if configured - non-blocking, independent operation
try:
await KeyManagementEventHooks._send_key_rotated_email(
response=response.model_dump(exclude_none=True),
existing_key_alias=existing_key_row.key_alias,
)
except Exception as e:
verbose_proxy_logger.warning(f"Failed to send key rotated email: {e}")
# store the audit log
if litellm.store_audit_logs is True and existing_key_row.token is not None:
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.token,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=existing_key_row.token,
action="rotated",
updated_values=response.model_dump_json(exclude_none=True),
before_value=existing_key_row.model_dump_json(
exclude_none=True
),
)
)
)
@staticmethod
async def async_key_deleted_hook(
data: KeyRequest,
keys_being_deleted: List[LiteLLM_VerificationToken],
response: dict,
user_api_key_dict: UserAPIKeyAuth,
litellm_changed_by: Optional[str] = None,
):
"""
Post /key/delete processing hook
Handles the following:
- Storing Audit Logs for key deletion
"""
from litellm.proxy.management_helpers.audit_logs import (
create_audit_log_for_update,
)
from litellm.proxy.proxy_server import litellm_proxy_admin_name
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
if litellm.store_audit_logs is True and data.keys is not None:
# make an audit log for each key deleted
for key in keys_being_deleted:
if key.token is None:
continue
_key_row = key.model_dump_json(exclude_none=True)
asyncio.create_task(
create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.token,
table_name=LitellmTableNames.KEY_TABLE_NAME,
object_id=key.token,
action="deleted",
updated_values="{}",
before_value=_key_row,
)
)
)
# delete the keys from the secret manager
await KeyManagementEventHooks._delete_virtual_keys_from_secret_manager(
keys_being_deleted=keys_being_deleted
)
pass
@staticmethod
async def _store_virtual_key_in_secret_manager(
secret_name: str, secret_token: str, team_id: Optional[str] = None
):
"""
Store a virtual key in the secret manager
Args:
secret_name: Name of the virtual key
secret_token: Value of the virtual key (example: sk-1234)
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
# store the key in the secret manager
if isinstance(litellm.secret_manager_client, BaseSecretManager):
tags = getattr(litellm._key_management_settings, "tags", None)
description = getattr(
litellm._key_management_settings, "description", None
)
optional_params = await KeyManagementEventHooks._get_secret_manager_optional_params(
team_id
)
verbose_proxy_logger.debug(
f"Creating secret with {secret_name} and tags={tags} and description={description}"
)
await litellm.secret_manager_client.async_write_secret(
secret_name=KeyManagementEventHooks._get_secret_name(
secret_name
),
description=description,
secret_value=secret_token,
tags=tags,
optional_params=optional_params,
)
@staticmethod
async def _rotate_virtual_key_in_secret_manager(
current_secret_name: str,
new_secret_name: str,
new_secret_value: str,
team_id: Optional[str] = None,
):
"""
Update a virtual key in the secret manager
Args:
current_secret_name: Current name of the virtual key
new_secret_name: New name of the virtual key
new_secret_value: New value of the virtual key (example: sk-1234)
team_id: Optional team ID to get team-specific secret manager settings
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
# store the key in the secret manager
if isinstance(litellm.secret_manager_client, BaseSecretManager):
optional_params = await KeyManagementEventHooks._get_secret_manager_optional_params(
team_id
)
await litellm.secret_manager_client.async_rotate_secret(
current_secret_name=KeyManagementEventHooks._get_secret_name(
current_secret_name
),
new_secret_name=KeyManagementEventHooks._get_secret_name(
new_secret_name
),
new_secret_value=new_secret_value,
optional_params=optional_params,
)
@staticmethod
def _get_secret_name(secret_name: str) -> str:
if litellm._key_management_settings.prefix_for_stored_virtual_keys.endswith(
"/"
):
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}{secret_name}"
else:
return f"{litellm._key_management_settings.prefix_for_stored_virtual_keys}/{secret_name}"
@staticmethod
async def _delete_virtual_keys_from_secret_manager(
keys_being_deleted: List[LiteLLM_VerificationToken],
):
"""
Deletes virtual keys from the secret manager
Args:
keys_being_deleted: List of keys being deleted, this is passed down from the /key/delete operation
"""
if litellm._key_management_settings is not None:
if litellm._key_management_settings.store_virtual_keys is True:
from litellm.secret_managers.base_secret_manager import (
BaseSecretManager,
)
if isinstance(litellm.secret_manager_client, BaseSecretManager):
team_settings_cache: Dict[Optional[str], Optional[dict]] = {}
for key in keys_being_deleted:
if key.key_alias is not None:
team_id = getattr(key, "team_id", None)
if team_id not in team_settings_cache:
team_settings_cache[
team_id
] = await KeyManagementEventHooks._get_secret_manager_optional_params(
team_id
)
optional_params = team_settings_cache[team_id]
await litellm.secret_manager_client.async_delete_secret(
secret_name=KeyManagementEventHooks._get_secret_name(
key.key_alias
),
optional_params=optional_params,
)
else:
verbose_proxy_logger.warning(
f"KeyManagementEventHooks._delete_virtual_key_from_secret_manager: Key alias not found for key {key.token}. Skipping deletion from secret manager."
)
@staticmethod
async def _get_secret_manager_optional_params(
team_id: Optional[str],
) -> Optional[dict]:
if team_id is None:
return None
try:
from litellm.proxy import proxy_server as proxy_server_module
except ImportError:
return None
prisma_client = getattr(proxy_server_module, "prisma_client", None)
user_api_key_cache = getattr(proxy_server_module, "user_api_key_cache", None)
if prisma_client is None or user_api_key_cache is None:
return None
try:
from litellm.proxy.auth.auth_checks import get_team_object
team_obj = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
except Exception as exc: # pragma: no cover - defensive logging
verbose_proxy_logger.debug(
f"Unable to load team metadata for team_id={team_id}: {exc}"
)
return None
metadata = getattr(team_obj, "metadata", None)
if metadata is None:
return None
if hasattr(metadata, "model_dump"):
metadata = metadata.model_dump()
if not isinstance(metadata, dict):
return None
team_settings = metadata.get("secret_manager_settings")
if isinstance(team_settings, dict) and team_settings:
return dict(team_settings)
return None
@staticmethod
def _is_email_sending_enabled() -> bool:
"""
Check if email sending is enabled via v2 enterprise loggers or v0 alerting config.
Returns True only if email is actually configured, preventing any email
processing when the user has not opted in.
"""
# Check v2 enterprise email loggers
try:
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
BaseEmailLogger,
)
initialized_email_loggers = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=BaseEmailLogger
)
)
if len(initialized_email_loggers) > 0:
return True
except ImportError:
pass
# Check v0 alerting config
from litellm.proxy.proxy_server import general_settings
if "email" in general_settings.get("alerting", []):
return True
return False
@staticmethod
async def _send_key_created_email(response: dict):
"""
Send key created email if email sending is enabled.
This method is non-blocking - it will return silently if email is not
configured, and will log warnings instead of raising exceptions on failure.
"""
# Early exit if email is not enabled
if not KeyManagementEventHooks._is_email_sending_enabled():
verbose_proxy_logger.debug(
"Email sending not enabled, skipping key created email"
)
return
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
##########################
# v2 integration for emails (enterprise)
##########################
try:
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
BaseEmailLogger,
)
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
SendKeyCreatedEmailEvent,
)
initialized_email_loggers = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=BaseEmailLogger
)
)
if len(initialized_email_loggers) > 0:
event = SendKeyCreatedEmailEvent(
virtual_key=response.get("key", ""),
event="key_created",
event_group=Litellm_EntityType.KEY,
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
for email_logger in initialized_email_loggers:
if isinstance(email_logger, BaseEmailLogger):
await email_logger.send_key_created_email(
send_key_created_email_event=event,
)
return
except ImportError:
pass
##########################
# v0 integration for emails
##########################
if "email" in general_settings.get("alerting", []):
from litellm.proxy._types import WebhookEvent
event = WebhookEvent(
event="key_created",
event_group=Litellm_EntityType.KEY,
event_message="API Key Created",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", None),
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)
@staticmethod
async def _send_key_rotated_email(
response: dict, existing_key_alias: Optional[str]
):
"""
Send key rotated email if email sending is enabled.
This method is non-blocking - it will return silently if email is not
configured, and will log warnings instead of raising exceptions on failure.
"""
# Early exit if email is not enabled
if not KeyManagementEventHooks._is_email_sending_enabled():
verbose_proxy_logger.debug(
"Email sending not enabled, skipping key rotated email"
)
return
try:
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
BaseEmailLogger,
)
except ImportError:
# Enterprise package not installed - v0 doesn't support key rotated email
verbose_proxy_logger.debug(
"Enterprise package not installed, skipping key rotated email"
)
return
try:
from litellm_enterprise.types.enterprise_callbacks.send_emails import (
SendKeyRotatedEmailEvent,
)
except ImportError:
verbose_proxy_logger.debug(
"Enterprise types not available, skipping key rotated email"
)
return
event = SendKeyRotatedEmailEvent(
virtual_key=response.get("key", ""),
event="key_rotated",
event_group=Litellm_EntityType.KEY,
event_message="API Key Rotated",
token=response.get("token", ""),
spend=response.get("spend", 0.0),
max_budget=response.get("max_budget", 0.0),
user_id=response.get("user_id", None),
team_id=response.get("team_id", "Default Team"),
key_alias=response.get("key_alias", existing_key_alias),
)
##########################
# v2 integration for emails
##########################
initialized_email_loggers = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=BaseEmailLogger
)
)
if len(initialized_email_loggers) > 0:
for email_logger in initialized_email_loggers:
if isinstance(email_logger, BaseEmailLogger):
await email_logger.send_key_rotated_email(
send_key_rotated_email_event=event,
)

View File

@@ -0,0 +1,39 @@
"""
LiteLLM Skills Hook - Proxy integration for skills
This module provides the CustomLogger hook for skills processing.
The actual skill logic is in litellm/llms/litellm_proxy/skills/.
Usage:
from litellm.proxy.hooks.litellm_skills import SkillsInjectionHook
# Register hook in proxy
litellm.callbacks.append(SkillsInjectionHook())
"""
# Re-export from the SDK location for convenience
from litellm.llms.litellm_proxy.skills import (
LITELLM_CODE_EXECUTION_TOOL,
CodeExecutionHandler,
LiteLLMInternalTools,
SkillPromptInjectionHandler,
SkillsSandboxExecutor,
code_execution_handler,
get_litellm_code_execution_tool,
)
from litellm.proxy.hooks.litellm_skills.main import (
SkillsInjectionHook,
skills_injection_hook,
)
__all__ = [
"SkillsInjectionHook",
"skills_injection_hook",
"CodeExecutionHandler",
"LiteLLMInternalTools",
"LITELLM_CODE_EXECUTION_TOOL",
"get_litellm_code_execution_tool",
"code_execution_handler",
"SkillPromptInjectionHandler",
"SkillsSandboxExecutor",
]

View File

@@ -0,0 +1,914 @@
"""
Skills Injection Hook for LiteLLM Proxy
Main hook that orchestrates skill processing:
- Fetches skills from LiteLLM DB
- Injects SKILL.md content into system prompt
- Adds litellm_code_execution tool for automatic code execution
- Handles agentic loop internally when litellm_code_execution is called
For non-Anthropic models (e.g., Bedrock, OpenAI, etc.):
- Skills are converted to OpenAI-style tools
- Skill file content (SKILL.md) is extracted and injected into the system prompt
- litellm_code_execution tool is added - when model calls it, LiteLLM handles
execution automatically and returns final response with file_ids
Usage:
# Simple - LiteLLM handles everything automatically via proxy
# The container parameter triggers the SkillsInjectionHook
response = await litellm.acompletion(
model="gpt-4o-mini",
messages=[{"role": "user", "content": "Create a bouncing ball GIF"}],
container={"skills": [{"skill_id": "litellm:skill_abc123"}]},
)
# Response includes file_ids for generated files
"""
import base64
import json
from typing import Any, Dict, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.litellm_proxy.skills.prompt_injection import (
SkillPromptInjectionHandler,
)
from litellm.proxy._types import LiteLLM_SkillsTable, UserAPIKeyAuth
from litellm.types.utils import CallTypes, CallTypesLiteral
class SkillsInjectionHook(CustomLogger):
"""
Pre/Post-call hook that processes skills from container.skills parameter.
Pre-call (async_pre_call_hook):
- Skills with 'litellm:' prefix are fetched from LiteLLM DB
- For Anthropic models: native skills pass through, LiteLLM skills converted to tools
- For non-Anthropic models: LiteLLM skills are converted to tools + execute_code tool
Post-call (async_post_call_success_deployment_hook):
- If response has litellm_code_execution tool call, automatically execute code
- Continue conversation loop until model gives final response
- Return response with generated files inline
This hook is called automatically by litellm during completion calls.
"""
def __init__(self, **kwargs):
from litellm.llms.litellm_proxy.skills.constants import (
DEFAULT_MAX_ITERATIONS,
DEFAULT_SANDBOX_TIMEOUT,
)
self.optional_params = kwargs
self.prompt_handler = SkillPromptInjectionHandler()
self.max_iterations = kwargs.get("max_iterations", DEFAULT_MAX_ITERATIONS)
self.sandbox_timeout = kwargs.get("sandbox_timeout", DEFAULT_SANDBOX_TIMEOUT)
super().__init__(**kwargs)
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
"""
Process skills from container.skills before the LLM call.
1. Check if container.skills exists in request
2. Separate skills by prefix (litellm: vs native)
3. Fetch LiteLLM skills from database
4. For Anthropic: keep native skills in container
5. For non-Anthropic: convert LiteLLM skills to tools, inject content, add execute_code
"""
# Only process completion-type calls
if call_type not in ["completion", "acompletion", "anthropic_messages"]:
return data
container = data.get("container")
if not container or not isinstance(container, dict):
return data
skills = container.get("skills")
if not skills or not isinstance(skills, list):
return data
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Processing {len(skills)} skills"
)
litellm_skills: List[LiteLLM_SkillsTable] = []
anthropic_skills: List[Dict[str, Any]] = []
# Separate skills by prefix
for skill in skills:
if not isinstance(skill, dict):
continue
skill_id = skill.get("skill_id", "")
if skill_id.startswith("litellm_"):
# Fetch from LiteLLM DB
db_skill = await self._fetch_skill_from_db(skill_id)
if db_skill:
litellm_skills.append(db_skill)
else:
verbose_proxy_logger.warning(
f"SkillsInjectionHook: Skill '{skill_id}' not found in LiteLLM DB"
)
else:
# Native Anthropic skill - pass through
anthropic_skills.append(skill)
# Check if using messages API spec (anthropic_messages call type)
# Messages API always uses Anthropic-style tool format
use_anthropic_format = call_type == "anthropic_messages"
if len(litellm_skills) > 0:
data = self._process_for_messages_api(
data=data,
litellm_skills=litellm_skills,
use_anthropic_format=use_anthropic_format,
)
return data
def _process_for_messages_api(
self,
data: dict,
litellm_skills: List[LiteLLM_SkillsTable],
use_anthropic_format: bool = True,
) -> dict:
"""
Process skills for messages API (Anthropic format tools).
- Converts skills to Anthropic-style tools (name, description, input_schema)
- Extracts and injects SKILL.md content into system prompt
- Adds litellm_code_execution tool for code execution
- Stores skill files in metadata for sandbox execution
"""
from litellm.llms.litellm_proxy.skills.code_execution import (
get_litellm_code_execution_tool_anthropic,
)
tools = data.get("tools", [])
skill_contents: List[str] = []
all_skill_files: Dict[str, Dict[str, bytes]] = {}
all_module_paths: List[str] = []
for skill in litellm_skills:
# Convert skill to Anthropic-style tool
tools.append(self.prompt_handler.convert_skill_to_anthropic_tool(skill))
# Extract skill content from file if available
content = self.prompt_handler.extract_skill_content(skill)
if content:
skill_contents.append(content)
# Extract all files for code execution
skill_files = self.prompt_handler.extract_all_files(skill)
if skill_files:
all_skill_files[skill.skill_id] = skill_files
for path in skill_files.keys():
if path.endswith(".py"):
all_module_paths.append(path)
if tools:
data["tools"] = tools
# Inject skill content into system prompt
# For Anthropic messages API, use top-level 'system' param instead of messages array
if skill_contents:
data = self.prompt_handler.inject_skill_content_to_messages(
data, skill_contents, use_anthropic_format=use_anthropic_format
)
# Add litellm_code_execution tool if we have skill files
if all_skill_files:
code_exec_tool = get_litellm_code_execution_tool_anthropic()
data["tools"] = data.get("tools", []) + [code_exec_tool]
# Store skill files in litellm_metadata for automatic code execution
data["litellm_metadata"] = data.get("litellm_metadata", {})
data["litellm_metadata"]["_skill_files"] = all_skill_files
data["litellm_metadata"]["_litellm_code_execution_enabled"] = True
# Remove container (not supported by underlying providers)
data.pop("container", None)
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Messages API - converted {len(litellm_skills)} skills to Anthropic tools, "
f"injected {len(skill_contents)} skill contents, "
f"added litellm_code_execution tool with {len(all_module_paths)} modules"
)
return data
def _process_non_anthropic_model(
self,
data: dict,
litellm_skills: List[LiteLLM_SkillsTable],
) -> dict:
"""
Process skills for non-Anthropic models (OpenAI format tools).
- Converts skills to OpenAI-style tools
- Extracts and injects SKILL.md content
- Adds execute_code tool for code execution
- Stores skill files in metadata for sandbox execution
"""
tools = data.get("tools", [])
skill_contents: List[str] = []
all_skill_files: Dict[str, Dict[str, bytes]] = {}
all_module_paths: List[str] = []
for skill in litellm_skills:
# Convert skill to OpenAI-style tool
tools.append(self.prompt_handler.convert_skill_to_tool(skill))
# Extract skill content from file if available
content = self.prompt_handler.extract_skill_content(skill)
if content:
skill_contents.append(content)
# Extract all files for code execution
skill_files = self.prompt_handler.extract_all_files(skill)
if skill_files:
all_skill_files[skill.skill_id] = skill_files
# Collect Python module paths
for path in skill_files.keys():
if path.endswith(".py"):
all_module_paths.append(path)
if tools:
data["tools"] = tools
# Inject skill content into system prompt
if skill_contents:
data = self.prompt_handler.inject_skill_content_to_messages(
data, skill_contents
)
# Add litellm_code_execution tool if we have skill files
if all_skill_files:
from litellm.llms.litellm_proxy.skills.code_execution import (
get_litellm_code_execution_tool,
)
data["tools"] = data.get("tools", []) + [get_litellm_code_execution_tool()]
# Store skill files in litellm_metadata for automatic code execution
# Using litellm_metadata instead of metadata to avoid conflicts with user metadata
data["litellm_metadata"] = data.get("litellm_metadata", {})
data["litellm_metadata"]["_skill_files"] = all_skill_files
data["litellm_metadata"]["_litellm_code_execution_enabled"] = True
# Remove container for non-Anthropic (they don't support it)
data.pop("container", None)
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Non-Anthropic model - converted {len(litellm_skills)} skills to tools, "
f"injected {len(skill_contents)} skill contents, "
f"added execute_code tool with {len(all_module_paths)} modules"
)
return data
async def _fetch_skill_from_db(
self, skill_id: str
) -> Optional[LiteLLM_SkillsTable]:
"""
Fetch a skill from the LiteLLM database.
Args:
skill_id: The skill ID (without 'litellm:' prefix)
Returns:
LiteLLM_SkillsTable or None if not found
"""
try:
from litellm.llms.litellm_proxy.skills.handler import LiteLLMSkillsHandler
return await LiteLLMSkillsHandler.fetch_skill_from_db(skill_id)
except Exception as e:
verbose_proxy_logger.warning(
f"SkillsInjectionHook: Error fetching skill {skill_id}: {e}"
)
return None
def _is_anthropic_model(self, model: str) -> bool:
"""
Check if the model is an Anthropic model using get_llm_provider.
Args:
model: The model name/identifier
Returns:
True if Anthropic model, False otherwise
"""
try:
from litellm.litellm_core_utils.get_llm_provider_logic import (
get_llm_provider,
)
_, custom_llm_provider, _, _ = get_llm_provider(model=model)
return custom_llm_provider == "anthropic"
except Exception:
# Fallback to simple check if get_llm_provider fails
return "claude" in model.lower() or model.lower().startswith("anthropic/")
async def async_post_call_success_deployment_hook(
self,
request_data: dict,
response: Any,
call_type: Optional[CallTypes],
) -> Optional[Any]:
"""
Post-call hook to handle automatic code execution.
Handles both OpenAI format (response.choices) and Anthropic/messages API
format (response["content"]).
If the response contains a tool call (litellm_code_execution or skill tool):
1. Execute the code in sandbox
2. Add result to messages
3. Make another LLM call
4. Repeat until model gives final response
5. Return modified response with generated files
"""
from litellm.llms.litellm_proxy.skills.code_execution import (
LiteLLMInternalTools,
)
# Check if code execution is enabled for this request
litellm_metadata = request_data.get("litellm_metadata") or {}
metadata = request_data.get("metadata") or {}
code_exec_enabled = litellm_metadata.get(
"_litellm_code_execution_enabled"
) or metadata.get("_litellm_code_execution_enabled")
if not code_exec_enabled:
return None
# Get skill files
skill_files_by_id = litellm_metadata.get("_skill_files") or metadata.get(
"_skill_files", {}
)
all_skill_files: Dict[str, bytes] = {}
for files_dict in skill_files_by_id.values():
all_skill_files.update(files_dict)
if not all_skill_files:
verbose_proxy_logger.warning(
"SkillsInjectionHook: No skill files found, cannot execute code"
)
return None
# Check for tool calls - handle both Anthropic and OpenAI formats
tool_calls = self._extract_tool_calls(response)
if not tool_calls:
return None
# Check if any tool call needs execution (litellm_code_execution or skill tool)
has_executable_tool = False
for tc in tool_calls:
tool_name = tc.get("name", "")
# Execute if it's litellm_code_execution OR a skill tool (skill_xxx)
if (
tool_name == LiteLLMInternalTools.CODE_EXECUTION.value
or tool_name.startswith("skill_")
):
has_executable_tool = True
break
if not has_executable_tool:
return None
verbose_proxy_logger.debug(
"SkillsInjectionHook: Detected tool call, starting execution loop"
)
# Start the agentic loop
return await self._execute_code_loop_messages_api(
data=request_data,
response=response,
skill_files=all_skill_files,
)
def _extract_tool_calls(self, response: Any) -> List[Dict[str, Any]]:
"""Extract tool calls from response, handling both formats."""
tool_calls = []
# Get content - handle both dict and object responses
content = None
if isinstance(response, dict):
content = response.get("content", [])
elif hasattr(response, "content"):
content = response.content
# Anthropic/messages API format: response has "content" list with tool_use blocks
if content:
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"name": block.get("name"),
"input": block.get("input", {}),
}
)
elif (
hasattr(block, "type")
and getattr(block, "type", None) == "tool_use"
):
tool_calls.append(
{
"id": getattr(block, "id", None),
"name": getattr(block, "name", None),
"input": getattr(block, "input", {}),
}
)
# OpenAI format: response has choices[0].message.tool_calls
if not tool_calls and hasattr(response, "choices") and response.choices: # type: ignore[union-attr]
msg = response.choices[0].message # type: ignore[union-attr]
if hasattr(msg, "tool_calls") and msg.tool_calls:
for tc in msg.tool_calls:
tool_calls.append(
{
"id": tc.id,
"name": tc.function.name,
"input": json.loads(tc.function.arguments)
if tc.function.arguments
else {},
}
)
return tool_calls
async def _execute_code_loop_messages_api(
self,
data: dict,
response: Any,
skill_files: Dict[str, bytes],
) -> Any:
"""
Execute the code execution loop for messages API (Anthropic format).
Returns the final response with generated files inline.
"""
import litellm
from litellm.llms.litellm_proxy.skills.code_execution import (
LiteLLMInternalTools,
)
from litellm.llms.litellm_proxy.skills.sandbox_executor import (
SkillsSandboxExecutor,
)
# Ensure response is not None
if response is None:
verbose_proxy_logger.error(
"SkillsInjectionHook: Response is None, cannot execute code loop"
)
return None
model = data.get("model", "")
messages = list(data.get("messages", []))
tools = data.get("tools", [])
max_tokens = data.get("max_tokens", 4096)
executor = SkillsSandboxExecutor(timeout=self.sandbox_timeout)
generated_files: List[Dict[str, Any]] = []
current_response = response
for iteration in range(self.max_iterations):
# Extract tool calls from current response
tool_calls = self._extract_tool_calls(current_response)
stop_reason = (
current_response.get("stop_reason")
if isinstance(current_response, dict)
else getattr(current_response, "stop_reason", None)
)
# Get content for assistant message - convert to plain dicts
raw_content = (
current_response.get("content", [])
if isinstance(current_response, dict)
else getattr(current_response, "content", [])
)
content_blocks = []
for block in raw_content or []:
if isinstance(block, dict):
content_blocks.append(block)
elif hasattr(block, "model_dump"):
content_blocks.append(block.model_dump())
elif hasattr(block, "__dict__"):
content_blocks.append(dict(block.__dict__))
else:
content_blocks.append({"type": "text", "text": str(block)})
# Build assistant message for conversation history (Anthropic format)
assistant_msg = {"role": "assistant", "content": content_blocks}
messages.append(assistant_msg)
# Check if we're done (no tool calls)
if stop_reason != "tool_use" or not tool_calls:
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Loop completed after {iteration + 1} iterations, "
f"{len(generated_files)} files generated"
)
return self._attach_files_to_response(current_response, generated_files)
# Process tool calls
tool_results = []
for tc in tool_calls:
tool_name = tc.get("name", "")
tool_id = tc.get("id", "")
tool_input = tc.get("input", {})
# Execute if it's litellm_code_execution OR a skill tool
if tool_name == LiteLLMInternalTools.CODE_EXECUTION.value:
code = tool_input.get("code", "")
result = await self._execute_code(
code, skill_files, executor, generated_files
)
elif tool_name.startswith("skill_"):
# Skill tool - execute the skill's code
result = await self._execute_skill_tool(
tool_name, tool_input, skill_files, executor, generated_files
)
else:
result = f"Tool '{tool_name}' not handled"
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": result,
}
)
# Add tool results to messages (Anthropic format)
messages.append({"role": "user", "content": tool_results})
# Make next LLM call
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Making LLM call iteration {iteration + 2}"
)
try:
current_response = await litellm.anthropic.acreate(
model=model,
messages=messages,
tools=tools,
max_tokens=max_tokens,
)
if current_response is None:
verbose_proxy_logger.error(
"SkillsInjectionHook: LLM call returned None"
)
return self._attach_files_to_response(response, generated_files)
except Exception as e:
verbose_proxy_logger.error(f"SkillsInjectionHook: LLM call failed: {e}")
return self._attach_files_to_response(response, generated_files)
verbose_proxy_logger.warning(
f"SkillsInjectionHook: Max iterations ({self.max_iterations}) reached"
)
return self._attach_files_to_response(current_response, generated_files)
async def _execute_code(
self,
code: str,
skill_files: Dict[str, bytes],
executor: Any,
generated_files: List[Dict[str, Any]],
) -> str:
"""Execute code in sandbox and return result string."""
try:
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Executing code ({len(code)} chars)"
)
exec_result = executor.execute(code=code, skill_files=skill_files)
result = exec_result.get("output", "") or ""
# Collect generated files
if exec_result.get("files"):
for f in exec_result["files"]:
generated_files.append(
{
"name": f["name"],
"mime_type": f["mime_type"],
"content_base64": f["content_base64"],
"size": len(base64.b64decode(f["content_base64"])),
}
)
result += f"\n\nGenerated file: {f['name']}"
if exec_result.get("error"):
result += f"\n\nError: {exec_result['error']}"
return result or "Code executed successfully"
except Exception as e:
return f"Code execution failed: {str(e)}"
async def _execute_skill_tool(
self,
tool_name: str,
tool_input: Dict[str, Any],
skill_files: Dict[str, bytes],
executor: Any,
generated_files: List[Dict[str, Any]],
) -> str:
"""Execute a skill tool by generating and running code based on skill content."""
# Generate code based on available skill modules
# Look for Python modules in the skill
python_modules = [
p
for p in skill_files.keys()
if p.endswith(".py") and not p.endswith("__init__.py")
]
# Try to find the main builder/creator module
main_module = None
for mod in python_modules:
if (
"builder" in mod.lower()
or "creator" in mod.lower()
or "generator" in mod.lower()
):
main_module = mod
break
if not main_module and python_modules:
# Use first non-init module
main_module = python_modules[0]
if main_module:
# Convert path to import: "core/gif_builder.py" -> "core.gif_builder"
import_path = main_module.replace("/", ".").replace(".py", "")
# Generate code that imports and uses the module
code = f"""
# Auto-generated code to execute skill
import sys
sys.path.insert(0, '/sandbox')
from {import_path} import *
# Try to find and use a Builder/Creator class
import inspect
module = __import__('{import_path}', fromlist=[''])
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) and name != 'object':
try:
instance = obj()
# Try common methods
if hasattr(instance, 'create'):
result = instance.create()
elif hasattr(instance, 'build'):
result = instance.build()
elif hasattr(instance, 'generate'):
result = instance.generate()
elif hasattr(instance, 'save'):
instance.save('output.gif')
print(f'Used {{name}} class')
break
except Exception as e:
print(f'Error with {{name}}: {{e}}')
continue
# List generated files
import os
for f in os.listdir('.'):
if f.endswith(('.gif', '.png', '.jpg')):
print(f'Generated: {{f}}')
"""
else:
# Fallback generic code
code = """
print('No executable skill module found')
"""
return await self._execute_code(code, skill_files, executor, generated_files)
async def _execute_code_loop(
self,
data: dict,
response: Any,
skill_files: Dict[str, bytes],
) -> Any:
"""
Execute the code execution loop until model gives final response.
Returns the final response with generated files inline.
"""
import litellm
from litellm.llms.litellm_proxy.skills.code_execution import (
LiteLLMInternalTools,
)
from litellm.llms.litellm_proxy.skills.sandbox_executor import (
SkillsSandboxExecutor,
)
model = data.get("model", "")
messages = list(data.get("messages", []))
tools = data.get("tools", [])
# Keys to exclude when passing through to acompletion
# These are either handled explicitly or are internal LiteLLM fields
_EXCLUDED_ACOMPLETION_KEYS = frozenset(
{
"messages",
"model",
"tools",
"metadata",
"litellm_metadata",
"container",
}
)
kwargs = {k: v for k, v in data.items() if k not in _EXCLUDED_ACOMPLETION_KEYS}
executor = SkillsSandboxExecutor(timeout=self.sandbox_timeout)
generated_files: List[Dict[str, Any]] = []
current_response: Any = response
for iteration in range(self.max_iterations):
# OpenAI format response has choices[0].message
assistant_message = current_response.choices[0].message # type: ignore[union-attr]
stop_reason = current_response.choices[0].finish_reason # type: ignore[union-attr]
# Build assistant message for conversation history
assistant_msg_dict: Dict[str, Any] = {
"role": "assistant",
"content": assistant_message.content,
}
if assistant_message.tool_calls:
assistant_msg_dict["tool_calls"] = [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in assistant_message.tool_calls
]
messages.append(assistant_msg_dict)
# Check if we're done (no tool calls)
if stop_reason != "tool_calls" or not assistant_message.tool_calls:
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Code execution loop completed after "
f"{iteration + 1} iterations, {len(generated_files)} files generated"
)
# Attach generated files to response
return self._attach_files_to_response(current_response, generated_files)
# Process tool calls
for tool_call in assistant_message.tool_calls:
tool_name = tool_call.function.name
if tool_name == LiteLLMInternalTools.CODE_EXECUTION.value:
tool_result = await self._execute_code_tool(
tool_call=tool_call,
skill_files=skill_files,
executor=executor,
generated_files=generated_files,
)
else:
# Non-code-execution tool - cannot handle
tool_result = f"Tool '{tool_name}' not handled automatically"
messages.append(
{
"role": "tool",
"tool_call_id": tool_call.id,
"content": tool_result,
}
)
# Make next LLM call using the messages API
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Making LLM call iteration {iteration + 2}"
)
current_response = await litellm.anthropic.acreate(
model=model,
messages=messages,
tools=tools,
max_tokens=kwargs.get("max_tokens", 4096),
)
# Max iterations reached
verbose_proxy_logger.warning(
f"SkillsInjectionHook: Max iterations ({self.max_iterations}) reached"
)
return self._attach_files_to_response(current_response, generated_files)
async def _execute_code_tool(
self,
tool_call: Any,
skill_files: Dict[str, bytes],
executor: Any,
generated_files: List[Dict[str, Any]],
) -> str:
"""Execute a litellm_code_execution tool call and return result string."""
try:
args = json.loads(tool_call.function.arguments)
code = args.get("code", "")
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Executing code ({len(code)} chars)"
)
exec_result = executor.execute(
code=code,
skill_files=skill_files,
)
# Build tool result content
tool_result = exec_result.get("output", "") or ""
# Collect generated files
if exec_result.get("files"):
tool_result += "\n\nGenerated files:"
for f in exec_result["files"]:
file_content = base64.b64decode(f["content_base64"])
generated_files.append(
{
"name": f["name"],
"mime_type": f["mime_type"],
"content_base64": f["content_base64"],
"size": len(file_content),
}
)
tool_result += f"\n- {f['name']} ({len(file_content)} bytes)"
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Generated file {f['name']} "
f"({len(file_content)} bytes)"
)
if exec_result.get("error"):
tool_result += f"\n\nError:\n{exec_result['error']}"
return tool_result
except Exception as e:
verbose_proxy_logger.error(
f"SkillsInjectionHook: Code execution failed: {e}"
)
return f"Code execution failed: {str(e)}"
def _attach_files_to_response(
self,
response: Any,
generated_files: List[Dict[str, Any]],
) -> Any:
"""
Attach generated files to the response object.
Files are added to response._litellm_generated_files for easy access.
For dict responses, files are added as a key.
"""
if not generated_files:
return response
# Handle dict response (Anthropic/messages API format)
if isinstance(response, dict):
response["_litellm_generated_files"] = generated_files
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Attached {len(generated_files)} files to dict response"
)
return response
# Handle object response (OpenAI format)
try:
response._litellm_generated_files = generated_files
except AttributeError:
pass
# Also add to model_extra if available (for serialization)
if hasattr(response, "model_extra"):
if response.model_extra is None:
response.model_extra = {}
response.model_extra["_litellm_generated_files"] = generated_files
verbose_proxy_logger.debug(
f"SkillsInjectionHook: Attached {len(generated_files)} files to response"
)
return response
# Global instance for registration
skills_injection_hook = SkillsInjectionHook()
import litellm
litellm.logging_callback_manager.add_litellm_callback(skills_injection_hook)

View File

@@ -0,0 +1,49 @@
from fastapi import HTTPException
from litellm import verbose_logger
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
class _PROXY_MaxBudgetLimiter(CustomLogger):
# Class variables or attributes
def __init__(self):
pass
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
try:
verbose_proxy_logger.debug("Inside Max Budget Limiter Pre-Call Hook")
cache_key = f"{user_api_key_dict.user_id}_user_api_key_user_id"
user_row = await cache.async_get_cache(
cache_key, parent_otel_span=user_api_key_dict.parent_otel_span
)
if user_row is None: # value not yet cached
return
max_budget = user_row["max_budget"]
curr_spend = user_row["spend"]
if max_budget is None:
return
if curr_spend is None:
return
# CHECK IF REQUEST ALLOWED
if curr_spend >= max_budget:
raise HTTPException(status_code=429, detail="Max budget limit reached.")
except HTTPException as e:
raise e
except Exception as e:
verbose_logger.exception(
"litellm.proxy.hooks.max_budget_limiter.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)

View File

@@ -0,0 +1,267 @@
"""
Per-Session Budget Limiter for LiteLLM Proxy.
Enforces a dollar-amount cap per session (identified by `session_id` /
`x-litellm-trace-id`). After each successful LLM call the response cost is
accumulated against the session. When the accumulated spend exceeds
`max_budget_per_session` (configured in agent litellm_params), subsequent
requests for that session receive a 429.
Note: trace-id enforcement (require_trace_id_on_calls_by_agent) is handled
separately in auth_checks.py at the agent level, not in this hook.
Works across multiple proxy instances via DualCache (in-memory + Redis).
Follows the same pattern as max_iterations_limiter.py.
"""
import os
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
if TYPE_CHECKING:
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
InternalUsageCache = _InternalUsageCache
else:
InternalUsageCache = Any
# Redis Lua script for atomic float increment with TTL.
# INCRBYFLOAT returns the new value as a string.
# Only sets EXPIRE on first call (when prior value was nil).
MAX_BUDGET_SESSION_INCREMENT_SCRIPT = """
local key = KEYS[1]
local amount = ARGV[1]
local ttl = tonumber(ARGV[2])
local existed = redis.call('EXISTS', key)
local new_val = redis.call('INCRBYFLOAT', key, amount)
if existed == 0 then
redis.call('EXPIRE', key, ttl)
end
return new_val
"""
# Default TTL for session budget counters (1 hour)
DEFAULT_MAX_BUDGET_PER_SESSION_TTL = 3600
class _PROXY_MaxBudgetPerSessionHandler(CustomLogger):
"""
Pre-call hook that enforces max_budget_per_session.
Configuration (set in agent litellm_params):
- max_budget_per_session: dollar cap per session_id
Cache key pattern:
{session_budget:<session_id>}:spend
"""
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
self.ttl = int(
os.getenv(
"LITELLM_MAX_BUDGET_PER_SESSION_TTL",
DEFAULT_MAX_BUDGET_PER_SESSION_TTL,
)
)
if self.internal_usage_cache.dual_cache.redis_cache is not None:
self.increment_script = (
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
MAX_BUDGET_SESSION_INCREMENT_SCRIPT
)
)
else:
self.increment_script = None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
) -> Optional[Union[Exception, str, dict]]:
"""
Before each LLM call, check if max_budget_per_session is set and
whether accumulated spend exceeds the budget (429 if so).
"""
max_budget = self._get_max_budget_per_session(user_api_key_dict)
session_id = self._get_session_id(data)
if max_budget is None or session_id is None:
return None
max_budget = float(max_budget)
cache_key = self._make_cache_key(session_id)
current_spend = await self._get_current_spend(cache_key)
verbose_proxy_logger.debug(
"MaxBudgetPerSessionHandler: session_id=%s, spend=%.4f, max=%.2f",
session_id,
current_spend,
max_budget,
)
if current_spend >= max_budget:
raise HTTPException(
status_code=429,
detail=(
f"Session budget exceeded for session {session_id}. "
f"Current spend: ${current_spend:.4f}, "
f"max_budget_per_session: ${max_budget:.2f}."
),
)
return None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
After a successful LLM call, increment the session spend by the response cost.
"""
try:
litellm_params = kwargs.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
session_id = metadata.get("session_id")
if session_id is None:
return
agent_id = metadata.get("agent_id")
if agent_id is None:
return
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry,
)
agent = global_agent_registry.get_agent_by_id(agent_id=str(agent_id))
if agent is None:
return
agent_litellm_params = agent.litellm_params or {}
max_budget = agent_litellm_params.get("max_budget_per_session")
if max_budget is None:
return
response_cost = kwargs.get("response_cost") or 0.0
if response_cost <= 0:
return
cache_key = self._make_cache_key(str(session_id))
await self._increment_spend(cache_key, float(response_cost))
verbose_proxy_logger.debug(
"MaxBudgetPerSessionHandler: incremented session %s spend by %.6f",
session_id,
response_cost,
)
except Exception as e:
verbose_proxy_logger.warning(
"MaxBudgetPerSessionHandler: error in async_log_success_event: %s",
str(e),
)
def _get_session_id(self, data: dict) -> Optional[str]:
"""Extract session_id from request metadata."""
metadata = data.get("metadata") or {}
session_id = metadata.get("session_id")
if session_id is not None:
return str(session_id)
litellm_metadata = data.get("litellm_metadata") or {}
session_id = litellm_metadata.get("session_id")
if session_id is not None:
return str(session_id)
return None
def _get_max_budget_per_session(
self, user_api_key_dict: UserAPIKeyAuth
) -> Optional[float]:
"""Extract max_budget_per_session from agent litellm_params."""
agent_id = user_api_key_dict.agent_id
if agent_id is None:
return None
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
if agent is None:
return None
litellm_params = agent.litellm_params or {}
max_budget = litellm_params.get("max_budget_per_session")
if max_budget is not None:
return float(max_budget)
return None
def _make_cache_key(self, session_id: str) -> str:
return f"{{session_budget:{session_id}}}:spend"
async def _get_current_spend(self, cache_key: str) -> float:
"""Read current accumulated spend for a session."""
if self.internal_usage_cache.dual_cache.redis_cache is not None:
try:
result = await self.internal_usage_cache.dual_cache.redis_cache.async_get_cache(
key=cache_key
)
if result is not None:
return float(result)
return 0.0
except Exception as e:
verbose_proxy_logger.warning(
"MaxBudgetPerSessionHandler: Redis GET failed, "
"falling back to in-memory: %s",
str(e),
)
result = await self.internal_usage_cache.async_get_cache(
key=cache_key,
litellm_parent_otel_span=None,
local_only=True,
)
if result is not None:
return float(result)
return 0.0
async def _increment_spend(self, cache_key: str, amount: float) -> float:
"""Atomically increment the session spend and return the new value."""
if self.increment_script is not None:
try:
result = await self.increment_script(
keys=[cache_key],
args=[str(amount), self.ttl],
)
return float(result)
except Exception as e:
verbose_proxy_logger.warning(
"MaxBudgetPerSessionHandler: Redis INCRBYFLOAT failed, "
"falling back to in-memory: %s",
str(e),
)
return await self._in_memory_increment_spend(cache_key, amount)
async def _in_memory_increment_spend(self, cache_key: str, amount: float) -> float:
current = await self.internal_usage_cache.async_get_cache(
key=cache_key,
litellm_parent_otel_span=None,
local_only=True,
)
new_value = (float(current) if current is not None else 0.0) + amount
await self.internal_usage_cache.async_set_cache(
key=cache_key,
value=new_value,
ttl=self.ttl,
litellm_parent_otel_span=None,
local_only=True,
)
return new_value

View File

@@ -0,0 +1,221 @@
"""
Max Iterations Limiter for LiteLLM Proxy.
Enforces a per-session cap on the number of LLM calls an agentic loop can make.
Callers send a `session_id` with each request (via `x-litellm-session-id` header
or `metadata.session_id`), and this hook counts calls per session. When the count
exceeds `max_iterations` (configured in agent litellm_params or key metadata), returns 429.
Works across multiple proxy instances via DualCache (in-memory + Redis).
Follows the same pattern as parallel_request_limiter_v3.py.
"""
import os
from typing import TYPE_CHECKING, Any, Optional, Union
from fastapi import HTTPException
from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import UserAPIKeyAuth
if TYPE_CHECKING:
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
InternalUsageCache = _InternalUsageCache
else:
InternalUsageCache = Any
# Redis Lua script for atomic increment with TTL.
# Returns the new count after increment.
# Only sets EXPIRE on first increment (when count becomes 1).
MAX_ITERATIONS_INCREMENT_SCRIPT = """
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local current = redis.call('INCR', key)
if current == 1 then
redis.call('EXPIRE', key, ttl)
end
return current
"""
# Default TTL for session iteration counters (1 hour)
DEFAULT_MAX_ITERATIONS_TTL = 3600
class _PROXY_MaxIterationsHandler(CustomLogger):
"""
Pre-call hook that enforces max_iterations per session.
Configuration:
- max_iterations: set in agent litellm_params (preferred)
e.g. litellm_params={"max_iterations": 25}
Falls back to key metadata max_iterations for backwards compatibility.
- session_id: sent by caller via x-litellm-session-id header or
metadata.session_id in request body
Cache key pattern:
{session_iterations:<session_id>}:count
Multi-instance support:
Uses Redis Lua script for atomic increment (same pattern as
parallel_request_limiter_v3). Falls back to in-memory cache
when Redis is unavailable.
"""
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
self.ttl = int(
os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL)
)
# Register Lua script with Redis if available (same pattern as v3 limiter)
if self.internal_usage_cache.dual_cache.redis_cache is not None:
self.increment_script = (
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
MAX_ITERATIONS_INCREMENT_SCRIPT
)
)
else:
self.increment_script = None
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
) -> Optional[Union[Exception, str, dict]]:
"""
Check session iteration count before making the API call.
Extracts session_id from request metadata and max_iterations from
agent litellm_params. If the session has exceeded max_iterations, raises 429.
"""
# Extract session_id from request data
session_id = self._get_session_id(data)
if session_id is None:
return None
max_iterations = self._get_max_iterations(user_api_key_dict)
if max_iterations is None:
return None
verbose_proxy_logger.debug(
"MaxIterationsHandler: session_id=%s, max_iterations=%s",
session_id,
max_iterations,
)
# Increment and check
cache_key = self._make_cache_key(session_id)
current_count = await self._increment_and_get(cache_key)
if current_count > max_iterations:
raise HTTPException(
status_code=429,
detail=(
f"Max iterations exceeded for session {session_id}. "
f"Current count: {current_count}, max_iterations: {max_iterations}."
),
)
verbose_proxy_logger.debug(
"MaxIterationsHandler: session_id=%s, count=%s/%s",
session_id,
current_count,
max_iterations,
)
return None
def _get_session_id(self, data: dict) -> Optional[str]:
"""Extract session_id from request metadata."""
metadata = data.get("metadata") or {}
session_id = metadata.get("session_id")
if session_id is not None:
return str(session_id)
# Also check litellm_metadata (used for /thread and /assistant endpoints)
litellm_metadata = data.get("litellm_metadata") or {}
session_id = litellm_metadata.get("session_id")
if session_id is not None:
return str(session_id)
return None
def _get_max_iterations(self, user_api_key_dict: UserAPIKeyAuth) -> Optional[int]:
"""Extract max_iterations from agent litellm_params, with fallback to key metadata."""
# Try agent litellm_params first
agent_id = user_api_key_dict.agent_id
if agent_id is not None:
from litellm.proxy.agent_endpoints.agent_registry import (
global_agent_registry,
)
agent = global_agent_registry.get_agent_by_id(agent_id=agent_id)
if agent is not None:
litellm_params = agent.litellm_params or {}
max_iterations = litellm_params.get("max_iterations")
if max_iterations is not None:
return int(max_iterations)
# Fallback to key metadata for backwards compatibility
metadata = user_api_key_dict.metadata or {}
max_iterations = metadata.get("max_iterations")
if max_iterations is not None:
return int(max_iterations)
return None
def _make_cache_key(self, session_id: str) -> str:
"""
Create cache key for session iteration counter.
Uses Redis hash-tag pattern {session_iterations:<session_id>} so all
keys for a session land on the same Redis Cluster slot.
"""
return f"{{session_iterations:{session_id}}}:count"
async def _increment_and_get(self, cache_key: str) -> int:
"""
Atomically increment the session counter and return the new value.
Tries Redis first (via registered Lua script for atomicity across
instances), falls back to in-memory cache.
"""
if self.increment_script is not None:
try:
result = await self.increment_script(
keys=[cache_key],
args=[self.ttl],
)
return int(result)
except Exception as e:
verbose_proxy_logger.warning(
"MaxIterationsHandler: Redis failed, falling back to in-memory: %s",
str(e),
)
# Fallback: in-memory cache
return await self._in_memory_increment(cache_key)
async def _in_memory_increment(self, cache_key: str) -> int:
"""Increment counter in in-memory cache with TTL."""
current = await self.internal_usage_cache.async_get_cache(
key=cache_key,
litellm_parent_otel_span=None,
local_only=True,
)
new_value = (int(current) if current is not None else 0) + 1
await self.internal_usage_cache.async_set_cache(
key=cache_key,
value=new_value,
ttl=self.ttl,
litellm_parent_otel_span=None,
local_only=True,
)
return new_value

View File

@@ -0,0 +1,96 @@
# MCP Semantic Tool Filter Architecture
## Why Filter MCP Tools
When multiple MCP servers are connected, the proxy may expose hundreds of tools. Sending all tools in every request wastes context window tokens and increases cost. The semantic filter keeps only the top-K most relevant tools based on embedding similarity.
```mermaid
sequenceDiagram
participant Client
participant Hook as SemanticToolFilterHook
participant Filter as SemanticMCPToolFilter
participant Router as semantic-router
participant LLM
Client->>Hook: POST /chat/completions
Note over Client,Hook: tools: [100+ MCP tools]
Note over Client,Hook: messages: [{"role": "user", "content": "Get my Jira issues"}]
rect rgb(240, 240, 240)
Note over Hook: 1. Extract User Query
Hook->>Filter: filter_tools("Get my Jira issues", tools)
end
rect rgb(240, 240, 240)
Note over Filter: 2. Convert Tools → Routes
Note over Filter: Tool name + description → Route
end
rect rgb(240, 240, 240)
Note over Filter: 3. Semantic Matching
Filter->>Router: router(query)
Router->>Router: Embeddings + similarity
Router-->>Filter: [top 10 matches]
end
rect rgb(240, 240, 240)
Note over Filter: 4. Return Filtered Tools
Filter-->>Hook: [10 relevant tools]
end
Hook->>LLM: POST /chat/completions
Note over Hook,LLM: tools: [10 Jira-related tools] ← FILTERED
Note over Hook,LLM: messages: [...] ← UNCHANGED
LLM-->>Client: Response (unchanged)
```
## Filter Operations
The hook intercepts requests before they reach the LLM:
| Operation | Description |
|-----------|-------------|
| **Extract query** | Get user message from `messages[-1]` |
| **Convert to Routes** | Transform MCP tools into semantic-router Routes |
| **Semantic match** | Use `semantic-router` to find top-K similar tools |
| **Filter tools** | Replace request `tools` with filtered subset |
## Trigger Conditions
The filter only runs when:
- Call type is `completion` or `acompletion`
- Request contains `tools` field
- Request contains `messages` field
- Filter is enabled in config
## What Does NOT Change
- Request messages
- Response body
- Non-tool parameters
## Integration with semantic-router
Reuses existing LiteLLM infrastructure:
- `semantic-router` - Already an optional dependency
- `LiteLLMRouterEncoder` - Wraps `Router.aembedding()` for embeddings
- `SemanticRouter` - Handles similarity calculation and top-K selection
## Configuration
```yaml
litellm_settings:
mcp_semantic_tool_filter:
enabled: true
embedding_model: "openai/text-embedding-3-small"
top_k: 10
similarity_threshold: 0.3
```
## Error Handling
The filter fails gracefully:
- If filtering fails → Return all tools (no impact on functionality)
- If query extraction fails → Skip filtering
- If no matches found → Return all tools

View File

@@ -0,0 +1,9 @@
"""
MCP Semantic Tool Filter Hook
Semantic filtering for MCP tools to reduce context window size
and improve tool selection accuracy.
"""
from litellm.proxy.hooks.mcp_semantic_filter.hook import SemanticToolFilterHook
__all__ = ["SemanticToolFilterHook"]

View File

@@ -0,0 +1,375 @@
"""
Semantic Tool Filter Hook
Pre-call hook that filters MCP tools semantically before LLM inference.
Reduces context window size and improves tool selection accuracy.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.constants import (
DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL,
DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD,
DEFAULT_MCP_SEMANTIC_FILTER_TOP_K,
)
from litellm.integrations.custom_logger import CustomLogger
if TYPE_CHECKING:
from litellm.caching.caching import DualCache
from litellm.proxy._experimental.mcp_server.semantic_tool_filter import (
SemanticMCPToolFilter,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.router import Router
class SemanticToolFilterHook(CustomLogger):
"""
Pre-call hook that filters MCP tools semantically.
This hook:
1. Extracts the user query from messages
2. Filters tools based on semantic similarity to the query
3. Returns only the top-k most relevant tools to the LLM
"""
def __init__(self, semantic_filter: "SemanticMCPToolFilter"):
"""
Initialize the hook.
Args:
semantic_filter: SemanticMCPToolFilter instance
"""
super().__init__()
self.filter = semantic_filter
verbose_proxy_logger.debug(
f"Initialized SemanticToolFilterHook with filter: "
f"enabled={semantic_filter.enabled}, top_k={semantic_filter.top_k}"
)
def _should_expand_mcp_tools(self, tools: List[Any]) -> bool:
"""
Check if tools contain MCP references with server_url="litellm_proxy".
Only expands MCP tools pointing to litellm proxy, not external MCP servers.
"""
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
return LiteLLM_Proxy_MCP_Handler._should_use_litellm_mcp_gateway(tools)
async def _expand_mcp_tools(
self,
tools: List[Any],
user_api_key_dict: "UserAPIKeyAuth",
) -> List[Dict[str, Any]]:
"""
Expand MCP references to actual tool definitions.
Reuses LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format
which internally does: parse -> fetch -> filter -> deduplicate -> transform
"""
from litellm.responses.mcp.litellm_proxy_mcp_handler import (
LiteLLM_Proxy_MCP_Handler,
)
# Parse to separate MCP tools from other tools
mcp_tools, _ = LiteLLM_Proxy_MCP_Handler._parse_mcp_tools(tools)
if not mcp_tools:
return []
# Use single combined method instead of 3 separate calls
# This already handles: fetch -> filter by allowed_tools -> deduplicate -> transform
(
openai_tools,
_,
) = await LiteLLM_Proxy_MCP_Handler._process_mcp_tools_to_openai_format(
user_api_key_auth=user_api_key_dict, mcp_tools_with_litellm_proxy=mcp_tools
)
# Convert Pydantic models to dicts for compatibility
openai_tools_as_dicts = []
for tool in openai_tools:
if hasattr(tool, "model_dump"):
tool_dict = tool.model_dump(exclude_none=True)
verbose_proxy_logger.debug(
f"Converted Pydantic tool to dict: {type(tool).__name__} -> dict with keys: {list(tool_dict.keys())}"
)
openai_tools_as_dicts.append(tool_dict)
elif hasattr(tool, "dict"):
tool_dict = tool.dict(exclude_none=True)
verbose_proxy_logger.debug(
f"Converted Pydantic tool (v1) to dict: {type(tool).__name__} -> dict"
)
openai_tools_as_dicts.append(tool_dict)
elif isinstance(tool, dict):
verbose_proxy_logger.debug(
f"Tool is already a dict with keys: {list(tool.keys())}"
)
openai_tools_as_dicts.append(tool)
else:
verbose_proxy_logger.warning(
f"Tool is unknown type: {type(tool)}, passing as-is"
)
openai_tools_as_dicts.append(tool)
verbose_proxy_logger.debug(
f"Expanded {len(mcp_tools)} MCP reference(s) to {len(openai_tools_as_dicts)} tools (all as dicts)"
)
return openai_tools_as_dicts
def _get_metadata_variable_name(self, data: dict) -> str:
if "litellm_metadata" in data:
return "litellm_metadata"
return "metadata"
async def async_pre_call_hook(
self,
user_api_key_dict: "UserAPIKeyAuth",
cache: "DualCache",
data: dict,
call_type: str,
) -> Optional[Union[Exception, str, dict]]:
"""
Filter tools before LLM call based on user query.
This hook is called before the LLM request is made. It filters the
tools list to only include semantically relevant tools.
Args:
user_api_key_dict: User authentication
cache: Cache instance
data: Request data containing messages and tools
call_type: Type of call (completion, acompletion, etc.)
Returns:
Modified data dict with filtered tools, or None if no changes
"""
# Only filter endpoints that support tools
if call_type not in ("completion", "acompletion", "aresponses"):
verbose_proxy_logger.debug(
f"Skipping semantic filter for call_type={call_type}"
)
return None
# Check if tools are present
tools = data.get("tools")
if not tools:
verbose_proxy_logger.debug("No tools in request, skipping semantic filter")
return None
original_tool_count = len(tools)
# Check for MCP references (server_url="litellm_proxy") and expand them
if self._should_expand_mcp_tools(tools):
verbose_proxy_logger.debug(
"Detected litellm_proxy MCP references, expanding before semantic filtering"
)
try:
expanded_tools = await self._expand_mcp_tools(tools, user_api_key_dict)
if not expanded_tools:
verbose_proxy_logger.warning(
"No tools expanded from MCP references"
)
return None
verbose_proxy_logger.info(
f"Expanded {len(tools)} MCP reference(s) to {len(expanded_tools)} tools"
)
# Update tools for filtering
tools = expanded_tools
original_tool_count = len(tools)
except Exception as e:
verbose_proxy_logger.error(
f"Failed to expand MCP references: {e}", exc_info=True
)
return None
# Check if messages are present (try both "messages" and "input" for responses API)
messages = data.get("messages", [])
if not messages:
messages = data.get("input", [])
if not messages:
verbose_proxy_logger.debug(
"No messages in request, skipping semantic filter"
)
return None
# Check if filter is enabled
if not self.filter.enabled:
verbose_proxy_logger.debug("Semantic filter disabled, skipping")
return None
try:
# Extract user query from messages
user_query = self.filter.extract_user_query(messages)
if not user_query:
verbose_proxy_logger.debug(
"No user query found, skipping semantic filter"
)
return None
verbose_proxy_logger.debug(
f"Applying semantic filter to {len(tools)} tools "
f"with query: '{user_query[:50]}...'"
)
# Filter tools semantically
filtered_tools = await self.filter.filter_tools(
query=user_query,
available_tools=tools, # type: ignore
)
# Always update tools and emit header (even if count unchanged)
data["tools"] = filtered_tools
# Store filter stats and tool names for response header
filter_stats = f"{original_tool_count}->{len(filtered_tools)}"
tool_names_csv = self._get_tool_names_csv(filtered_tools)
_metadata_variable_name = self._get_metadata_variable_name(data)
data[_metadata_variable_name][
"litellm_semantic_filter_stats"
] = filter_stats
data[_metadata_variable_name][
"litellm_semantic_filter_tools"
] = tool_names_csv
verbose_proxy_logger.info(f"Semantic tool filter: {filter_stats} tools")
return data
except Exception as e:
verbose_proxy_logger.warning(
f"Semantic tool filter hook failed: {e}. Proceeding with all tools."
)
return None
async def async_post_call_response_headers_hook(
self,
data: dict,
user_api_key_dict: "UserAPIKeyAuth",
response: Any,
request_headers: Optional[Dict[str, str]] = None,
litellm_call_info: Optional[Dict[str, Any]] = None,
) -> Optional[Dict[str, str]]:
"""Add semantic filter stats and tool names to response headers."""
from litellm.constants import MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH
_metadata_variable_name = self._get_metadata_variable_name(data)
metadata = data[_metadata_variable_name]
filter_stats = metadata.get("litellm_semantic_filter_stats")
if not filter_stats:
return None
headers = {"x-litellm-semantic-filter": filter_stats}
# Add CSV of filtered tool names (nginx-safe length)
tool_names_csv = metadata.get("litellm_semantic_filter_tools", "")
if tool_names_csv:
if len(tool_names_csv) > MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH:
tool_names_csv = (
tool_names_csv[: MAX_MCP_SEMANTIC_FILTER_TOOLS_HEADER_LENGTH - 3]
+ "..."
)
headers["x-litellm-semantic-filter-tools"] = tool_names_csv
return headers
def _get_tool_names_csv(self, tools: List[Any]) -> str:
"""Extract tool names and return as CSV string."""
if not tools:
return ""
tool_names = []
for tool in tools:
name = (
tool.get("name", "")
if isinstance(tool, dict)
else getattr(tool, "name", "")
)
if name:
tool_names.append(name)
return ",".join(tool_names)
@staticmethod
async def initialize_from_config(
config: Optional[Dict[str, Any]],
llm_router: Optional["Router"],
) -> Optional["SemanticToolFilterHook"]:
"""
Initialize semantic tool filter from proxy config.
Args:
config: Proxy configuration dict (litellm_settings.mcp_semantic_tool_filter)
llm_router: LiteLLM router instance for embeddings
Returns:
SemanticToolFilterHook instance if enabled, None otherwise
"""
from litellm.proxy._experimental.mcp_server.semantic_tool_filter import (
SemanticMCPToolFilter,
)
if not config or not config.get("enabled", False):
verbose_proxy_logger.debug("Semantic tool filter not enabled in config")
return None
if llm_router is None:
verbose_proxy_logger.warning(
"Cannot initialize semantic filter: llm_router is None"
)
return None
try:
embedding_model = config.get(
"embedding_model", DEFAULT_MCP_SEMANTIC_FILTER_EMBEDDING_MODEL
)
top_k = config.get("top_k", DEFAULT_MCP_SEMANTIC_FILTER_TOP_K)
similarity_threshold = config.get(
"similarity_threshold", DEFAULT_MCP_SEMANTIC_FILTER_SIMILARITY_THRESHOLD
)
semantic_filter = SemanticMCPToolFilter(
embedding_model=embedding_model,
litellm_router_instance=llm_router,
top_k=top_k,
similarity_threshold=similarity_threshold,
enabled=True,
)
# Build router from MCP registry on startup
await semantic_filter.build_router_from_mcp_registry()
hook = SemanticToolFilterHook(semantic_filter)
verbose_proxy_logger.info(
f"✅ MCP Semantic Tool Filter enabled: "
f"embedding_model={embedding_model}, top_k={top_k}, "
f"similarity_threshold={similarity_threshold}"
)
return hook
except ImportError as e:
verbose_proxy_logger.warning(
f"semantic-router not installed. Install with: "
f"pip install 'litellm[semantic-router]'. Error: {e}"
)
return None
except Exception as e:
verbose_proxy_logger.exception(
f"Failed to initialize MCP semantic tool filter: {e}"
)
return None

View File

@@ -0,0 +1,318 @@
import json
from typing import List, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_logger import Span
from litellm.proxy._types import UserAPIKeyAuth
from litellm.router_strategy.budget_limiter import RouterBudgetLimiting
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
BudgetConfig,
GenericBudgetConfigType,
StandardLoggingPayload,
)
VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX = "virtual_key_spend"
END_USER_SPEND_CACHE_KEY_PREFIX = "end_user_model_spend"
class _PROXY_VirtualKeyModelMaxBudgetLimiter(RouterBudgetLimiting):
"""
Handles budgets for model + virtual key
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
def __init__(self, dual_cache: DualCache):
self.dual_cache = dual_cache
self.redis_increment_operation_queue = []
async def is_key_within_model_budget(
self,
user_api_key_dict: UserAPIKeyAuth,
model: str,
) -> bool:
"""
Check if the user_api_key_dict is within the model budget
Raises:
BudgetExceededError: If the user_api_key_dict has exceeded the model budget
"""
_model_max_budget = user_api_key_dict.model_max_budget
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in _model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
verbose_proxy_logger.debug(
"internal_model_max_budget %s",
json.dumps(internal_model_max_budget, indent=4, default=str),
)
# check if current model is in internal_model_max_budget
_current_model_budget_info = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if _current_model_budget_info is None:
verbose_proxy_logger.debug(
f"Model {model} not found in internal_model_max_budget"
)
return True
# check if current model is within budget
if (
_current_model_budget_info.max_budget
and _current_model_budget_info.max_budget > 0
):
_current_spend = await self._get_virtual_key_spend_for_model(
user_api_key_hash=user_api_key_dict.token,
model=model,
key_budget_config=_current_model_budget_info,
)
if (
_current_spend is not None
and _current_model_budget_info.max_budget is not None
and _current_spend > _current_model_budget_info.max_budget
):
raise litellm.BudgetExceededError(
message=f"LiteLLM Virtual Key: {user_api_key_dict.token}, key_alias: {user_api_key_dict.key_alias}, exceeded budget for model={model}",
current_cost=_current_spend,
max_budget=_current_model_budget_info.max_budget,
)
return True
async def is_end_user_within_model_budget(
self,
end_user_id: str,
end_user_model_max_budget: dict,
model: str,
) -> bool:
"""
Check if the end_user is within the model budget
Raises:
BudgetExceededError: If the end_user has exceeded the model budget
"""
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in end_user_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
verbose_proxy_logger.debug(
"end_user internal_model_max_budget %s",
json.dumps(internal_model_max_budget, indent=4, default=str),
)
# check if current model is in internal_model_max_budget
_current_model_budget_info = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if _current_model_budget_info is None:
verbose_proxy_logger.debug(
f"Model {model} not found in end_user_model_max_budget"
)
return True
# check if current model is within budget
if (
_current_model_budget_info.max_budget
and _current_model_budget_info.max_budget > 0
):
_current_spend = await self._get_end_user_spend_for_model(
end_user_id=end_user_id,
model=model,
key_budget_config=_current_model_budget_info,
)
if (
_current_spend is not None
and _current_model_budget_info.max_budget is not None
and _current_spend > _current_model_budget_info.max_budget
):
raise litellm.BudgetExceededError(
message=f"LiteLLM End User: {end_user_id}, exceeded budget for model={model}",
current_cost=_current_spend,
max_budget=_current_model_budget_info.max_budget,
)
return True
async def _get_end_user_spend_for_model(
self,
end_user_id: str,
model: str,
key_budget_config: BudgetConfig,
) -> Optional[float]:
# 1. model: directly look up `model`
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=end_user_model_spend_cache_key,
)
if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
end_user_model_spend_cache_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=end_user_model_spend_cache_key,
)
return _current_spend
async def _get_virtual_key_spend_for_model(
self,
user_api_key_hash: Optional[str],
model: str,
key_budget_config: BudgetConfig,
) -> Optional[float]:
"""
Get the current spend for a virtual key for a model
Lookup model in this order:
1. model: directly look up `model`
2. If 1, does not exist, check if passed as {custom_llm_provider}/model
"""
# 1. model: directly look up `model`
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{model}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
if _current_spend is None:
# 2. If 1, does not exist, check if passed as {custom_llm_provider}/model
# if "/" in model, remove first part before "/" - eg. openai/o1-preview -> o1-preview
virtual_key_model_spend_cache_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{user_api_key_hash}:{self._get_model_without_custom_llm_provider(model)}:{key_budget_config.budget_duration}"
_current_spend = await self.dual_cache.async_get_cache(
key=virtual_key_model_spend_cache_key,
)
return _current_spend
def _get_request_model_budget_config(
self, model: str, internal_model_max_budget: GenericBudgetConfigType
) -> Optional[BudgetConfig]:
"""
Get the budget config for the request model
1. Check if `model` is in `internal_model_max_budget`
2. If not, check if `model` without custom llm provider is in `internal_model_max_budget`
"""
return internal_model_max_budget.get(
model, None
) or internal_model_max_budget.get(
self._get_model_without_custom_llm_provider(model), None
)
def _get_model_without_custom_llm_provider(self, model: str) -> str:
if "/" in model:
return model.split("/")[-1]
return model
async def async_filter_deployments(
self,
model: str,
healthy_deployments: List,
messages: Optional[List[AllMessageValues]],
request_kwargs: Optional[dict] = None,
parent_otel_span: Optional[Span] = None, # type: ignore
) -> List[dict]:
return healthy_deployments
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
"""
Track spend for virtual key + model in DualCache
Example: key=sk-1234567890, model=gpt-4o, max_budget=100, time_period=1d
"""
verbose_proxy_logger.debug("in RouterBudgetLimiting.async_log_success_event")
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
if standard_logging_payload is None:
verbose_proxy_logger.debug(
"Skipping _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event: standard_logging_payload is None"
)
return
_litellm_params: dict = kwargs.get("litellm_params", {}) or {}
_metadata: dict = _litellm_params.get("metadata", {}) or {}
user_api_key_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_model_max_budget", None
)
user_api_key_end_user_model_max_budget: Optional[dict] = _metadata.get(
"user_api_key_end_user_model_max_budget", None
)
if (
user_api_key_model_max_budget is None
or len(user_api_key_model_max_budget) == 0
) and (
user_api_key_end_user_model_max_budget is None
or len(user_api_key_end_user_model_max_budget) == 0
):
verbose_proxy_logger.debug(
"Not running _PROXY_VirtualKeyModelMaxBudgetLimiter.async_log_success_event because user_api_key_model_max_budget and user_api_key_end_user_model_max_budget are None or empty."
)
return
response_cost: float = standard_logging_payload.get("response_cost", 0)
model = standard_logging_payload.get("model")
virtual_key = standard_logging_payload.get("metadata", {}).get(
"user_api_key_hash"
)
end_user_id = standard_logging_payload.get(
"end_user"
) or standard_logging_payload.get("metadata", {}).get(
"user_api_key_end_user_id"
)
if model is None:
return
if (
virtual_key is not None
and user_api_key_model_max_budget is not None
and len(user_api_key_model_max_budget) > 0
):
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in user_api_key_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
key_budget_config = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if key_budget_config is not None and key_budget_config.budget_duration:
virtual_spend_key = f"{VIRTUAL_KEY_SPEND_CACHE_KEY_PREFIX}:{virtual_key}:{model}:{key_budget_config.budget_duration}"
virtual_start_time_key = f"virtual_key_budget_start_time:{virtual_key}"
await self._increment_spend_for_key(
budget_config=key_budget_config,
spend_key=virtual_spend_key,
start_time_key=virtual_start_time_key,
response_cost=response_cost,
)
if (
end_user_id is not None
and user_api_key_end_user_model_max_budget is not None
and len(user_api_key_end_user_model_max_budget) > 0
):
internal_model_max_budget: GenericBudgetConfigType = {}
for _model, _budget_info in user_api_key_end_user_model_max_budget.items():
internal_model_max_budget[_model] = BudgetConfig(**_budget_info)
key_budget_config = self._get_request_model_budget_config(
model=model, internal_model_max_budget=internal_model_max_budget
)
if key_budget_config is not None and key_budget_config.budget_duration:
end_user_spend_key = f"{END_USER_SPEND_CACHE_KEY_PREFIX}:{end_user_id}:{model}:{key_budget_config.budget_duration}"
end_user_start_time_key = f"end_user_budget_start_time:{end_user_id}"
await self._increment_spend_for_key(
budget_config=key_budget_config,
spend_key=end_user_spend_key,
start_time_key=end_user_start_time_key,
response_cost=response_cost,
)
verbose_proxy_logger.debug(
"current state of in memory cache %s",
json.dumps(
self.dual_cache.in_memory_cache.cache_dict, indent=4, default=str
),
)

View File

@@ -0,0 +1,869 @@
import asyncio
import sys
from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, List, Literal, Optional, Tuple, Union
from fastapi import HTTPException
from pydantic import BaseModel
from typing_extensions import TypedDict
import litellm
from litellm import DualCache, ModelResponse
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
from litellm.proxy._types import CommonProxyErrors, CurrentItemRateLimit, UserAPIKeyAuth
from litellm.proxy.auth.auth_utils import (
get_key_model_rpm_limit,
get_key_model_tpm_limit,
)
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
Span = Union[_Span, Any]
InternalUsageCache = _InternalUsageCache
else:
Span = Any
InternalUsageCache = Any
class CacheObject(TypedDict):
current_global_requests: Optional[dict]
request_count_api_key: Optional[dict]
request_count_api_key_model: Optional[dict]
request_count_user_id: Optional[dict]
request_count_team_id: Optional[dict]
request_count_end_user_id: Optional[dict]
class _PROXY_MaxParallelRequestsHandler(CustomLogger):
# Class variables or attributes
def __init__(self, internal_usage_cache: InternalUsageCache):
self.internal_usage_cache = internal_usage_cache
def print_verbose(self, print_statement):
try:
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except Exception:
pass
async def check_key_in_limits(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
max_parallel_requests: int,
tpm_limit: int,
rpm_limit: int,
current: Optional[dict],
request_count_api_key: str,
rate_limit_type: Literal["key", "model_per_key", "user", "customer", "team"],
values_to_update_in_cache: List[Tuple[Any, Any]],
) -> dict:
verbose_proxy_logger.info(
f"Current Usage of {rate_limit_type} in this minute: {current}"
)
if current is None:
if max_parallel_requests == 0 or tpm_limit == 0 or rpm_limit == 0:
# base case
raise self.raise_rate_limit_error(
additional_details=f"{CommonProxyErrors.max_parallel_request_limit_reached.value}. Hit limit for {rate_limit_type}. Current limits: max_parallel_requests: {max_parallel_requests}, tpm_limit: {tpm_limit}, rpm_limit: {rpm_limit}"
)
new_val = {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 1,
}
values_to_update_in_cache.append((request_count_api_key, new_val))
elif (
int(current["current_requests"]) < max_parallel_requests
and current["current_tpm"] < tpm_limit
and current["current_rpm"] < rpm_limit
):
# Increase count for this token
new_val = {
"current_requests": current["current_requests"] + 1,
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"] + 1,
}
values_to_update_in_cache.append((request_count_api_key, new_val))
else:
raise HTTPException(
status_code=429,
detail=f"LiteLLM Rate Limit Handler for rate limit type = {rate_limit_type}. {CommonProxyErrors.max_parallel_request_limit_reached.value}. current rpm: {current['current_rpm']}, rpm limit: {rpm_limit}, current tpm: {current['current_tpm']}, tpm limit: {tpm_limit}, current max_parallel_requests: {current['current_requests']}, max_parallel_requests: {max_parallel_requests}",
headers={"retry-after": str(self.time_to_next_minute())},
)
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
local_only=True,
)
return new_val
def time_to_next_minute(self) -> float:
# Get the current time
now = datetime.now()
# Calculate the next minute
next_minute = (now + timedelta(minutes=1)).replace(second=0, microsecond=0)
# Calculate the difference in seconds
seconds_to_next_minute = (next_minute - now).total_seconds()
return seconds_to_next_minute
def raise_rate_limit_error(
self, additional_details: Optional[str] = None
) -> HTTPException:
"""
Raise an HTTPException with a 429 status code and a retry-after header
"""
error_message = "Max parallel request limit reached"
if additional_details is not None:
error_message = error_message + " " + additional_details
raise HTTPException(
status_code=429,
detail=f"Max parallel request limit reached {additional_details}",
headers={"retry-after": str(self.time_to_next_minute())},
)
async def get_all_cache_objects(
self,
current_global_requests: Optional[str],
request_count_api_key: Optional[str],
request_count_api_key_model: Optional[str],
request_count_user_id: Optional[str],
request_count_team_id: Optional[str],
request_count_end_user_id: Optional[str],
parent_otel_span: Optional[Span] = None,
) -> CacheObject:
keys = [
current_global_requests,
request_count_api_key,
request_count_api_key_model,
request_count_user_id,
request_count_team_id,
request_count_end_user_id,
]
results = await self.internal_usage_cache.async_batch_get_cache(
keys=keys,
parent_otel_span=parent_otel_span,
)
if results is None:
return CacheObject(
current_global_requests=None,
request_count_api_key=None,
request_count_api_key_model=None,
request_count_user_id=None,
request_count_team_id=None,
request_count_end_user_id=None,
)
return CacheObject(
current_global_requests=results[0],
request_count_api_key=results[1],
request_count_api_key_model=results[2],
request_count_user_id=results[3],
request_count_team_id=results[4],
request_count_end_user_id=results[5],
)
async def async_pre_call_hook( # noqa: PLR0915
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str,
):
self.print_verbose("Inside Max Parallel Request Pre-Call Hook")
api_key = user_api_key_dict.api_key
max_parallel_requests = user_api_key_dict.max_parallel_requests
if max_parallel_requests is None:
max_parallel_requests = sys.maxsize
if data is None:
data = {}
global_max_parallel_requests = data.get("metadata", {}).get(
"global_max_parallel_requests", None
)
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
if tpm_limit is None:
tpm_limit = sys.maxsize
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
if rpm_limit is None:
rpm_limit = sys.maxsize
values_to_update_in_cache: List[
Tuple[Any, Any]
] = (
[]
) # values that need to get updated in cache, will run a batch_set_cache after this function
# ------------
# Setup values
# ------------
new_val: Optional[dict] = None
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
current_global_requests = await self.internal_usage_cache.async_get_cache(
key=_key,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
# check if below limit
if current_global_requests is None:
current_global_requests = 1
# if above -> raise error
if current_global_requests >= global_max_parallel_requests:
return self.raise_rate_limit_error(
additional_details=f"Hit Global Limit: Limit={global_max_parallel_requests}, current: {current_global_requests}"
)
# if below -> increment
else:
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=1,
local_only=True,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
_model = data.get("model", None)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
cache_objects: CacheObject = await self.get_all_cache_objects(
current_global_requests=(
"global_max_parallel_requests"
if global_max_parallel_requests is not None
else None
),
request_count_api_key=(
f"{api_key}::{precise_minute}::request_count"
if api_key is not None
else None
),
request_count_api_key_model=(
f"{api_key}::{_model}::{precise_minute}::request_count"
if api_key is not None and _model is not None
else None
),
request_count_user_id=(
f"{user_api_key_dict.user_id}::{precise_minute}::request_count"
if user_api_key_dict.user_id is not None
else None
),
request_count_team_id=(
f"{user_api_key_dict.team_id}::{precise_minute}::request_count"
if user_api_key_dict.team_id is not None
else None
),
request_count_end_user_id=(
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
if user_api_key_dict.end_user_id is not None
else None
),
parent_otel_span=user_api_key_dict.parent_otel_span,
)
if api_key is not None:
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
# CHECK IF REQUEST ALLOWED for key
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=max_parallel_requests,
current=cache_objects["request_count_api_key"],
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit,
rpm_limit=rpm_limit,
rate_limit_type="key",
values_to_update_in_cache=values_to_update_in_cache,
)
# Check if request under RPM/TPM per model for a given API Key
if (
get_key_model_tpm_limit(user_api_key_dict) is not None
or get_key_model_rpm_limit(user_api_key_dict) is not None
):
_model = data.get("model", None)
request_count_api_key = (
f"{api_key}::{_model}::{precise_minute}::request_count"
)
_tpm_limit_for_key_model = get_key_model_tpm_limit(user_api_key_dict)
_rpm_limit_for_key_model = get_key_model_rpm_limit(user_api_key_dict)
tpm_limit_for_model = None
rpm_limit_for_model = None
if _model is not None:
if _tpm_limit_for_key_model:
tpm_limit_for_model = _tpm_limit_for_key_model.get(_model)
if _rpm_limit_for_key_model:
rpm_limit_for_model = _rpm_limit_for_key_model.get(_model)
new_val = await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a model
current=cache_objects["request_count_api_key_model"],
request_count_api_key=request_count_api_key,
tpm_limit=tpm_limit_for_model or sys.maxsize,
rpm_limit=rpm_limit_for_model or sys.maxsize,
rate_limit_type="model_per_key",
values_to_update_in_cache=values_to_update_in_cache,
)
_remaining_tokens = None
_remaining_requests = None
# Add remaining tokens, requests to metadata
if new_val:
if tpm_limit_for_model is not None:
_remaining_tokens = tpm_limit_for_model - new_val["current_tpm"]
if rpm_limit_for_model is not None:
_remaining_requests = rpm_limit_for_model - new_val["current_rpm"]
_remaining_limits_data = {
f"litellm-key-remaining-tokens-{_model}": _remaining_tokens,
f"litellm-key-remaining-requests-{_model}": _remaining_requests,
}
if "metadata" not in data:
data["metadata"] = {}
data["metadata"].update(_remaining_limits_data)
# check if REQUEST ALLOWED for user_id
user_id = user_api_key_dict.user_id
if user_id is not None:
user_tpm_limit = user_api_key_dict.user_tpm_limit
user_rpm_limit = user_api_key_dict.user_rpm_limit
if user_tpm_limit is None:
user_tpm_limit = sys.maxsize
if user_rpm_limit is None:
user_rpm_limit = sys.maxsize
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
current=cache_objects["request_count_user_id"],
request_count_api_key=request_count_api_key,
tpm_limit=user_tpm_limit,
rpm_limit=user_rpm_limit,
rate_limit_type="user",
values_to_update_in_cache=values_to_update_in_cache,
)
# TEAM RATE LIMITS
## get team tpm/rpm limits
team_id = user_api_key_dict.team_id
if team_id is not None:
team_tpm_limit = user_api_key_dict.team_tpm_limit
team_rpm_limit = user_api_key_dict.team_rpm_limit
if team_tpm_limit is None:
team_tpm_limit = sys.maxsize
if team_rpm_limit is None:
team_rpm_limit = sys.maxsize
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a team
current=cache_objects["request_count_team_id"],
request_count_api_key=request_count_api_key,
tpm_limit=team_tpm_limit,
rpm_limit=team_rpm_limit,
rate_limit_type="team",
values_to_update_in_cache=values_to_update_in_cache,
)
# End-User Rate Limits
# Only enforce if user passed `user` to /chat, /completions, /embeddings
if user_api_key_dict.end_user_id:
end_user_tpm_limit = getattr(
user_api_key_dict, "end_user_tpm_limit", sys.maxsize
)
end_user_rpm_limit = getattr(
user_api_key_dict, "end_user_rpm_limit", sys.maxsize
)
if end_user_tpm_limit is None:
end_user_tpm_limit = sys.maxsize
if end_user_rpm_limit is None:
end_user_rpm_limit = sys.maxsize
# now do the same tpm/rpm checks
request_count_api_key = (
f"{user_api_key_dict.end_user_id}::{precise_minute}::request_count"
)
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
await self.check_key_in_limits(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type=call_type,
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for an End-User
request_count_api_key=request_count_api_key,
current=cache_objects["request_count_end_user_id"],
tpm_limit=end_user_tpm_limit,
rpm_limit=end_user_rpm_limit,
rate_limit_type="customer",
values_to_update_in_cache=values_to_update_in_cache,
)
asyncio.create_task(
self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
) # don't block execution for cache updates
)
return
async def async_log_success_event( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
):
from litellm.proxy.common_utils.callback_utils import (
get_model_group_from_litellm_kwargs,
)
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
kwargs=kwargs
)
try:
self.print_verbose("INSIDE parallel request limiter ASYNC SUCCESS LOGGING")
global_max_parallel_requests = kwargs["litellm_params"]["metadata"].get(
"global_max_parallel_requests", None
)
user_api_key = kwargs["litellm_params"]["metadata"]["user_api_key"]
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_user_id", None
)
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
"user_api_key_team_id", None
)
user_api_key_model_max_budget = kwargs["litellm_params"]["metadata"].get(
"user_api_key_model_max_budget", None
)
user_api_key_end_user_id = kwargs.get("user")
user_api_key_metadata = (
kwargs["litellm_params"]["metadata"].get("user_api_key_metadata", {})
or {}
)
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
# decrement
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
# ------------
# Update usage - API Key
# ------------
values_to_update_in_cache = []
if user_api_key is not None:
request_count_api_key = (
f"{user_api_key}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - model group + API Key
# ------------
model_group = get_model_group_from_litellm_kwargs(kwargs)
if (
user_api_key is not None
and model_group is not None
and (
"model_rpm_limit" in user_api_key_metadata
or "model_tpm_limit" in user_api_key_metadata
or user_api_key_model_max_budget is not None
)
):
request_count_api_key = (
f"{user_api_key}::{model_group}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - User
# ------------
if user_api_key_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_user_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - Team
# ------------
if user_api_key_team_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_team_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
# ------------
# Update usage - End User
# ------------
if user_api_key_end_user_id is not None:
total_tokens = 0
if isinstance(response_obj, ModelResponse):
total_tokens = response_obj.usage.total_tokens # type: ignore
request_count_api_key = (
f"{user_api_key_end_user_id}::{precise_minute}::request_count"
)
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": total_tokens,
"current_rpm": 1,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"] + total_tokens,
"current_rpm": current["current_rpm"],
}
self.print_verbose(
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
)
values_to_update_in_cache.append((request_count_api_key, new_val))
await self.internal_usage_cache.async_batch_set_cache(
cache_list=values_to_update_in_cache,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
)
except Exception as e:
self.print_verbose(e) # noqa
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.print_verbose("Inside Max Parallel Request Failure Hook")
litellm_parent_otel_span: Union[
Span, None
] = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
_metadata = kwargs["litellm_params"].get("metadata", {}) or {}
global_max_parallel_requests = _metadata.get(
"global_max_parallel_requests", None
)
user_api_key = _metadata.get("user_api_key", None)
self.print_verbose(f"user_api_key: {user_api_key}")
if user_api_key is None:
return
## decrement call count if call failed
if CommonProxyErrors.max_parallel_request_limit_reached.value in str(
kwargs["exception"]
):
pass # ignore failed calls due to max limit being reached
else:
# ------------
# Setup values
# ------------
if global_max_parallel_requests is not None:
# get value from cache
_key = "global_max_parallel_requests"
(
await self.internal_usage_cache.async_get_cache(
key=_key,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
)
# decrement
await self.internal_usage_cache.async_increment_cache(
key=_key,
value=-1,
local_only=True,
litellm_parent_otel_span=litellm_parent_otel_span,
)
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = (
f"{user_api_key}::{precise_minute}::request_count"
)
# ------------
# Update usage
# ------------
current = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=litellm_parent_otel_span,
) or {
"current_requests": 1,
"current_tpm": 0,
"current_rpm": 0,
}
new_val = {
"current_requests": max(current["current_requests"] - 1, 0),
"current_tpm": current["current_tpm"],
"current_rpm": current["current_rpm"],
}
self.print_verbose(f"updated_value in failure call: {new_val}")
await self.internal_usage_cache.async_set_cache(
request_count_api_key,
new_val,
ttl=60,
litellm_parent_otel_span=litellm_parent_otel_span,
) # save in cache for up to 1 min.
except Exception as e:
verbose_proxy_logger.exception(
"Inside Parallel Request Limiter: An exception occurred - {}".format(
str(e)
)
)
async def get_internal_user_object(
self,
user_id: str,
user_api_key_dict: UserAPIKeyAuth,
) -> Optional[dict]:
"""
Helper to get the 'Internal User Object'
It uses the `get_user_object` function from `litellm.proxy.auth.auth_checks`
We need this because the UserApiKeyAuth object does not contain the rpm/tpm limits for a User AND there could be a perf impact by additionally reading the UserTable.
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.auth_checks import get_user_object
from litellm.proxy.proxy_server import prisma_client
try:
_user_id_rate_limits = await get_user_object(
user_id=user_id,
prisma_client=prisma_client,
user_api_key_cache=self.internal_usage_cache.dual_cache,
user_id_upsert=False,
parent_otel_span=user_api_key_dict.parent_otel_span,
proxy_logging_obj=None,
)
if _user_id_rate_limits is None:
return None
return _user_id_rate_limits.model_dump()
except Exception as e:
verbose_proxy_logger.debug(
"Parallel Request Limiter: Error getting user object", str(e)
)
return None
async def async_post_call_success_hook(
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
):
"""
Retrieve the key's remaining rate limits.
"""
api_key = user_api_key_dict.api_key
current_date = datetime.now().strftime("%Y-%m-%d")
current_hour = datetime.now().strftime("%H")
current_minute = datetime.now().strftime("%M")
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
request_count_api_key = f"{api_key}::{precise_minute}::request_count"
current: Optional[
CurrentItemRateLimit
] = await self.internal_usage_cache.async_get_cache(
key=request_count_api_key,
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
)
key_remaining_rpm_limit: Optional[int] = None
key_rpm_limit: Optional[int] = None
key_remaining_tpm_limit: Optional[int] = None
key_tpm_limit: Optional[int] = None
if current is not None:
if user_api_key_dict.rpm_limit is not None:
key_remaining_rpm_limit = (
user_api_key_dict.rpm_limit - current["current_rpm"]
)
key_rpm_limit = user_api_key_dict.rpm_limit
if user_api_key_dict.tpm_limit is not None:
key_remaining_tpm_limit = (
user_api_key_dict.tpm_limit - current["current_tpm"]
)
key_tpm_limit = user_api_key_dict.tpm_limit
if hasattr(response, "_hidden_params"):
_hidden_params = getattr(response, "_hidden_params")
else:
_hidden_params = None
if _hidden_params is not None and (
isinstance(_hidden_params, BaseModel) or isinstance(_hidden_params, dict)
):
if isinstance(_hidden_params, BaseModel):
_hidden_params = _hidden_params.model_dump()
_additional_headers = _hidden_params.get("additional_headers", {}) or {}
if key_remaining_rpm_limit is not None:
_additional_headers[
"x-ratelimit-remaining-requests"
] = key_remaining_rpm_limit
if key_rpm_limit is not None:
_additional_headers["x-ratelimit-limit-requests"] = key_rpm_limit
if key_remaining_tpm_limit is not None:
_additional_headers[
"x-ratelimit-remaining-tokens"
] = key_remaining_tpm_limit
if key_tpm_limit is not None:
_additional_headers["x-ratelimit-limit-tokens"] = key_tpm_limit
setattr(
response,
"_hidden_params",
{**_hidden_params, "additional_headers": _additional_headers},
)
return await super().async_post_call_success_hook(
data, user_api_key_dict, response
)

View File

@@ -0,0 +1,284 @@
# +------------------------------------+
#
# Prompt Injection Detection
#
# +------------------------------------+
# Thank you users! We ❤️ you! - Krrish & Ishaan
## Reject a call if it contains a prompt injection attack.
from difflib import SequenceMatcher
from typing import List, Literal, Optional
from fastapi import HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.constants import DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.prompt_templates.factory import (
prompt_injection_detection_default_pt,
)
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth
from litellm.router import Router
from litellm.utils import get_formatted_prompt
class _OPTIONAL_PromptInjectionDetection(CustomLogger):
# Class variables or attributes
def __init__(
self,
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None,
):
self.prompt_injection_params = prompt_injection_params
self.llm_router: Optional[Router] = None
self.verbs = [
"Ignore",
"Disregard",
"Skip",
"Forget",
"Neglect",
"Overlook",
"Omit",
"Bypass",
"Pay no attention to",
"Do not follow",
"Do not obey",
]
self.adjectives = [
"",
"prior",
"previous",
"preceding",
"above",
"foregoing",
"earlier",
"initial",
]
self.prepositions = [
"",
"and start over",
"and start anew",
"and begin afresh",
"and start from scratch",
]
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
if level == "INFO":
verbose_proxy_logger.info(print_statement)
elif level == "DEBUG":
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose is True:
print(print_statement) # noqa
def update_environment(self, router: Optional[Router] = None):
self.llm_router = router
if (
self.prompt_injection_params is not None
and self.prompt_injection_params.llm_api_check is True
):
if self.llm_router is None:
raise Exception(
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection."
)
self.print_verbose(
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}"
)
if (
self.prompt_injection_params.llm_api_name is None
or self.prompt_injection_params.llm_api_name
not in self.llm_router.model_names
):
raise Exception(
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'."
)
def generate_injection_keywords(self) -> List[str]:
combinations = []
for verb in self.verbs:
for adj in self.adjectives:
for prep in self.prepositions:
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
if (
len(phrase.split()) > 2
): # additional check to ensure more than 2 words
combinations.append(phrase.lower())
return combinations
def check_user_input_similarity(
self,
user_input: str,
similarity_threshold: float = DEFAULT_PROMPT_INJECTION_SIMILARITY_THRESHOLD,
) -> bool:
user_input_lower = user_input.lower()
keywords = self.generate_injection_keywords()
for keyword in keywords:
# Calculate the length of the keyword to extract substrings of the same length from user input
keyword_length = len(keyword)
for i in range(len(user_input_lower) - keyword_length + 1):
# Extract a substring of the same length as the keyword
substring = user_input_lower[i : i + keyword_length]
# Calculate similarity
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
if match_ratio > similarity_threshold:
self.print_verbose(
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
level="INFO",
)
return True # Found a highly similar substring
return False # No substring crossed the threshold
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
):
try:
"""
- check if user id part of call
- check if user id part of blocked list
"""
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook")
try:
assert call_type in [
"acompletion",
"completion",
"text_completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
]
except Exception:
self.print_verbose(
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
)
return data
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False
if self.prompt_injection_params is not None:
# 1. check if heuristics check turned on
if self.prompt_injection_params.heuristics_check is True:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
# 2. check if vector db similarity check turned on [TODO] Not Implemented yet
if self.prompt_injection_params.vector_db_check is True:
pass
else:
is_prompt_attack = self.check_user_input_similarity(
user_input=formatted_prompt
)
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
return data
except HTTPException as e:
if (
e.status_code == 400
and isinstance(e.detail, dict)
and "error" in e.detail # type: ignore
and self.prompt_injection_params is not None
and self.prompt_injection_params.reject_as_response
):
return e.detail.get("error")
raise e
except Exception as e:
verbose_proxy_logger.exception(
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
str(e)
)
)
async def async_moderation_hook( # type: ignore
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
call_type: Literal[
"acompletion",
"completion",
"embeddings",
"image_generation",
"moderation",
"audio_transcription",
],
) -> Optional[bool]:
self.print_verbose(
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}"
)
if self.prompt_injection_params is None:
return None
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
is_prompt_attack = False
prompt_injection_system_prompt = getattr(
self.prompt_injection_params,
"llm_api_system_prompt",
prompt_injection_detection_default_pt(),
)
# 3. check if llm api check turned on
if (
self.prompt_injection_params.llm_api_check is True
and self.prompt_injection_params.llm_api_name is not None
and self.llm_router is not None
):
# make a call to the llm api
response = await self.llm_router.acompletion(
model=self.prompt_injection_params.llm_api_name,
messages=[
{
"role": "system",
"content": prompt_injection_system_prompt,
},
{"role": "user", "content": formatted_prompt},
],
)
self.print_verbose(f"Received LLM Moderation response: {response}")
self.print_verbose(
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}"
)
if isinstance(response, litellm.ModelResponse) and isinstance(
response.choices[0], litellm.Choices
):
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore
is_prompt_attack = True
if is_prompt_attack is True:
raise HTTPException(
status_code=400,
detail={
"error": "Rejected message. This is a prompt injection attack."
},
)
return is_prompt_attack

View File

@@ -0,0 +1,370 @@
import asyncio
import traceback
from datetime import datetime
from typing import Any, List, Optional, Union, cast
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.auth_checks import (
get_key_object,
get_team_object,
log_db_metrics,
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import ProxyUpdateSpend
from litellm.types.utils import (
StandardLoggingPayload,
StandardLoggingUserAPIKeyMetadata,
)
from litellm.utils import get_end_user_id_for_cost_tracking
class _ProxyDBLogger(CustomLogger):
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
await self._PROXY_track_cost_callback(
kwargs, response_obj, start_time, end_time
)
async def async_post_call_failure_hook(
self,
request_data: dict,
original_exception: Exception,
user_api_key_dict: UserAPIKeyAuth,
traceback_str: Optional[str] = None,
):
request_route = user_api_key_dict.request_route
if _ProxyDBLogger._should_track_errors_in_db() is False:
return
elif request_route is not None and not (
RouteChecks.is_llm_api_route(route=request_route)
or RouteChecks.is_info_route(route=request_route)
):
return
from litellm.proxy.proxy_server import proxy_logging_obj
_metadata = dict(
StandardLoggingUserAPIKeyMetadata(
user_api_key_hash=user_api_key_dict.api_key,
user_api_key_alias=user_api_key_dict.key_alias,
user_api_key_spend=user_api_key_dict.spend,
user_api_key_max_budget=user_api_key_dict.max_budget,
user_api_key_budget_reset_at=(
user_api_key_dict.budget_reset_at.isoformat()
if user_api_key_dict.budget_reset_at
else None
),
user_api_key_user_email=user_api_key_dict.user_email,
user_api_key_user_id=user_api_key_dict.user_id,
user_api_key_team_id=user_api_key_dict.team_id,
user_api_key_org_id=user_api_key_dict.org_id,
user_api_key_project_id=user_api_key_dict.project_id,
user_api_key_team_alias=user_api_key_dict.team_alias,
user_api_key_end_user_id=user_api_key_dict.end_user_id,
user_api_key_request_route=user_api_key_dict.request_route,
user_api_key_auth_metadata=user_api_key_dict.metadata,
)
)
_metadata["user_api_key"] = user_api_key_dict.api_key
_metadata["status"] = "failure"
_metadata[
"error_information"
] = StandardLoggingPayloadSetup.get_error_information(
original_exception=original_exception,
traceback_str=traceback_str,
)
_metadata = await _ProxyDBLogger._enrich_failure_metadata_with_key_info(
metadata=_metadata,
)
existing_metadata: dict = request_data.get("metadata", None) or {}
existing_metadata.update(_metadata)
if "litellm_params" not in request_data:
request_data["litellm_params"] = {}
existing_litellm_params = request_data.get("litellm_params", {})
existing_litellm_metadata = existing_litellm_params.get("metadata", {}) or {}
# Preserve tags from existing metadata
if existing_litellm_metadata.get("tags"):
existing_metadata["tags"] = existing_litellm_metadata.get("tags")
request_data["litellm_params"]["proxy_server_request"] = (
request_data.get("proxy_server_request")
or existing_litellm_params.get("proxy_server_request")
or {}
)
request_data["litellm_params"]["metadata"] = existing_metadata
# Preserve model name and custom_llm_provider
if "model" not in request_data:
request_data["model"] = existing_litellm_params.get(
"model"
) or request_data.get("model", "")
if "custom_llm_provider" not in request_data:
request_data["custom_llm_provider"] = existing_litellm_params.get(
"custom_llm_provider"
) or request_data.get("custom_llm_provider", "")
await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key_dict.api_key,
response_cost=0.0,
user_id=user_api_key_dict.user_id,
end_user_id=user_api_key_dict.end_user_id,
team_id=user_api_key_dict.team_id,
kwargs=request_data,
completion_response=original_exception,
start_time=datetime.now(),
end_time=datetime.now(),
org_id=user_api_key_dict.org_id,
)
@log_db_metrics
async def _PROXY_track_cost_callback(
self,
kwargs, # kwargs to completion
completion_response: Optional[
Union[litellm.ModelResponse, Any]
], # response from completion
start_time=None,
end_time=None, # start/end time for completion
):
from litellm.proxy.proxy_server import proxy_logging_obj, update_cache
verbose_proxy_logger.debug("INSIDE _PROXY_track_cost_callback")
try:
verbose_proxy_logger.debug(
f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}"
)
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs=kwargs)
litellm_params = kwargs.get("litellm_params", {}) or {}
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
user_id = cast(Optional[str], metadata.get("user_api_key_user_id", None))
team_id = cast(Optional[str], metadata.get("user_api_key_team_id", None))
org_id = cast(Optional[str], metadata.get("user_api_key_org_id", None))
key_alias = cast(Optional[str], metadata.get("user_api_key_alias", None))
end_user_max_budget = metadata.get("user_api_end_user_max_budget", None)
sl_object: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
)
response_cost = (
sl_object.get("response_cost", None)
if sl_object is not None
else kwargs.get("response_cost", None)
)
tags: Optional[List[str]] = (
sl_object.get("request_tags", None) if sl_object is not None else None
)
if response_cost is not None:
user_api_key = metadata.get("user_api_key", None)
if kwargs.get("cache_hit", False) is True:
response_cost = 0.0
verbose_proxy_logger.debug(
f"Cache Hit: response_cost {response_cost}, for user_id {user_id}"
)
verbose_proxy_logger.debug(
f"user_api_key {user_api_key}, user_id {user_id}, team_id {team_id}, end_user_id {end_user_id}"
)
if _should_track_cost_callback(
user_api_key=user_api_key,
user_id=user_id,
team_id=team_id,
end_user_id=end_user_id,
):
## UPDATE DATABASE
await proxy_logging_obj.db_spend_update_writer.update_database(
token=user_api_key,
response_cost=response_cost,
user_id=user_id,
end_user_id=end_user_id,
team_id=team_id,
kwargs=kwargs,
completion_response=completion_response,
start_time=start_time,
end_time=end_time,
org_id=org_id,
)
# update cache
asyncio.create_task(
update_cache(
token=user_api_key,
user_id=user_id,
end_user_id=end_user_id,
response_cost=response_cost,
team_id=team_id,
parent_otel_span=parent_otel_span,
tags=tags,
)
)
await proxy_logging_obj.slack_alerting_instance.customer_spend_alert(
token=user_api_key,
key_alias=key_alias,
end_user_id=end_user_id,
response_cost=response_cost,
max_budget=end_user_max_budget,
)
else:
# Non-model call types (health checks, afile_delete) have no model or standard_logging_object.
# Use .get() for "stream" to avoid KeyError on health checks.
if sl_object is None and not kwargs.get("model"):
verbose_proxy_logger.warning(
"Cost tracking - skipping, no standard_logging_object and no model for call_type=%s",
kwargs.get("call_type", "unknown"),
)
return
if kwargs.get("stream") is not True or (
kwargs.get("stream") is True
and "complete_streaming_response" in kwargs
):
if sl_object is not None:
cost_tracking_failure_debug_info: Union[dict, str] = (
sl_object["response_cost_failure_debug_info"] # type: ignore
or "response_cost_failure_debug_info is None in standard_logging_object"
)
else:
cost_tracking_failure_debug_info = (
"standard_logging_object not found"
)
model = kwargs.get("model")
raise Exception(
f"Cost tracking failed for model={model}.\nDebug info - {cost_tracking_failure_debug_info}\nAdd custom pricing - https://docs.litellm.ai/docs/proxy/custom_pricing"
)
except Exception as e:
error_msg = f"Error in tracking cost callback - {str(e)}\n Traceback:{traceback.format_exc()}"
model = kwargs.get("model", "")
metadata = get_litellm_metadata_from_kwargs(kwargs=kwargs)
litellm_metadata = kwargs.get("litellm_params", {}).get(
"litellm_metadata", {}
)
old_metadata = kwargs.get("litellm_params", {}).get("metadata", {})
call_type = kwargs.get("call_type", "")
error_msg += f"\n Args to _PROXY_track_cost_callback\n model: {model}\n chosen_metadata: {metadata}\n litellm_metadata: {litellm_metadata}\n old_metadata: {old_metadata}\n call_type: {call_type}\n"
asyncio.create_task(
proxy_logging_obj.failed_tracking_alert(
error_message=error_msg,
failing_model=model,
)
)
verbose_proxy_logger.exception(
"Error in tracking cost callback - %s", str(e)
)
@staticmethod
async def _enrich_failure_metadata_with_key_info(metadata: dict) -> dict:
"""
Enriches failure spend log metadata by looking up the key object (and team object)
from cache/DB when key fields are missing.
This handles two scenarios:
1. Auth errors (401): UserAPIKeyAuth is created with only api_key set, all other
fields are null. We look up the full key object to fill in alias, user_id,
team_id, etc.
2. Post-auth failures (provider errors, rate limits): key fields are populated
but team_alias is missing because LiteLLM_VerificationTokenView SQL view
doesn't include it. We look up the team object to fill in team_alias.
"""
api_key_hash = metadata.get("user_api_key")
if not api_key_hash:
return metadata
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
# Step 1: If key fields are missing, look up the full key object
if metadata.get("user_api_key_alias") is None:
try:
key_obj = await get_key_object(
hashed_token=api_key_hash,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if metadata.get("user_api_key_alias") is None:
metadata["user_api_key_alias"] = key_obj.key_alias
if metadata.get("user_api_key_user_id") is None:
metadata["user_api_key_user_id"] = key_obj.user_id
if metadata.get("user_api_key_team_id") is None:
metadata["user_api_key_team_id"] = key_obj.team_id
if metadata.get("user_api_key_org_id") is None:
metadata["user_api_key_org_id"] = key_obj.org_id
except Exception:
verbose_proxy_logger.debug(
"Failed to enrich failure metadata with key info for api_key=%s",
api_key_hash,
)
# Step 2: If team_id is known but team_alias is missing, look up the team object
team_id = metadata.get("user_api_key_team_id")
if team_id and metadata.get("user_api_key_team_alias") is None:
try:
team_obj = await get_team_object(
team_id=team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
if team_obj.team_alias is not None:
metadata["user_api_key_team_alias"] = team_obj.team_alias
except Exception:
verbose_proxy_logger.debug(
"Failed to enrich failure metadata with team_alias for team_id=%s",
team_id,
)
return metadata
@staticmethod
def _should_track_errors_in_db():
"""
Returns True if errors should be tracked in the database
By default, errors are tracked in the database
If users want to disable error tracking, they can set the disable_error_logs flag in the general_settings
"""
from litellm.proxy.proxy_server import general_settings
if general_settings.get("disable_error_logs") is True:
return False
return
def _should_track_cost_callback(
user_api_key: Optional[str],
user_id: Optional[str],
team_id: Optional[str],
end_user_id: Optional[str],
) -> bool:
"""
Determine if the cost callback should be tracked based on the kwargs
"""
# don't run track cost callback if user opted into disabling spend
if ProxyUpdateSpend.disable_spend_updates() is True:
return False
if (
user_api_key is not None
or user_id is not None
or team_id is not None
or end_user_id is not None
):
return True
return False

View File

@@ -0,0 +1,45 @@
"""
Shared utility functions for rate limiter hooks.
"""
from typing import Optional, Union
from litellm.types.router import ModelGroupInfo
from litellm.types.utils import PriorityReservationDict
def convert_priority_to_percent(
value: Union[float, PriorityReservationDict], model_info: Optional[ModelGroupInfo]
) -> float:
"""
Convert priority reservation value to percentage (0.0-1.0).
Supports three formats:
1. Plain float/int: 0.9 -> 0.9 (90%)
2. Dict with percent: {"type": "percent", "value": 0.9} -> 0.9
3. Dict with rpm: {"type": "rpm", "value": 900} -> 900/model_rpm
4. Dict with tpm: {"type": "tpm", "value": 900000} -> 900000/model_tpm
Args:
value: Priority value as float or dict with type/value keys
model_info: Model configuration containing rpm/tpm limits
Returns:
float: Percentage value between 0.0 and 1.0
"""
if isinstance(value, (int, float)):
return float(value)
if isinstance(value, dict):
val_type = value.get("type", "percent")
val_num = value.get("value", 1.0)
if val_type == "percent":
return float(val_num)
elif val_type == "rpm" and model_info and model_info.rpm and model_info.rpm > 0:
return float(val_num) / model_info.rpm
elif val_type == "tpm" and model_info and model_info.tpm and model_info.tpm > 0:
return float(val_num) / model_info.tpm
# Fallback: treat as percent
return float(val_num)

View File

@@ -0,0 +1,296 @@
"""
Security hook to prevent user B from seeing response from user A.
This hook uses the DBSpendUpdateWriter to batch-write response IDs to the database
instead of writing immediately on each request.
"""
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, Tuple, Union, cast
from fastapi import HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.types.llms.openai import (
BaseLiteLLMOpenAIResponseObject,
ResponsesAPIResponse,
)
from litellm.types.utils import CallTypesLiteral, LLMResponseTypes, SpecialEnums
if TYPE_CHECKING:
from litellm.caching.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
class ResponsesIDSecurity(CustomLogger):
def __init__(self):
pass
async def async_pre_call_hook(
self,
user_api_key_dict: "UserAPIKeyAuth",
cache: "DualCache",
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
# MAP all the responses api response ids to the encrypted response ids
responses_api_call_types = {
"aresponses",
"aget_responses",
"adelete_responses",
"acancel_responses",
}
if call_type not in responses_api_call_types:
return None
if call_type == "aresponses":
# check 'previous_response_id' if present in the data
previous_response_id = data.get("previous_response_id")
if previous_response_id and self._is_encrypted_response_id(
previous_response_id
):
original_response_id, user_id, team_id = self._decrypt_response_id(
previous_response_id
)
self.check_user_access_to_response_id(
user_id, team_id, user_api_key_dict
)
data["previous_response_id"] = original_response_id
elif call_type in {"aget_responses", "adelete_responses", "acancel_responses"}:
response_id = data.get("response_id")
if response_id and self._is_encrypted_response_id(response_id):
original_response_id, user_id, team_id = self._decrypt_response_id(
response_id
)
self.check_user_access_to_response_id(
user_id, team_id, user_api_key_dict
)
data["response_id"] = original_response_id
return data
def check_user_access_to_response_id(
self,
response_id_user_id: Optional[str],
response_id_team_id: Optional[str],
user_api_key_dict: "UserAPIKeyAuth",
) -> bool:
from litellm.proxy.proxy_server import general_settings
if (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
):
return True
if response_id_user_id and response_id_user_id != user_api_key_dict.user_id:
if general_settings.get("disable_responses_id_security", False):
verbose_proxy_logger.debug(
f"Responses ID Security is disabled. User {user_api_key_dict.user_id} is accessing response id {response_id_user_id} which is not associated with them."
)
return True
raise HTTPException(
status_code=403,
detail="Forbidden. The response id is not associated with the user, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
)
if response_id_team_id and response_id_team_id != user_api_key_dict.team_id:
if general_settings.get("disable_responses_id_security", False):
verbose_proxy_logger.debug(
f"Responses ID Security is disabled. Response belongs to team {response_id_team_id} but user {user_api_key_dict.user_id} is accessing it with team id {user_api_key_dict.team_id}."
)
return True
raise HTTPException(
status_code=403,
detail="Forbidden. The response id is not associated with the team, who this key belongs to. To disable this security feature, set general_settings::disable_responses_id_security to True in the config.yaml file.",
)
return True
def _is_encrypted_response_id(self, response_id: str) -> bool:
split_result = response_id.split("resp_")
if len(split_result) < 2:
return False
remaining_string = split_result[1]
decrypted_value = decrypt_value_helper(
value=remaining_string, key="response_id", return_original_value=True
)
if decrypted_value is None:
return False
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
return True
return False
def _decrypt_response_id(
self, response_id: str
) -> Tuple[str, Optional[str], Optional[str]]:
"""
Returns:
- original_response_id: the original response id
- user_id: the user id
- team_id: the team id
"""
split_result = response_id.split("resp_")
if len(split_result) < 2:
return response_id, None, None
remaining_string = split_result[1]
decrypted_value = decrypt_value_helper(
value=remaining_string, key="response_id", return_original_value=True
)
if decrypted_value is None:
return response_id, None, None
if decrypted_value.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
# Expected format: "litellm_proxy:responses_api:response_id:{response_id};user_id:{user_id}"
parts = decrypted_value.split(";")
if len(parts) >= 2:
# Extract response_id from "litellm_proxy:responses_api:response_id:{response_id}"
response_id_part = parts[0]
original_response_id = response_id_part.split("response_id:")[-1]
# Extract user_id from "user_id:{user_id}"
user_id_part = parts[1]
user_id = user_id_part.split("user_id:")[-1]
# Extract team_id from "team_id:{team_id}"
team_id_part = parts[2]
team_id = team_id_part.split("team_id:")[-1]
return original_response_id, user_id, team_id
else:
# Fallback if format is unexpected
return response_id, None, None
return response_id, None, None
def _get_signing_key(self) -> Optional[str]:
"""Get the signing key for encryption/decryption."""
import os
from litellm.proxy.proxy_server import master_key
salt_key = os.getenv("LITELLM_SALT_KEY", None)
if salt_key is None:
salt_key = master_key
return salt_key
def _encrypt_response_id(
self,
response: BaseLiteLLMOpenAIResponseObject,
user_api_key_dict: "UserAPIKeyAuth",
request_cache: Optional[dict[str, str]] = None,
) -> BaseLiteLLMOpenAIResponseObject:
# encrypt the response id using the symmetric key
# encrypt the response id, and encode the user id and response id in base64
# Check if signing key is available
signing_key = self._get_signing_key()
if signing_key is None:
verbose_proxy_logger.debug(
"Response ID encryption is enabled but no signing key is configured. "
"Please set LITELLM_SALT_KEY environment variable or configure a master_key. "
"Skipping response ID encryption. "
"See: https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
)
return response
response_id = getattr(response, "id", None)
response_obj = getattr(response, "response", None)
if (
response_id
and isinstance(response_id, str)
and response_id.startswith("resp_")
):
# Check request-scoped cache first (for streaming consistency)
if request_cache is not None and response_id in request_cache:
setattr(response, "id", request_cache[response_id])
else:
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
response_id,
user_api_key_dict.user_id or "",
user_api_key_dict.team_id or "",
)
encoded_user_id_and_response_id = encrypt_value_helper(
value=encrypted_response_id
)
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
if request_cache is not None:
request_cache[response_id] = encrypted_id
setattr(response, "id", encrypted_id)
elif response_obj and isinstance(response_obj, ResponsesAPIResponse):
# Check request-scoped cache first (for streaming consistency)
if request_cache is not None and response_obj.id in request_cache:
setattr(response_obj, "id", request_cache[response_obj.id])
else:
encrypted_response_id = SpecialEnums.LITELLM_MANAGED_RESPONSE_API_RESPONSE_ID_COMPLETE_STR.value.format(
response_obj.id,
user_api_key_dict.user_id or "",
user_api_key_dict.team_id or "",
)
encoded_user_id_and_response_id = encrypt_value_helper(
value=encrypted_response_id
)
encrypted_id = f"resp_{encoded_user_id_and_response_id}"
if request_cache is not None:
request_cache[response_obj.id] = encrypted_id
setattr(response_obj, "id", encrypted_id)
setattr(response, "response", response_obj)
return response
async def async_post_call_success_hook(
self,
data: dict,
user_api_key_dict: "UserAPIKeyAuth",
response: LLMResponseTypes,
) -> Any:
"""
Queue response IDs for batch processing instead of writing directly to DB.
This method adds response IDs to an in-memory queue, which are then
batch-processed by the DBSpendUpdateWriter during regular database update cycles.
"""
from litellm.proxy.proxy_server import general_settings
if general_settings.get("disable_responses_id_security", False):
return response
if isinstance(response, ResponsesAPIResponse):
response = cast(
ResponsesAPIResponse,
self._encrypt_response_id(
response, user_api_key_dict, request_cache=None
),
)
return response
async def async_post_call_streaming_iterator_hook( # type: ignore
self, user_api_key_dict: "UserAPIKeyAuth", response: Any, request_data: dict
) -> AsyncGenerator[BaseLiteLLMOpenAIResponseObject, None]:
from litellm.proxy.proxy_server import general_settings
# Create a request-scoped cache for consistent encryption across streaming chunks.
request_encryption_cache: dict[str, str] = {}
async for chunk in response:
if (
isinstance(chunk, BaseLiteLLMOpenAIResponseObject)
and user_api_key_dict.request_route
== "/v1/responses" # only encrypt the response id for the responses api
and not general_settings.get("disable_responses_id_security", False)
):
chunk = self._encrypt_response_id(
chunk, user_api_key_dict, request_encryption_cache
)
yield chunk

View File

@@ -0,0 +1,209 @@
"""
Hooks that are triggered when a litellm user event occurs
"""
import asyncio
from litellm._uuid import uuid
from datetime import datetime, timezone
from typing import Optional
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
AUDIT_ACTIONS,
CommonProxyErrors,
LiteLLM_AuditLogs,
Litellm_EntityType,
LiteLLM_UserTable,
LitellmTableNames,
NewUserRequest,
NewUserResponse,
UserAPIKeyAuth,
WebhookEvent,
)
from litellm.proxy.management_helpers.audit_logs import create_audit_log_for_update
class UserManagementEventHooks:
@staticmethod
async def async_user_created_hook(
data: NewUserRequest,
response: NewUserResponse,
user_api_key_dict: UserAPIKeyAuth,
):
"""
This hook is called when a new user is created on litellm
Handles:
- Creating an audit log for the user creation
- Sending a user invitation email to the user
"""
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
#########################################################
########## Send User Invitation Email ################
#########################################################
await UserManagementEventHooks.async_send_user_invitation_email(
data=data,
response=response,
user_api_key_dict=user_api_key_dict,
)
#########################################################
########## CREATE AUDIT LOG ################
#########################################################
try:
if prisma_client is None:
raise Exception(CommonProxyErrors.db_not_connected_error.value)
user_row: BaseModel = await prisma_client.db.litellm_usertable.find_first(
where={"user_id": response.user_id}
)
user_row_litellm_typed = LiteLLM_UserTable(
**user_row.model_dump(exclude_none=True)
)
asyncio.create_task(
UserManagementEventHooks.create_internal_user_audit_log(
user_id=user_row_litellm_typed.user_id,
action="created",
litellm_changed_by=user_api_key_dict.user_id,
user_api_key_dict=user_api_key_dict,
litellm_proxy_admin_name=litellm_proxy_admin_name,
before_value=None,
after_value=user_row_litellm_typed.model_dump_json(
exclude_none=True
),
)
)
except Exception as e:
verbose_proxy_logger.warning(
"Unable to create audit log for user on `/user/new` - {}".format(str(e))
)
pass
@staticmethod
async def async_send_user_invitation_email(
data: NewUserRequest,
response: NewUserResponse,
user_api_key_dict: UserAPIKeyAuth,
):
"""
Send a user invitation email to the user
"""
event = WebhookEvent(
event="internal_user_created",
event_group=Litellm_EntityType.USER,
event_message="Welcome to LiteLLM Proxy",
token=response.token,
spend=response.spend or 0.0,
max_budget=response.max_budget,
user_id=response.user_id,
user_email=response.user_email,
team_id=response.team_id,
key_alias=response.key_alias,
)
#########################################################
########## V2 USER INVITATION EMAIL ################
#########################################################
try:
from litellm_enterprise.enterprise_callbacks.send_emails.base_email import (
BaseEmailLogger,
)
use_enterprise_email_hooks = True
except ImportError:
verbose_proxy_logger.warning(
"Defaulting to using Legacy Email Hooks."
+ CommonProxyErrors.missing_enterprise_package.value
)
use_enterprise_email_hooks = False
if use_enterprise_email_hooks and (data.send_invite_email is True):
initialized_email_loggers = litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=BaseEmailLogger # type: ignore
)
if len(initialized_email_loggers) > 0:
for email_logger in initialized_email_loggers:
if isinstance(email_logger, BaseEmailLogger): # type: ignore
await email_logger.send_user_invitation_email( # type: ignore
event=event,
)
#########################################################
########## LEGACY V1 USER INVITATION EMAIL ################
#########################################################
if data.send_invite_email is True:
await UserManagementEventHooks.send_legacy_v1_user_invitation_email(
data=data,
response=response,
user_api_key_dict=user_api_key_dict,
event=event,
)
@staticmethod
async def send_legacy_v1_user_invitation_email(
data: NewUserRequest,
response: NewUserResponse,
user_api_key_dict: UserAPIKeyAuth,
event: WebhookEvent,
):
"""
Send a user invitation email to the user
"""
from litellm.proxy.proxy_server import general_settings, proxy_logging_obj
# check if user has setup email alerting
if "email" not in general_settings.get("alerting", []):
raise ValueError(
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
)
# If user configured email alerting - send an Email letting their end-user know the key was created
asyncio.create_task(
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
webhook_event=event,
)
)
@staticmethod
async def create_internal_user_audit_log(
user_id: str,
action: AUDIT_ACTIONS,
litellm_changed_by: Optional[str],
user_api_key_dict: UserAPIKeyAuth,
litellm_proxy_admin_name: Optional[str],
before_value: Optional[str] = None,
after_value: Optional[str] = None,
):
"""
Create an audit log for an internal user.
Parameters:
- user_id: str - The id of the user to create the audit log for.
- action: AUDIT_ACTIONS - The action to create the audit log for.
- user_row: LiteLLM_UserTable - The user row to create the audit log for.
- litellm_changed_by: Optional[str] - The user id of the user who is changing the user.
- user_api_key_dict: UserAPIKeyAuth - The user api key dictionary.
- litellm_proxy_admin_name: Optional[str] - The name of the proxy admin.
"""
if not litellm.store_audit_logs:
return
await create_audit_log_for_update(
request_data=LiteLLM_AuditLogs(
id=str(uuid.uuid4()),
updated_at=datetime.now(timezone.utc),
changed_by=litellm_changed_by
or user_api_key_dict.user_id
or litellm_proxy_admin_name,
changed_by_api_key=user_api_key_dict.api_key,
table_name=LitellmTableNames.USER_TABLE_NAME,
object_id=user_id,
action=action,
updated_values=after_value,
before_value=before_value,
)
)