chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,169 @@
def show_missing_vars_in_env():
from fastapi.responses import HTMLResponse
from litellm.proxy.proxy_server import master_key, prisma_client
if prisma_client is None and master_key is None:
return HTMLResponse(
content=missing_keys_form(
missing_key_names="DATABASE_URL, LITELLM_MASTER_KEY"
),
status_code=200,
)
if prisma_client is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="DATABASE_URL"), status_code=200
)
if master_key is None:
return HTMLResponse(
content=missing_keys_form(missing_key_names="LITELLM_MASTER_KEY"),
status_code=200,
)
return None
def missing_keys_form(missing_key_names: str):
missing_keys_html_form = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Environment Setup Instructions</title>
</head>
<body>
<div class="container">
<h1>Environment Setup Instructions</h1>
<p>Please add the following variables to your environment variables:</p>
<pre>
<span class="env-var">LITELLM_MASTER_KEY="sk-1234"</span> <span class="comment"># Your master key for the proxy server. Can use this to send /chat/completion requests etc</span>
<span class="env-var">LITELLM_SALT_KEY="sk-XXXXXXXX"</span> <span class="comment"># Can NOT CHANGE THIS ONCE SET - It is used to encrypt/decrypt credentials stored in DB. If value of 'LITELLM_SALT_KEY' changes your models cannot be retrieved from DB</span>
<span class="env-var">DATABASE_URL="postgres://..."</span> <span class="comment"># Need a postgres database? (Check out Supabase, Neon, etc)</span>
<span class="comment">## OPTIONAL ##</span>
<span class="env-var">PORT=4000</span> <span class="comment"># DO THIS FOR RENDER/RAILWAY</span>
<span class="env-var">STORE_MODEL_IN_DB="True"</span> <span class="comment"># Allow storing models in db</span>
</pre>
<h1>Missing Environment Variables</h1>
<p>{missing_keys}</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return missing_keys_html_form.format(missing_keys=missing_key_names)
def admin_ui_disabled():
from fastapi.responses import HTMLResponse
ui_disabled_html = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: Arial, sans-serif;
background-color: #f4f4f9;
color: #333;
margin: 20px;
line-height: 1.6;
}}
.container {{
max-width: 800px;
margin: auto;
padding: 20px;
background: #fff;
border: 1px solid #ddd;
border-radius: 5px;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}}
h1 {{
font-size: 24px;
margin-bottom: 20px;
}}
pre {{
background: #f8f8f8;
padding: 1px;
border: 1px solid #ccc;
border-radius: 4px;
overflow-x: auto;
font-size: 14px;
}}
.env-var {{
font-weight: normal;
}}
.comment {{
font-weight: normal;
color: #777;
}}
</style>
<title>Admin UI Disabled</title>
</head>
<body>
<div class="container">
<h1>Admin UI is Disabled</h1>
<p>The Admin UI has been disabled by the administrator. To re-enable it, please update the following environment variable:</p>
<pre>
<span class="env-var">DISABLE_ADMIN_UI="False"</span> <span class="comment"># Set this to "False" to enable the Admin UI.</span>
</pre>
<p>After making this change, restart the application for it to take effect.</p>
</div>
<div class="container">
<h1>Need Help? Support</h1>
<p>Discord: <a href="https://discord.com/invite/wuPM9dRgDw" target="_blank">https://discord.com/invite/wuPM9dRgDw</a></p>
<p>Docs: <a href="https://docs.litellm.ai/docs/" target="_blank">https://docs.litellm.ai/docs/</a></p>
</div>
</body>
</html>
"""
return HTMLResponse(
content=ui_disabled_html,
status_code=200,
)

View File

@@ -0,0 +1,17 @@
# LiteLLM ASCII banner
LITELLM_BANNER = """ ██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝"""
def show_banner():
"""Display the LiteLLM CLI banner."""
try:
import click
click.echo(f"\n{LITELLM_BANNER}\n")
except ImportError:
print("\n") # noqa: T201

View File

@@ -0,0 +1,190 @@
"""
Event-driven cache coordinator to prevent cache stampede.
Use this when many requests can miss the same cache key at once (e.g. after
expiry or restart). Without coordination, they would all run the expensive
load (DB query, API call) in parallel and overload the backend.
This module ensures only one request performs the load; the rest wait for a
signal and then read the freshly cached value. Reuse it for any cache-aside
pattern: global spend, feature flags, config, or other shared read-through data.
"""
import asyncio
import time
from typing import Any, Awaitable, Callable, Optional, Protocol, TypeVar
from litellm._logging import verbose_proxy_logger
T = TypeVar("T")
class AsyncCacheProtocol(Protocol):
"""Protocol for cache backends used by EventDrivenCacheCoordinator."""
async def async_get_cache(self, key: str, **kwargs: Any) -> Any:
...
async def async_set_cache(self, key: str, value: Any, **kwargs: Any) -> Any:
...
class EventDrivenCacheCoordinator:
"""
Coordinates a single in-flight load per logical resource to prevent cache stampede.
Pattern:
- First request: loads data (e.g. DB query), caches it, then signals waiters.
- Other requests: wait for the signal, then read from cache.
Create one instance per resource (e.g. one for global spend, one for feature flags).
"""
def __init__(self, log_prefix: str = "[CACHE]"):
self._lock = asyncio.Lock()
self._event: Optional[asyncio.Event] = None
self._query_in_progress = False
self._log_prefix = log_prefix
async def _get_cached(
self, cache_key: str, cache: AsyncCacheProtocol
) -> Optional[Any]:
"""Return value from cache if present, else None."""
return await cache.async_get_cache(key=cache_key)
def _log_cache_hit(self, value: T) -> None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache hit, value: %s", self._log_prefix, value
)
def _log_cache_miss(self) -> None:
if self._log_prefix:
verbose_proxy_logger.debug("%s Cache miss", self._log_prefix)
async def _claim_role(self) -> Optional[asyncio.Event]:
"""
Under lock: return event to wait on if load is in progress, else set us as loader and return None.
"""
async with self._lock:
if self._query_in_progress and self._event is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Load in flight, waiting for signal", self._log_prefix
)
return self._event
self._query_in_progress = True
self._event = asyncio.Event()
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Starting load (will signal others when done)",
self._log_prefix,
)
return None
async def _wait_for_signal_and_get(
self,
event: asyncio.Event,
cache_key: str,
cache: AsyncCacheProtocol,
) -> Optional[T]:
"""Wait for loader to finish, then read from cache."""
await event.wait()
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Signal received, reading from cache", self._log_prefix
)
value: Optional[T] = await cache.async_get_cache(key=cache_key)
if value is not None and self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache filled by other request, value: %s",
self._log_prefix,
value,
)
elif value is None and self._log_prefix:
verbose_proxy_logger.debug(
"%s Signal received but cache still empty", self._log_prefix
)
return value
async def _load_and_cache(
self,
cache_key: str,
cache: AsyncCacheProtocol,
load_fn: Callable[[], Awaitable[T]],
) -> Optional[T]:
"""Double-check cache, run load_fn, set cache, return value. Caller must call _signal_done in finally."""
value = await cache.async_get_cache(key=cache_key)
if value is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Cache filled while acquiring lock, value: %s",
self._log_prefix,
value,
)
return value
if self._log_prefix:
verbose_proxy_logger.debug("%s Running load", self._log_prefix)
start = time.perf_counter()
value = await load_fn()
elapsed_ms = (time.perf_counter() - start) * 1000
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Load completed in %.2fms, result: %s",
self._log_prefix,
elapsed_ms,
value,
)
await cache.async_set_cache(key=cache_key, value=value)
if self._log_prefix:
verbose_proxy_logger.debug("%s Result cached", self._log_prefix)
return value
async def _signal_done(self) -> None:
"""Reset loader state and signal all waiters."""
async with self._lock:
self._query_in_progress = False
if self._event is not None:
if self._log_prefix:
verbose_proxy_logger.debug(
"%s Signaling all waiting requests", self._log_prefix
)
self._event.set()
self._event = None
async def get_or_load(
self,
cache_key: str,
cache: AsyncCacheProtocol,
load_fn: Callable[[], Awaitable[T]],
) -> Optional[T]:
"""
Return cached value or load it once and signal waiters.
- cache_key: Key to read/write in the cache.
- cache: Object with async_get_cache(key) and async_set_cache(key, value).
- load_fn: Async callable that performs the load (e.g. DB query). No args.
Return value is cached and returned. If it raises, waiters are
still signaled so they can retry or handle empty cache.
Returns the value from cache or from load_fn, or None if load failed or
cache was still empty after waiting.
"""
value = await self._get_cached(cache_key, cache)
if value is not None:
self._log_cache_hit(value)
return value
self._log_cache_miss()
event_to_wait = await self._claim_role()
if event_to_wait is not None:
return await self._wait_for_signal_and_get(event_to_wait, cache_key, cache)
try:
result = await self._load_and_cache(cache_key, cache, load_fn)
return result
finally:
await self._signal_done()

View File

@@ -0,0 +1,526 @@
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Optional
import litellm
from litellm import get_secret
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.utils import (
StandardLoggingGuardrailInformation,
StandardLoggingPayload,
)
blue_color_code = "\033[94m"
reset_color_code = "\033[0m"
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
def initialize_callbacks_on_proxy( # noqa: PLR0915
value: Any,
premium_user: bool,
config_file_path: str,
litellm_settings: dict,
callback_specific_params: dict = {},
):
from litellm.integrations.custom_logger import CustomLogger
from litellm.litellm_core_utils.logging_callback_manager import (
LoggingCallbackManager,
)
from litellm.proxy.proxy_server import prisma_client
verbose_proxy_logger.debug(
f"{blue_color_code}initializing callbacks={value} on proxy{reset_color_code}"
)
if isinstance(value, list):
imported_list: List[Any] = []
for callback in value: # ["presidio", <my-custom-callback>]
# check if callback is a custom logger compatible callback
if isinstance(callback, str):
callback = LoggingCallbackManager._add_custom_callback_generic_api_str(
callback
)
if (
isinstance(callback, str)
and callback in litellm._known_custom_logger_compatible_callbacks
):
imported_list.append(callback)
elif isinstance(callback, str) and callback == "presidio":
from litellm.proxy.guardrails.guardrail_hooks.presidio import (
_OPTIONAL_PresidioPIIMasking,
)
presidio_logging_only: Optional[bool] = litellm_settings.get(
"presidio_logging_only", None
)
if presidio_logging_only is not None:
presidio_logging_only = bool(
presidio_logging_only
) # validate boolean given
_presidio_params = {}
if "presidio" in callback_specific_params and isinstance(
callback_specific_params["presidio"], dict
):
_presidio_params = callback_specific_params["presidio"]
params: Dict[str, Any] = {
"logging_only": presidio_logging_only,
**_presidio_params,
}
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object)
elif isinstance(callback, str) and callback == "llamaguard_moderations":
try:
from litellm_enterprise.enterprise_callbacks.llama_guard import (
_ENTERPRISE_LlamaGuard,
)
except ImportError:
raise Exception(
"MissingTrying to use Llama Guard"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
)
llama_guard_object = _ENTERPRISE_LlamaGuard()
imported_list.append(llama_guard_object)
elif isinstance(callback, str) and callback == "hide_secrets":
try:
from litellm_enterprise.enterprise_callbacks.secret_detection import (
_ENTERPRISE_SecretDetection,
)
except ImportError:
raise Exception(
"Trying to use Secret Detection"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
)
_secret_detection_object = _ENTERPRISE_SecretDetection()
imported_list.append(_secret_detection_object)
elif isinstance(callback, str) and callback == "openai_moderations":
try:
from enterprise.enterprise_hooks.openai_moderation import (
_ENTERPRISE_OpenAI_Moderation,
)
except ImportError:
raise Exception(
"Trying to use OpenAI Moderations Check,"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
)
openai_moderations_object = _ENTERPRISE_OpenAI_Moderation()
imported_list.append(openai_moderations_object)
elif isinstance(callback, str) and callback == "lakera_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.lakera_ai import (
lakeraAI_Moderation,
)
init_params = {}
if "lakera_prompt_injection" in callback_specific_params:
init_params = callback_specific_params["lakera_prompt_injection"]
lakera_moderations_object = lakeraAI_Moderation(**init_params)
imported_list.append(lakera_moderations_object)
elif isinstance(callback, str) and callback == "aporia_prompt_injection":
from litellm.proxy.guardrails.guardrail_hooks.aporia_ai.aporia_ai import (
AporiaGuardrail,
)
aporia_guardrail_object = AporiaGuardrail()
imported_list.append(aporia_guardrail_object)
elif isinstance(callback, str) and callback == "google_text_moderation":
try:
from enterprise.enterprise_hooks.google_text_moderation import (
_ENTERPRISE_GoogleTextModeration,
)
except ImportError:
raise Exception(
"Trying to use Google Text Moderation,"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
)
google_text_moderation_obj = _ENTERPRISE_GoogleTextModeration()
imported_list.append(google_text_moderation_obj)
elif isinstance(callback, str) and callback == "llmguard_moderations":
try:
from litellm_enterprise.enterprise_callbacks.llm_guard import (
_ENTERPRISE_LLMGuard,
)
except ImportError:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.missing_enterprise_package.value
)
if premium_user is not True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
)
llm_guard_moderation_obj = _ENTERPRISE_LLMGuard()
imported_list.append(llm_guard_moderation_obj)
elif isinstance(callback, str) and callback == "blocked_user_check":
try:
from enterprise.enterprise_hooks.blocked_user_list import (
_ENTERPRISE_BlockedUserList,
)
except ImportError:
raise Exception(
"Trying to use Blocked User List"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
)
blocked_user_list = _ENTERPRISE_BlockedUserList(
prisma_client=prisma_client
)
imported_list.append(blocked_user_list)
elif isinstance(callback, str) and callback == "banned_keywords":
try:
from enterprise.enterprise_hooks.banned_keywords import (
_ENTERPRISE_BannedKeywords,
)
except ImportError:
raise Exception(
"Trying to use Banned Keywords"
+ CommonProxyErrors.missing_enterprise_package_docker.value
)
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
)
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
imported_list.append(banned_keywords_obj)
elif isinstance(callback, str) and callback == "detect_prompt_injection":
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
prompt_injection_params = None
if "prompt_injection_params" in litellm_settings:
prompt_injection_params_in_config = litellm_settings[
"prompt_injection_params"
]
prompt_injection_params = LiteLLMPromptInjectionParams(
**prompt_injection_params_in_config
)
prompt_injection_detection_obj = _OPTIONAL_PromptInjectionDetection(
prompt_injection_params=prompt_injection_params,
)
imported_list.append(prompt_injection_detection_obj)
elif isinstance(callback, str) and callback == "batch_redis_requests":
from litellm.proxy.hooks.batch_redis_get import (
_PROXY_BatchRedisRequests,
)
batch_redis_obj = _PROXY_BatchRedisRequests()
imported_list.append(batch_redis_obj)
elif isinstance(callback, str) and callback == "azure_content_safety":
from litellm.proxy.hooks.azure_content_safety import (
_PROXY_AzureContentSafety,
)
azure_content_safety_params = litellm_settings[
"azure_content_safety_params"
]
for k, v in azure_content_safety_params.items():
if (
v is not None
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,
)
imported_list.append(azure_content_safety_obj)
elif isinstance(callback, str) and callback == "websearch_interception":
from litellm.integrations.websearch_interception.handler import (
WebSearchInterceptionLogger,
)
websearch_interception_obj = (
WebSearchInterceptionLogger.initialize_from_proxy_config(
litellm_settings=litellm_settings,
callback_specific_params=callback_specific_params,
)
)
imported_list.append(websearch_interception_obj)
elif isinstance(callback, str) and callback == "datadog_cost_management":
from litellm.integrations.datadog.datadog_cost_management import (
DatadogCostManagementLogger,
)
datadog_cost_management_obj = DatadogCostManagementLogger()
imported_list.append(datadog_cost_management_obj)
elif isinstance(callback, CustomLogger):
imported_list.append(callback)
else:
verbose_proxy_logger.debug(
f"{blue_color_code} attempting to import custom calback={callback} {reset_color_code}"
)
imported_list.append(
get_instance_fn(
value=callback,
config_file_path=config_file_path,
)
)
if isinstance(litellm.callbacks, list):
litellm.callbacks.extend(imported_list)
else:
litellm.callbacks = imported_list # type: ignore
if "prometheus" in value:
from litellm.integrations.prometheus import PrometheusLogger
PrometheusLogger._mount_metrics_endpoint()
else:
litellm.callbacks = [
get_instance_fn(
value=value,
config_file_path=config_file_path,
)
]
verbose_proxy_logger.debug(
f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}"
)
def get_model_group_from_litellm_kwargs(kwargs: dict) -> Optional[str]:
_litellm_params = kwargs.get("litellm_params", None) or {}
_metadata = (
_litellm_params.get(get_metadata_variable_name_from_kwargs(kwargs)) or {}
)
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_model_group_from_request_data(data: dict) -> Optional[str]:
_metadata = data.get("metadata", None) or {}
_model_group = _metadata.get("model_group", None)
if _model_group is not None:
return _model_group
return None
def get_remaining_tokens_and_requests_from_request_data(data: Dict) -> Dict[str, str]:
"""
Helper function to return x-litellm-key-remaining-tokens-{model_group} and x-litellm-key-remaining-requests-{model_group}
Returns {} when api_key + model rpm/tpm limit is not set
"""
headers = {}
_metadata = data.get("metadata", None) or {}
model_group = get_model_group_from_request_data(data)
# The h11 package considers "/" or ":" invalid and raise a LocalProtocolError
h11_model_group_name = (
model_group.replace("/", "-").replace(":", "-") if model_group else None
)
# Remaining Requests
remaining_requests_variable_name = f"litellm-key-remaining-requests-{model_group}"
remaining_requests = _metadata.get(remaining_requests_variable_name, None)
if remaining_requests:
headers[
f"x-litellm-key-remaining-requests-{h11_model_group_name}"
] = remaining_requests
# Remaining Tokens
remaining_tokens_variable_name = f"litellm-key-remaining-tokens-{model_group}"
remaining_tokens = _metadata.get(remaining_tokens_variable_name, None)
if remaining_tokens:
headers[
f"x-litellm-key-remaining-tokens-{h11_model_group_name}"
] = remaining_tokens
return headers
def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]:
_metadata = request_data.get("metadata", None)
if not _metadata:
_metadata = request_data.get("litellm_metadata", None)
if not isinstance(_metadata, dict):
_metadata = {}
headers = {}
if "applied_guardrails" in _metadata:
headers["x-litellm-applied-guardrails"] = ",".join(
_metadata["applied_guardrails"]
)
if "applied_policies" in _metadata:
headers["x-litellm-applied-policies"] = ",".join(_metadata["applied_policies"])
if "policy_sources" in _metadata:
sources = _metadata["policy_sources"]
if isinstance(sources, dict) and sources:
# Use ';' as delimiter — matched_via reasons may contain commas
headers["x-litellm-policy-sources"] = "; ".join(
f"{name}={reason}" for name, reason in sources.items()
)
if "semantic-similarity" in _metadata:
headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"])
pillar_headers = _metadata.get("pillar_response_headers")
if isinstance(pillar_headers, dict):
headers.update(pillar_headers)
elif "pillar_flagged" in _metadata:
headers["x-pillar-flagged"] = str(_metadata["pillar_flagged"]).lower()
return headers
def add_guardrail_to_applied_guardrails_header(
request_data: Dict, guardrail_name: Optional[str]
):
if guardrail_name is None:
return
_metadata = request_data.get("metadata", None) or {}
if "applied_guardrails" in _metadata:
_metadata["applied_guardrails"].append(guardrail_name)
else:
_metadata["applied_guardrails"] = [guardrail_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_to_applied_policies_header(
request_data: Dict, policy_name: Optional[str]
):
"""
Add a policy name to the applied_policies list in request metadata.
This is used to track which policies were applied to a request,
similar to how applied_guardrails tracks guardrails.
"""
if policy_name is None:
return
_metadata = request_data.get("metadata", None) or {}
if "applied_policies" in _metadata:
if policy_name not in _metadata["applied_policies"]:
_metadata["applied_policies"].append(policy_name)
else:
_metadata["applied_policies"] = [policy_name]
# Ensure metadata is set back to request_data (important when metadata didn't exist)
request_data["metadata"] = _metadata
def add_policy_sources_to_metadata(request_data: Dict, policy_sources: Dict[str, str]):
"""
Store policy match reasons in metadata for x-litellm-policy-sources header.
Args:
request_data: The request data dict
policy_sources: Map of policy_name -> matched_via reason
"""
if not policy_sources:
return
_metadata = request_data.get("metadata", None) or {}
existing = _metadata.get("policy_sources", {})
if not isinstance(existing, dict):
existing = {}
existing.update(policy_sources)
_metadata["policy_sources"] = existing
request_data["metadata"] = _metadata
def add_guardrail_response_to_standard_logging_object(
litellm_logging_obj: Optional["LiteLLMLogging"],
guardrail_response: StandardLoggingGuardrailInformation,
):
if litellm_logging_obj is None:
return
standard_logging_object: Optional[
StandardLoggingPayload
] = litellm_logging_obj.model_call_details.get("standard_logging_object")
if standard_logging_object is None:
return
guardrail_information = standard_logging_object.get("guardrail_information", [])
if guardrail_information is None:
guardrail_information = []
guardrail_information.append(guardrail_response)
standard_logging_object["guardrail_information"] = guardrail_information
return standard_logging_object
def get_metadata_variable_name_from_kwargs(
kwargs: dict,
) -> Literal["metadata", "litellm_metadata"]:
"""
Helper to return what the "metadata" field should be called in the request data
- New endpoints return `litellm_metadata`
- Old endpoints return `metadata`
Context:
- LiteLLM used `metadata` as an internal field for storing metadata
- OpenAI then started using this field for their metadata
- LiteLLM is now moving to using `litellm_metadata` for our metadata
"""
return "litellm_metadata" if "litellm_metadata" in kwargs else "metadata"
def process_callback(
_callback: str, callback_type: str, environment_variables: dict
) -> dict:
"""Process a single callback and return its data with environment variables"""
env_vars = CustomLogger.get_callback_env_vars(_callback)
env_vars_dict: dict[str, str | None] = {}
for _var in env_vars:
env_variable = environment_variables.get(_var, None)
if env_variable is None:
env_vars_dict[_var] = None
else:
env_vars_dict[_var] = env_variable
return {"name": _callback, "variables": env_vars_dict, "type": callback_type}
def normalize_callback_names(callbacks: Iterable[Any]) -> List[Any]:
if callbacks is None:
return []
return [c.lower() if isinstance(c, str) else c for c in callbacks]

View File

@@ -0,0 +1,437 @@
from typing import Any, Dict, List, Optional, Type
from litellm._logging import verbose_proxy_logger
class CustomOpenAPISpec:
"""
Handler for customizing OpenAPI specifications with Pydantic models
for documentation purposes without runtime validation.
"""
CHAT_COMPLETION_PATHS = [
"/v1/chat/completions",
"/chat/completions",
"/engines/{model}/chat/completions",
"/openai/deployments/{model}/chat/completions",
]
EMBEDDING_PATHS = [
"/v1/embeddings",
"/embeddings",
"/engines/{model}/embeddings",
"/openai/deployments/{model}/embeddings",
]
RESPONSES_API_PATHS = ["/v1/responses", "/responses"]
@staticmethod
def get_pydantic_schema(model_class) -> Optional[Dict[str, Any]]:
"""
Get JSON schema from a Pydantic model, handling both v1 and v2 APIs.
Args:
model_class: Pydantic model class
Returns:
JSON schema dict or None if failed
"""
try:
# Try Pydantic v2 method first
return model_class.model_json_schema() # type: ignore
except AttributeError:
try:
# Fallback to Pydantic v1 method
return model_class.schema() # type: ignore
except AttributeError:
# If both methods fail, return None
return None
except Exception as e:
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
# Log the error and return None to skip schema generation for this model
verbose_proxy_logger.debug(
f"Failed to generate schema for {model_class}: {e}"
)
return None
@staticmethod
def add_schema_to_components(
openapi_schema: Dict[str, Any], schema_name: str, schema_def: Dict[str, Any]
) -> None:
"""
Add a schema definition to the OpenAPI components/schemas section.
Args:
openapi_schema: The OpenAPI schema dict to modify
schema_name: Name for the schema component
schema_def: The schema definition
"""
# Ensure components/schemas structure exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Add the schema
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, {schema_name: schema_def}
)
@staticmethod
def add_request_body_to_paths(
openapi_schema: Dict[str, Any], paths: List[str], schema_ref: str
) -> None:
"""
Add request body with expanded form fields for better Swagger UI display.
This keeps the request body but expands it to show individual fields in the UI.
Args:
openapi_schema: The OpenAPI schema dict to modify
paths: List of paths to update
schema_ref: Reference to the schema component (e.g., "#/components/schemas/ModelName")
"""
for path in paths:
if (
path in openapi_schema.get("paths", {})
and "post" in openapi_schema["paths"][path]
):
# Get the actual schema to extract ALL field definitions
schema_name = schema_ref.split("/")[
-1
] # Extract "ProxyChatCompletionRequest" from the ref
actual_schema = (
openapi_schema.get("components", {})
.get("schemas", {})
.get(schema_name, {})
)
schema_properties = actual_schema.get("properties", {})
required_fields = actual_schema.get("required", [])
# Extract $defs and add them to components/schemas
# This fixes Pydantic v2 $defs not being resolvable in Swagger/OpenAPI
if "$defs" in actual_schema:
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, actual_schema["$defs"]
)
# Create an expanded inline schema instead of just a $ref
# This makes Swagger UI show all individual fields in the request body editor
expanded_schema = {
"type": "object",
"required": required_fields,
"properties": {},
}
# Add all properties with their full definitions
for field_name, field_def in schema_properties.items():
expanded_field = CustomOpenAPISpec._expand_field_definition(
field_def
)
# Rewrite $defs references to use components/schemas instead
expanded_field = CustomOpenAPISpec._rewrite_defs_refs(
expanded_field
)
# Add a simple example for the messages field
if field_name == "messages":
expanded_field["example"] = [
{"role": "user", "content": "Hello, how are you?"}
]
expanded_schema["properties"][field_name] = expanded_field
# Set the request body with the expanded schema
openapi_schema["paths"][path]["post"]["requestBody"] = {
"required": True,
"content": {"application/json": {"schema": expanded_schema}},
}
# Keep any existing parameters (like path parameters) but remove conflicting query params
if "parameters" in openapi_schema["paths"][path]["post"]:
existing_params = openapi_schema["paths"][path]["post"][
"parameters"
]
# Only keep path parameters, remove query params that conflict with request body
filtered_params = [
param for param in existing_params if param.get("in") == "path"
]
openapi_schema["paths"][path]["post"][
"parameters"
] = filtered_params
@staticmethod
def _move_defs_to_components(
openapi_schema: Dict[str, Any], defs: Dict[str, Any]
) -> None:
"""
Move $defs from Pydantic v2 schema to OpenAPI components/schemas.
This makes the definitions resolvable in Swagger/OpenAPI viewers.
Args:
openapi_schema: The OpenAPI schema dict to modify
defs: The $defs dictionary from Pydantic schema
"""
if not defs:
return
# Ensure components/schemas exists
if "components" not in openapi_schema:
openapi_schema["components"] = {}
if "schemas" not in openapi_schema["components"]:
openapi_schema["components"]["schemas"] = {}
# Add each definition to components/schemas
for def_name, def_schema in defs.items():
# Recursively rewrite any nested $defs references within this definition
rewritten_def = CustomOpenAPISpec._rewrite_defs_refs(def_schema)
openapi_schema["components"]["schemas"][def_name] = rewritten_def
# If this definition also has $defs, process them recursively
if "$defs" in def_schema:
CustomOpenAPISpec._move_defs_to_components(
openapi_schema, def_schema["$defs"]
)
@staticmethod
def _rewrite_defs_refs(schema: Any) -> Any:
"""
Recursively rewrite $ref values from #/$defs/... to #/components/schemas/...
This converts Pydantic v2 references to OpenAPI-compatible references.
Args:
schema: Schema object to process (can be dict, list, or primitive)
Returns:
Schema with rewritten references
"""
if isinstance(schema, dict):
result = {}
for key, value in schema.items():
if (
key == "$ref"
and isinstance(value, str)
and value.startswith("#/$defs/")
):
# Rewrite the reference to use components/schemas
def_name = value.replace("#/$defs/", "")
result[key] = f"#/components/schemas/{def_name}"
elif key == "$defs":
# Remove $defs from the schema since they're moved to components
continue
else:
# Recursively process nested structures
result[key] = CustomOpenAPISpec._rewrite_defs_refs(value)
return result
elif isinstance(schema, list):
return [CustomOpenAPISpec._rewrite_defs_refs(item) for item in schema]
else:
return schema
@staticmethod
def _extract_field_schema(field_def: Dict[str, Any]) -> Dict[str, Any]:
"""
Extract a simple schema from a Pydantic field definition for parameter display.
Args:
field_def: Pydantic field definition
Returns:
Simplified schema for OpenAPI parameter
"""
# Handle simple types
if "type" in field_def:
return {"type": field_def["type"]}
# Handle anyOf (Optional fields in Pydantic v2)
if "anyOf" in field_def:
any_of = field_def["anyOf"]
# Find the non-null type
for option in any_of:
if option.get("type") != "null":
return option
# Fallback to string if all else fails
return {"type": "string"}
# Default fallback
return {"type": "string"}
@staticmethod
def _expand_field_definition(field_def: Dict[str, Any]) -> Dict[str, Any]:
"""
Expand a Pydantic field definition for inline use in OpenAPI schema.
This creates a full field definition that Swagger UI can render as individual form fields.
Args:
field_def: Pydantic field definition
Returns:
Expanded field definition for OpenAPI schema
"""
# Return the field definition as-is since Pydantic already provides proper schemas
return field_def.copy()
@staticmethod
def add_request_schema(
openapi_schema: Dict[str, Any],
model_class: Type,
schema_name: str,
paths: List[str],
operation_name: str,
) -> Dict[str, Any]:
"""
Generic method to add a request schema to OpenAPI specification.
Args:
openapi_schema: The OpenAPI schema dict to modify
model_class: The Pydantic model class to get schema from
schema_name: Name for the schema component
paths: List of paths to add the request body to
operation_name: Name of the operation for logging (e.g., "chat completion", "embedding")
Returns:
Modified OpenAPI schema
"""
try:
# Get the schema for the model class
request_schema = CustomOpenAPISpec.get_pydantic_schema(model_class)
# Only proceed if we successfully got the schema
if request_schema is not None:
# Add schema to components
CustomOpenAPISpec.add_schema_to_components(
openapi_schema, schema_name, request_schema
)
# Add request body to specified endpoints
CustomOpenAPISpec.add_request_body_to_paths(
openapi_schema, paths, f"#/components/schemas/{schema_name}"
)
verbose_proxy_logger.debug(
f"Successfully added {schema_name} schema to OpenAPI spec"
)
else:
verbose_proxy_logger.debug(f"Could not get schema for {schema_name}")
except Exception as e:
# If schema addition fails, continue without it
verbose_proxy_logger.debug(
f"Failed to add {operation_name} request schema: {str(e)}"
)
return openapi_schema
@staticmethod
def add_chat_completion_request_schema(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ProxyChatCompletionRequest schema to chat completion endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.proxy._types import ProxyChatCompletionRequest
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=ProxyChatCompletionRequest,
schema_name="ProxyChatCompletionRequest",
paths=CustomOpenAPISpec.CHAT_COMPLETION_PATHS,
operation_name="chat completion",
)
except ImportError as e:
verbose_proxy_logger.debug(
f"Failed to import ProxyChatCompletionRequest: {str(e)}"
)
return openapi_schema
@staticmethod
def add_embedding_request_schema(openapi_schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Add EmbeddingRequest schema to embedding endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.types.embedding import EmbeddingRequest
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=EmbeddingRequest,
schema_name="EmbeddingRequest",
paths=CustomOpenAPISpec.EMBEDDING_PATHS,
operation_name="embedding",
)
except ImportError as e:
verbose_proxy_logger.debug(f"Failed to import EmbeddingRequest: {str(e)}")
return openapi_schema
@staticmethod
def add_responses_api_request_schema(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add ResponsesAPIRequestParams schema to responses API endpoints for documentation.
This shows the request body in Swagger without runtime validation.
Args:
openapi_schema: The OpenAPI schema dict to modify
Returns:
Modified OpenAPI schema
"""
try:
from litellm.types.llms.openai import ResponsesAPIRequestParams
return CustomOpenAPISpec.add_request_schema(
openapi_schema=openapi_schema,
model_class=ResponsesAPIRequestParams,
schema_name="ResponsesAPIRequestParams",
paths=CustomOpenAPISpec.RESPONSES_API_PATHS,
operation_name="responses API",
)
except ImportError as e:
verbose_proxy_logger.debug(
f"Failed to import ResponsesAPIRequestParams: {str(e)}"
)
return openapi_schema
@staticmethod
def add_llm_api_request_schema_body(
openapi_schema: Dict[str, Any]
) -> Dict[str, Any]:
"""
Add LLM API request schema bodies to OpenAPI specification for documentation.
Args:
openapi_schema: The base OpenAPI schema
Returns:
OpenAPI schema with added request body schemas
"""
# Add chat completion request schema
openapi_schema = CustomOpenAPISpec.add_chat_completion_request_schema(
openapi_schema
)
# Add embedding request schema
openapi_schema = CustomOpenAPISpec.add_embedding_request_schema(openapi_schema)
# Add responses API request schema
openapi_schema = CustomOpenAPISpec.add_responses_api_request_schema(
openapi_schema
)
return openapi_schema

View File

@@ -0,0 +1,832 @@
# Start tracing memory allocations
import asyncio
import gc
import json
import os
import sys
import tracemalloc
from collections import Counter
from typing import Any, Dict, List, Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, Query
from litellm import get_secret_str
from litellm._logging import verbose_proxy_logger
from litellm.constants import PYTHON_GC_THRESHOLD
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
router = APIRouter()
# Configure garbage collection thresholds from environment variables
def configure_gc_thresholds():
"""Configure Python garbage collection thresholds from environment variables."""
gc_threshold_env = PYTHON_GC_THRESHOLD
if gc_threshold_env:
try:
# Parse threshold string like "1000,50,50"
thresholds = [int(x.strip()) for x in gc_threshold_env.split(",")]
if len(thresholds) == 3:
gc.set_threshold(*thresholds)
verbose_proxy_logger.info(f"GC thresholds set to: {thresholds}")
else:
verbose_proxy_logger.warning(
f"GC threshold not set: {gc_threshold_env}. Expected format: 'gen0,gen1,gen2'"
)
except ValueError as e:
verbose_proxy_logger.warning(
f"Failed to parse GC threshold: {gc_threshold_env}. Error: {e}"
)
# Log current thresholds
current_thresholds = gc.get_threshold()
verbose_proxy_logger.info(
f"Current GC thresholds: gen0={current_thresholds[0]}, gen1={current_thresholds[1]}, gen2={current_thresholds[2]}"
)
# Initialize GC configuration
configure_gc_thresholds()
@router.get("/debug/asyncio-tasks")
async def get_active_tasks_stats():
"""
Returns:
total_active_tasks: int
by_name: { coroutine_name: count }
"""
MAX_TASKS_TO_CHECK = 5000
# Gather all tasks in this event loop (including this endpoints own task).
all_tasks = asyncio.all_tasks()
# Filter out tasks that are already done.
active_tasks = [t for t in all_tasks if not t.done()]
# Count how many active tasks exist, grouped by coroutine function name.
counter = Counter()
for idx, task in enumerate(active_tasks):
# reasonable max circuit breaker
if idx >= MAX_TASKS_TO_CHECK:
break
coro = task.get_coro()
# Derive a humanreadable name from the coroutine:
name = (
getattr(coro, "__qualname__", None)
or getattr(coro, "__name__", None)
or repr(coro)
)
counter[name] += 1
return {
"total_active_tasks": len(active_tasks),
"by_name": dict(counter),
}
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
try:
import objgraph # type: ignore
print("growth of objects") # noqa
objgraph.show_growth()
print("\n\nMost common types") # noqa
objgraph.show_most_common_types()
roots = objgraph.get_leaking_objects()
print("\n\nLeaking objects") # noqa
objgraph.show_most_common_types(objects=roots)
except ImportError:
raise ImportError(
"objgraph not found. Please install objgraph to use this feature."
)
tracemalloc.start(10)
@router.get("/memory-usage", include_in_schema=False)
async def memory_usage():
# Take a snapshot of the current memory usage
snapshot = tracemalloc.take_snapshot()
top_stats = snapshot.statistics("lineno")
verbose_proxy_logger.debug("TOP STATS: %s", top_stats)
# Get the top 50 memory usage lines
top_50 = top_stats[:50]
result = []
for stat in top_50:
result.append(f"{stat.traceback.format(limit=10)}: {stat.size / 1024} KiB")
return {"top_50_memory_usage": result}
@router.get("/memory-usage-in-mem-cache", include_in_schema=False)
async def memory_usage_in_mem_cache(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
):
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
num_items_in_llm_router_cache = 0
else:
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
num_items_in_proxy_logging_obj_cache = len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
return {
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
"num_items_in_llm_router_cache": num_items_in_llm_router_cache,
"num_items_in_proxy_logging_obj_cache": num_items_in_proxy_logging_obj_cache,
}
@router.get("/memory-usage-in-mem-cache-items", include_in_schema=False)
async def memory_usage_in_mem_cache_items(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
):
# returns the size of all in-memory caches on the proxy server
"""
1. user_api_key_cache
2. router_cache
3. proxy_logging_cache
4. internal_usage_cache
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
if llm_router is None:
llm_router_in_memory_cache_dict = {}
llm_router_in_memory_ttl_dict = {}
else:
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
return {
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router_in_memory_cache_dict,
"llm_router_ttl": llm_router_in_memory_ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
}
@router.get("/debug/memory/summary", include_in_schema=False)
async def get_memory_summary(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
) -> Dict[str, Any]:
"""
Get simplified memory usage summary for the proxy.
Returns:
- worker_pid: Process ID
- status: Overall health based on memory usage
- memory: Process memory usage and RAM info
- caches: Cache item counts and descriptions
- garbage_collector: GC status and pending object counts
Example usage:
curl http://localhost:4000/debug/memory/summary -H "Authorization: Bearer sk-1234"
For detailed analysis, call GET /debug/memory/details
For cache management, use the cache management endpoints
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
)
# Get process memory info
process_memory = {}
health_status = "healthy"
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
memory_mb = memory_info.rss / (1024 * 1024)
memory_percent = process.memory_percent()
process_memory = {
"summary": f"{memory_mb:.1f} MB ({memory_percent:.1f}% of system memory)",
"ram_usage_mb": round(memory_mb, 2),
"system_memory_percent": round(memory_percent, 2),
}
# Check memory health status
if memory_percent > 80:
health_status = "critical"
elif memory_percent > 60:
health_status = "warning"
else:
health_status = "healthy"
except ImportError:
process_memory[
"error"
] = "Install psutil for memory monitoring: pip install psutil"
except Exception as e:
process_memory["error"] = str(e)
# Get cache information
caches: Dict[str, Any] = {}
total_cache_items = 0
try:
# User API key cache
user_cache_items = len(user_api_key_cache.in_memory_cache.cache_dict)
total_cache_items += user_cache_items
caches["user_api_keys"] = {
"count": user_cache_items,
"count_readable": f"{user_cache_items:,}",
"what_it_stores": "Validated API keys for faster authentication",
}
# Router cache
if llm_router is not None:
router_cache_items = len(llm_router.cache.in_memory_cache.cache_dict)
total_cache_items += router_cache_items
caches["llm_responses"] = {
"count": router_cache_items,
"count_readable": f"{router_cache_items:,}",
"what_it_stores": "LLM responses for identical requests",
}
# Proxy logging cache
logging_cache_items = len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
)
total_cache_items += logging_cache_items
caches["usage_tracking"] = {
"count": logging_cache_items,
"count_readable": f"{logging_cache_items:,}",
"what_it_stores": "Usage metrics before database write",
}
except Exception as e:
caches["error"] = str(e)
# Get garbage collector stats
gc_enabled = gc.isenabled()
objects_pending = gc.get_count()[0]
uncollectable = len(gc.garbage)
gc_info = {
"status": "enabled" if gc_enabled else "disabled",
"objects_awaiting_collection": objects_pending,
}
# Add warning if garbage collection issues detected
if uncollectable > 0:
gc_info[
"warning"
] = f"{uncollectable} uncollectable objects (possible memory leak)"
return {
"worker_pid": os.getpid(),
"status": health_status,
"memory": process_memory,
"caches": {
"total_items": total_cache_items,
"breakdown": caches,
},
"garbage_collector": gc_info,
}
def _get_gc_statistics() -> Dict[str, Any]:
"""Get garbage collector statistics."""
return {
"enabled": gc.isenabled(),
"thresholds": {
"generation_0": gc.get_threshold()[0],
"generation_1": gc.get_threshold()[1],
"generation_2": gc.get_threshold()[2],
"explanation": "Number of allocations before automatic collection for each generation",
},
"current_counts": {
"generation_0": gc.get_count()[0],
"generation_1": gc.get_count()[1],
"generation_2": gc.get_count()[2],
"explanation": "Current number of allocated objects in each generation",
},
"collection_history": [
{
"generation": i,
"total_collections": stat["collections"],
"total_collected": stat["collected"],
"uncollectable": stat["uncollectable"],
}
for i, stat in enumerate(gc.get_stats())
],
}
def _get_object_type_counts(top_n: int) -> Tuple[int, List[Dict[str, Any]]]:
"""Count objects by type and return total count and top N types."""
type_counts: Counter = Counter()
total_objects = 0
for obj in gc.get_objects():
total_objects += 1
obj_type = type(obj).__name__
type_counts[obj_type] += 1
top_object_types = [
{"type": obj_type, "count": count, "count_readable": f"{count:,}"}
for obj_type, count in type_counts.most_common(top_n)
]
return total_objects, top_object_types
def _get_uncollectable_objects_info() -> Dict[str, Any]:
"""Get information about uncollectable objects (potential memory leaks)."""
uncollectable = gc.garbage
return {
"count": len(uncollectable),
"sample_types": [type(obj).__name__ for obj in uncollectable[:10]],
"warning": "If count > 0, you may have reference cycles preventing garbage collection"
if len(uncollectable) > 0
else None,
}
def _get_cache_memory_stats(
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
) -> Dict[str, Any]:
"""Calculate memory usage for all caches."""
cache_stats: Dict[str, Any] = {}
try:
# User API key cache
user_cache_size = sys.getsizeof(user_api_key_cache.in_memory_cache.cache_dict)
user_ttl_size = sys.getsizeof(user_api_key_cache.in_memory_cache.ttl_dict)
cache_stats["user_api_key_cache"] = {
"num_items": len(user_api_key_cache.in_memory_cache.cache_dict),
"cache_dict_size_bytes": user_cache_size,
"ttl_dict_size_bytes": user_ttl_size,
"total_size_mb": round(
(user_cache_size + user_ttl_size) / (1024 * 1024), 2
),
}
# Router cache
if llm_router is not None:
router_cache_size = sys.getsizeof(
llm_router.cache.in_memory_cache.cache_dict
)
router_ttl_size = sys.getsizeof(llm_router.cache.in_memory_cache.ttl_dict)
cache_stats["llm_router_cache"] = {
"num_items": len(llm_router.cache.in_memory_cache.cache_dict),
"cache_dict_size_bytes": router_cache_size,
"ttl_dict_size_bytes": router_ttl_size,
"total_size_mb": round(
(router_cache_size + router_ttl_size) / (1024 * 1024), 2
),
}
# Proxy logging cache
logging_cache_size = sys.getsizeof(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
)
logging_ttl_size = sys.getsizeof(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict
)
cache_stats["proxy_logging_cache"] = {
"num_items": len(
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
),
"cache_dict_size_bytes": logging_cache_size,
"ttl_dict_size_bytes": logging_ttl_size,
"total_size_mb": round(
(logging_cache_size + logging_ttl_size) / (1024 * 1024), 2
),
}
# Redis cache info
if redis_usage_cache is not None:
cache_stats["redis_usage_cache"] = {
"enabled": True,
"cache_type": type(redis_usage_cache).__name__,
}
# Try to get Redis connection pool info if available
try:
if (
hasattr(redis_usage_cache, "redis_client")
and redis_usage_cache.redis_client
):
if hasattr(redis_usage_cache.redis_client, "connection_pool"):
pool_info = redis_usage_cache.redis_client.connection_pool # type: ignore
cache_stats["redis_usage_cache"]["connection_pool"] = {
"max_connections": pool_info.max_connections
if hasattr(pool_info, "max_connections")
else None,
"connection_class": pool_info.connection_class.__name__
if hasattr(pool_info, "connection_class")
else None,
}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting Redis pool info: {e}")
else:
cache_stats["redis_usage_cache"] = {"enabled": False}
except Exception as e:
verbose_proxy_logger.debug(f"Error calculating cache stats: {e}")
cache_stats["error"] = str(e)
return cache_stats
def _get_router_memory_stats(llm_router) -> Dict[str, Any]:
"""Get memory usage statistics for LiteLLM router."""
litellm_router_memory: Dict[str, Any] = {}
try:
if llm_router is not None:
# Model list memory size
if hasattr(llm_router, "model_list") and llm_router.model_list:
model_list_size = sys.getsizeof(llm_router.model_list)
litellm_router_memory["model_list"] = {
"num_models": len(llm_router.model_list),
"size_bytes": model_list_size,
"size_mb": round(model_list_size / (1024 * 1024), 4),
}
# Model names set
if hasattr(llm_router, "model_names") and llm_router.model_names:
model_names_size = sys.getsizeof(llm_router.model_names)
litellm_router_memory["model_names_set"] = {
"num_model_groups": len(llm_router.model_names),
"size_bytes": model_names_size,
"size_mb": round(model_names_size / (1024 * 1024), 4),
}
# Deployment names list
if hasattr(llm_router, "deployment_names") and llm_router.deployment_names:
deployment_names_size = sys.getsizeof(llm_router.deployment_names)
litellm_router_memory["deployment_names"] = {
"num_deployments": len(llm_router.deployment_names),
"size_bytes": deployment_names_size,
"size_mb": round(deployment_names_size / (1024 * 1024), 4),
}
# Deployment latency map
if (
hasattr(llm_router, "deployment_latency_map")
and llm_router.deployment_latency_map
):
latency_map_size = sys.getsizeof(llm_router.deployment_latency_map)
litellm_router_memory["deployment_latency_map"] = {
"num_tracked_deployments": len(llm_router.deployment_latency_map),
"size_bytes": latency_map_size,
"size_mb": round(latency_map_size / (1024 * 1024), 4),
}
# Fallback configuration
if hasattr(llm_router, "fallbacks") and llm_router.fallbacks:
fallbacks_size = sys.getsizeof(llm_router.fallbacks)
litellm_router_memory["fallbacks"] = {
"num_fallback_configs": len(llm_router.fallbacks),
"size_bytes": fallbacks_size,
"size_mb": round(fallbacks_size / (1024 * 1024), 4),
}
# Total router object size
router_obj_size = sys.getsizeof(llm_router)
litellm_router_memory["router_object"] = {
"size_bytes": router_obj_size,
"size_mb": round(router_obj_size / (1024 * 1024), 4),
}
else:
litellm_router_memory = {"note": "Router not initialized"}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting router memory info: {e}")
litellm_router_memory = {"error": str(e)}
return litellm_router_memory
def _get_process_memory_info(
worker_pid: int, include_process_info: bool
) -> Optional[Dict[str, Any]]:
"""Get process-level memory information using psutil."""
if not include_process_info:
return None
try:
import psutil
process = psutil.Process()
memory_info = process.memory_info()
ram_usage_mb = round(memory_info.rss / (1024 * 1024), 2)
virtual_memory_mb = round(memory_info.vms / (1024 * 1024), 2)
memory_percent = round(process.memory_percent(), 2)
return {
"pid": worker_pid,
"summary": f"Worker PID {worker_pid} using {ram_usage_mb:.1f} MB of RAM ({memory_percent:.1f}% of system memory)",
"ram_usage": {
"megabytes": ram_usage_mb,
"description": "Actual physical RAM used by this process",
},
"virtual_memory": {
"megabytes": virtual_memory_mb,
"description": "Total virtual memory allocated (includes swapped memory)",
},
"system_memory_percent": {
"percent": memory_percent,
"description": "Percentage of total system RAM being used",
},
"open_file_handles": {
"count": process.num_fds()
if hasattr(process, "num_fds")
else "N/A (Windows)",
"description": "Number of open file descriptors/handles",
},
"threads": {
"count": process.num_threads(),
"description": "Number of active threads in this process",
},
}
except ImportError:
return {
"pid": worker_pid,
"error": "psutil not installed. Install with: pip install psutil",
}
except Exception as e:
verbose_proxy_logger.debug(f"Error getting process info: {e}")
return {"pid": worker_pid, "error": str(e)}
@router.get("/debug/memory/details", include_in_schema=False)
async def get_memory_details(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
top_n: int = Query(20, description="Number of top object types to return"),
include_process_info: bool = Query(True, description="Include process memory info"),
) -> Dict[str, Any]:
"""
Get detailed memory diagnostics for deep debugging.
Returns:
- worker_pid: Process ID
- process_memory: RAM usage, virtual memory, file handles, threads
- garbage_collector: GC thresholds, counts, collection history
- objects: Total tracked objects and top object types
- uncollectable: Objects that can't be garbage collected (potential leaks)
- cache_memory: Memory usage of user_api_key, router, and logging caches
- router_memory: Memory usage of router components (model_list, deployment_names, etc.)
Query Parameters:
- top_n: Number of top object types to return (default: 20)
- include_process_info: Include process-level memory info using psutil (default: true)
Example usage:
curl "http://localhost:4000/debug/memory/details?top_n=30" -H "Authorization: Bearer sk-1234"
All memory sizes are reported in both bytes and MB.
"""
from litellm.proxy.proxy_server import (
llm_router,
proxy_logging_obj,
user_api_key_cache,
redis_usage_cache,
)
worker_pid = os.getpid()
# Collect all diagnostics using helper functions
gc_stats = _get_gc_statistics()
total_objects, top_object_types = _get_object_type_counts(top_n)
uncollectable_info = _get_uncollectable_objects_info()
cache_stats = _get_cache_memory_stats(
user_api_key_cache, llm_router, proxy_logging_obj, redis_usage_cache
)
litellm_router_memory = _get_router_memory_stats(llm_router)
process_info = _get_process_memory_info(worker_pid, include_process_info)
return {
"worker_pid": worker_pid,
"process_memory": process_info,
"garbage_collector": gc_stats,
"objects": {
"total_tracked": total_objects,
"total_tracked_readable": f"{total_objects:,}",
"top_types": top_object_types,
},
"uncollectable": uncollectable_info,
"cache_memory": cache_stats,
"router_memory": litellm_router_memory,
}
@router.post("/debug/memory/gc/configure", include_in_schema=False)
async def configure_gc_thresholds_endpoint(
_: UserAPIKeyAuth = Depends(user_api_key_auth),
generation_0: int = Query(700, description="Generation 0 threshold (default: 700)"),
generation_1: int = Query(10, description="Generation 1 threshold (default: 10)"),
generation_2: int = Query(10, description="Generation 2 threshold (default: 10)"),
) -> Dict[str, Any]:
"""
Configure Python garbage collection thresholds.
Lower thresholds mean more frequent GC cycles (less memory, more CPU overhead).
Higher thresholds mean less frequent GC cycles (more memory, less CPU overhead).
Returns:
- message: Confirmation message
- previous_thresholds: Old threshold values
- new_thresholds: New threshold values
- objects_awaiting_collection: Current object count in gen-0
- tip: Hint about when next collection will occur
Query Parameters:
- generation_0: Number of allocations before gen-0 collection (default: 700)
- generation_1: Number of gen-0 collections before gen-1 collection (default: 10)
- generation_2: Number of gen-1 collections before gen-2 collection (default: 10)
Example for more aggressive collection:
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=500" -H "Authorization: Bearer sk-1234"
Example for less aggressive collection:
curl -X POST "http://localhost:4000/debug/memory/gc/configure?generation_0=1000" -H "Authorization: Bearer sk-1234"
Monitor memory usage with GET /debug/memory/summary after changes.
"""
# Get current thresholds for logging
old_thresholds = gc.get_threshold()
# Set new thresholds with error handling
try:
gc.set_threshold(generation_0, generation_1, generation_2)
verbose_proxy_logger.info(
f"GC thresholds updated from {old_thresholds} to "
f"({generation_0}, {generation_1}, {generation_2})"
)
except Exception as e:
verbose_proxy_logger.error(f"Failed to set GC thresholds: {e}")
raise HTTPException(
status_code=500, detail=f"Failed to set GC thresholds: {str(e)}"
)
# Get current object count to show immediate impact
current_count = gc.get_count()[0]
return {
"message": "GC thresholds updated",
"previous_thresholds": f"{old_thresholds[0]}, {old_thresholds[1]}, {old_thresholds[2]}",
"new_thresholds": f"{generation_0}, {generation_1}, {generation_2}",
"objects_awaiting_collection": current_count,
"tip": f"Next collection will run after {generation_0 - current_count} more allocations",
}
@router.get("/otel-spans", include_in_schema=False)
async def get_otel_spans():
from litellm.proxy.proxy_server import open_telemetry_logger
if open_telemetry_logger is None:
return {
"otel_spans": [],
"spans_grouped_by_parent": {},
"most_recent_parent": None,
}
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
if hasattr(otel_exporter, "get_finished_spans"):
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
else:
recorded_spans = []
print("Spans: ", recorded_spans) # noqa
most_recent_parent = None
most_recent_start_time = 1000000
spans_grouped_by_parent = {}
for span in recorded_spans:
if span.parent is not None:
parent_trace_id = span.parent.trace_id
if parent_trace_id not in spans_grouped_by_parent:
spans_grouped_by_parent[parent_trace_id] = []
spans_grouped_by_parent[parent_trace_id].append(span.name)
# check time of span
if span.start_time > most_recent_start_time:
most_recent_parent = parent_trace_id
most_recent_start_time = span.start_time
# these are otel spans - get the span name
span_names = [span.name for span in recorded_spans]
return {
"otel_spans": span_names,
"spans_grouped_by_parent": spans_grouped_by_parent,
"most_recent_parent": most_recent_parent,
}
# Helper functions for debugging
def init_verbose_loggers():
try:
worker_config = get_secret_str("WORKER_CONFIG")
# if not, assume it's a json string
if worker_config is None:
return
if os.path.isfile(worker_config):
return
_settings = json.loads(worker_config)
if not isinstance(_settings, dict):
return
debug = _settings.get("debug", None)
detailed_debug = _settings.get("detailed_debug", None)
if debug is True: # this needs to be first, so users can see Router init debugg
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_logger.setLevel(level=logging.INFO) # sets package logs to info
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug is True:
import logging
from litellm._logging import (
verbose_logger,
verbose_proxy_logger,
verbose_router_logger,
)
verbose_logger.setLevel(level=logging.DEBUG) # set package log to debug
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to debug
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
elif debug is False and detailed_debug is False:
# users can control proxy debugging using env variable = 'LITELLM_LOG'
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting is not None:
if litellm_log_setting.upper() == "INFO":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
# this must ALWAYS remain logging.INFO, DO NOT MODIFY THIS
verbose_router_logger.setLevel(
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.INFO
) # set proxy logs to info
elif litellm_log_setting.upper() == "DEBUG":
import logging
from litellm._logging import (
verbose_proxy_logger,
verbose_router_logger,
)
verbose_router_logger.setLevel(
level=logging.DEBUG
) # set router logs to info
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
except Exception as e:
import logging
logging.warning(f"Failed to init verbose loggers: {str(e)}")

View File

@@ -0,0 +1,122 @@
import base64
import os
from typing import Literal, Optional
from litellm._logging import verbose_proxy_logger
def _get_salt_key():
from litellm.proxy.proxy_server import master_key
salt_key = os.getenv("LITELLM_SALT_KEY", None)
if salt_key is None:
salt_key = master_key
return salt_key
def encrypt_value_helper(value: str, new_encryption_key: Optional[str] = None):
signing_key = new_encryption_key or _get_salt_key()
try:
if isinstance(value, str):
encrypted_value = encrypt_value(value=value, signing_key=signing_key) # type: ignore
# Use urlsafe_b64encode for URL-safe base64 encoding (replaces + with - and / with _)
encrypted_value = base64.urlsafe_b64encode(encrypted_value).decode("utf-8")
return encrypted_value
verbose_proxy_logger.debug(
f"Invalid value type passed to encrypt_value: {type(value)} for Value: {value}\n Value must be a string"
)
# if it's not a string - do not encrypt it and return the value
return value
except Exception as e:
raise e
def decrypt_value_helper(
value: str,
key: str, # this is just for debug purposes, showing the k,v pair that's invalid. not a signing key.
exception_type: Literal["debug", "error"] = "error",
return_original_value: bool = False,
):
signing_key = _get_salt_key()
try:
if isinstance(value, str):
# Try URL-safe base64 decoding first (new format)
# Fall back to standard base64 decoding for backwards compatibility (old format)
try:
decoded_b64 = base64.urlsafe_b64decode(value)
except Exception:
# If URL-safe decoding fails, try standard base64 decoding for backwards compatibility
decoded_b64 = base64.b64decode(value)
value = decrypt_value(value=decoded_b64, signing_key=signing_key) # type: ignore
return value
# if it's not str - do not decrypt it, return the value
return value
except Exception as e:
error_message = f"Error decrypting value for key: {key}, Did your master_key/salt key change recently? \nError: {str(e)}\nSet permanent salt key - https://docs.litellm.ai/docs/proxy/prod#5-set-litellm-salt-key"
if exception_type == "debug":
verbose_proxy_logger.debug(error_message)
return value if return_original_value else None
verbose_proxy_logger.debug(
f"Unable to decrypt value={value} for key: {key}, returning None"
)
if return_original_value:
return value
else:
verbose_proxy_logger.exception(error_message)
# [Non-Blocking Exception. - this should not block decrypting other values]
return None
def encrypt_value(value: str, signing_key: str):
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# encode message #
value_bytes = value.encode("utf-8")
encrypted = box.encrypt(value_bytes)
return encrypted
def decrypt_value(value: bytes, signing_key: str) -> str:
import hashlib
import nacl.secret
import nacl.utils
# get 32 byte master key #
hash_object = hashlib.sha256(signing_key.encode())
hash_bytes = hash_object.digest()
# initialize secret box #
box = nacl.secret.SecretBox(hash_bytes)
# Convert the bytes object to a string
try:
if len(value) == 0:
return ""
plaintext = box.decrypt(value)
plaintext = plaintext.decode("utf-8") # type: ignore
return plaintext # type: ignore
except Exception as e:
raise e

View File

@@ -0,0 +1,82 @@
"""
Utility class for getting routes from a FastAPI app.
"""
from typing import Any, Dict, List, Optional
from starlette.routing import BaseRoute
from litellm._logging import verbose_logger
class GetRoutes:
@staticmethod
def get_app_routes(
route: BaseRoute,
endpoint_route: Any,
) -> List[Dict[str, Any]]:
"""
Get routes for a regular route.
"""
routes: List[Dict[str, Any]] = []
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
endpoint_route.__name__ if getattr(route, "endpoint", None) else None
),
}
routes.append(route_info)
return routes
@staticmethod
def get_routes_for_mounted_app(
route: BaseRoute,
) -> List[Dict[str, Any]]:
"""
Get routes for a mounted sub-application.
"""
routes: List[Dict[str, Any]] = []
mount_path = getattr(route, "path", "")
sub_app = getattr(route, "app", None)
if sub_app and hasattr(sub_app, "routes"):
for sub_route in sub_app.routes:
# Get endpoint - either from endpoint attribute or app attribute
endpoint_func = getattr(sub_route, "endpoint", None) or getattr(
sub_route, "app", None
)
if endpoint_func is not None:
sub_route_path = getattr(sub_route, "path", "")
full_path = mount_path.rstrip("/") + sub_route_path
route_info = {
"path": full_path,
"methods": getattr(sub_route, "methods", ["GET", "POST"]),
"name": getattr(sub_route, "name", None),
"endpoint": GetRoutes._safe_get_endpoint_name(endpoint_func),
"mounted_app": True,
}
routes.append(route_info)
return routes
@staticmethod
def _safe_get_endpoint_name(endpoint_function: Any) -> Optional[str]:
"""
Safely get the name of the endpoint function.
"""
try:
if hasattr(endpoint_function, "__name__"):
return getattr(endpoint_function, "__name__")
elif hasattr(endpoint_function, "__class__") and hasattr(
endpoint_function.__class__, "__name__"
):
return getattr(endpoint_function.__class__, "__name__")
else:
return None
except Exception:
verbose_logger.exception(
f"Error getting endpoint name for route: {endpoint_function}"
)
return None

View File

@@ -0,0 +1,207 @@
from litellm.proxy.common_utils.banner import LITELLM_BANNER
def render_cli_sso_success_page() -> str:
"""
Renders the CLI SSO authentication success page with minimal styling
Returns:
str: HTML content for the success page
"""
html_content = f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>CLI Authentication Successful - LiteLLM</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #1e293b;
}}
.container {{
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
width: 450px;
max-width: 100%;
text-align: center;
}}
.logo-container {{
margin-bottom: 20px;
}}
.logo {{
font-size: 24px;
font-weight: 600;
color: #1e293b;
}}
h1 {{
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
}}
.subtitle {{
color: #64748b;
margin: 0 0 30px;
font-size: 16px;
}}
.banner {{
background-color: #f8fafc;
color: #334155;
font-family: 'Courier New', Consolas, monospace;
font-size: 10px;
line-height: 1.1;
white-space: pre;
padding: 20px;
border-radius: 6px;
margin: 20px 0;
text-align: center;
border: 1px solid #e2e8f0;
overflow-x: auto;
}}
.success-box {{
background-color: #f8fafc;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border: 1px solid #e2e8f0;
}}
.success-header {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 12px;
color: #1e293b;
font-weight: 600;
font-size: 16px;
}}
.success-header svg {{
margin-right: 8px;
}}
.success-box p {{
color: #64748b;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
.instructions {{
background-color: #f8fafc;
border-radius: 6px;
padding: 20px;
margin-bottom: 20px;
border: 1px solid #e2e8f0;
}}
.instructions-header {{
display: flex;
align-items: center;
justify-content: center;
margin-bottom: 12px;
color: #1e293b;
font-weight: 600;
font-size: 16px;
}}
.instructions-header svg {{
margin-right: 8px;
}}
.instructions p {{
color: #64748b;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
.countdown {{
color: #64748b;
font-size: 14px;
font-weight: 500;
padding: 12px;
background-color: #f8fafc;
border-radius: 6px;
border: 1px solid #e2e8f0;
}}
</style>
</head>
<body>
<div class="container">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<div class="banner">{LITELLM_BANNER}</div>
<h1>Authentication Successful!</h1>
<p class="subtitle">Your CLI authentication is complete.</p>
<div class="success-box">
<div class="success-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M9 12l2 2 4-4"></path>
<circle cx="12" cy="12" r="10"></circle>
</svg>
CLI Authentication Complete
</div>
<p>Your LiteLLM CLI has been successfully authenticated and is ready to use.</p>
</div>
<div class="instructions">
<div class="instructions-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
Next Steps
</div>
<p>Return to your terminal - the CLI will automatically detect the successful authentication.</p>
<p>You can now use LiteLLM CLI commands with your authenticated session.</p>
</div>
<div class="countdown" id="countdown">This window will close in 3 seconds...</div>
</div>
<script>
let seconds = 3;
const countdownElement = document.getElementById('countdown');
const countdown = setInterval(function() {{
seconds--;
if (seconds > 0) {{
countdownElement.textContent = `This window will close in ${{seconds}} second${{seconds === 1 ? '' : 's'}}...`;
}} else {{
countdownElement.textContent = 'Closing...';
clearInterval(countdown);
window.close();
}}
}}, 1000);
</script>
</body>
</html>
"""
return html_content

View File

@@ -0,0 +1,284 @@
# JWT display template for SSO debug callback
jwt_display_template = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM SSO Debug - JWT Information</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}
.container {
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 800px;
max-width: 100%;
}
.logo-container {
text-align: center;
margin-bottom: 30px;
}
.logo {
font-size: 24px;
font-weight: 600;
color: #1e293b;
}
h2 {
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}
.subtitle {
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}
.info-box {
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}
.success-box {
background-color: #f0fdf4;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #16a34a;
}
.info-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}
.success-header {
display: flex;
align-items: center;
margin-bottom: 12px;
color: #166534;
font-weight: 600;
font-size: 16px;
}
.info-header svg, .success-header svg {
margin-right: 8px;
}
.data-container {
margin-top: 20px;
}
.data-row {
display: flex;
border-bottom: 1px solid #e2e8f0;
padding: 12px 0;
}
.data-row:last-child {
border-bottom: none;
}
.data-label {
font-weight: 500;
color: #334155;
width: 180px;
flex-shrink: 0;
}
.data-value {
color: #475569;
word-break: break-all;
}
.jwt-container {
background-color: #f8fafc;
border-radius: 6px;
padding: 15px;
margin-top: 20px;
overflow-x: auto;
border: 1px solid #e2e8f0;
}
.jwt-text {
font-family: monospace;
white-space: pre-wrap;
word-break: break-all;
margin: 0;
color: #334155;
}
.back-button {
display: inline-block;
background-color: #6466E9;
color: #fff;
text-decoration: none;
padding: 10px 16px;
border-radius: 6px;
font-weight: 500;
margin-top: 20px;
text-align: center;
}
.back-button:hover {
background-color: #4138C2;
text-decoration: none;
}
.buttons {
display: flex;
gap: 10px;
margin-top: 20px;
}
.copy-button {
background-color: #e2e8f0;
color: #334155;
border: none;
padding: 8px 12px;
border-radius: 4px;
cursor: pointer;
font-size: 14px;
display: flex;
align-items: center;
}
.copy-button:hover {
background-color: #cbd5e1;
}
.copy-button svg {
margin-right: 6px;
}
</style>
</head>
<body>
<div class="container">
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>SSO Debug Information</h2>
<p class="subtitle">Results from the SSO authentication process.</p>
<div class="success-box">
<div class="success-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<path d="M22 11.08V12a10 10 0 1 1-5.93-9.14"></path>
<polyline points="22 4 12 14.01 9 11.01"></polyline>
</svg>
Authentication Successful
</div>
<p>The SSO authentication completed successfully. Below is the information returned by the provider.</p>
</div>
<div class="data-container" id="userData">
<!-- Data will be inserted here by JavaScript -->
</div>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
JSON Representation
</div>
<div class="jwt-container">
<pre class="jwt-text" id="jsonData">Loading...</pre>
</div>
<div class="buttons">
<button class="copy-button" onclick="copyToClipboard('jsonData')">
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<rect x="9" y="9" width="13" height="13" rx="2" ry="2"></rect>
<path d="M5 15H4a2 2 0 0 1-2-2V4a2 2 0 0 1 2-2h9a2 2 0 0 1 2 2v1"></path>
</svg>
Copy to Clipboard
</button>
</div>
</div>
<a href="/sso/debug/login" class="back-button">
Try Another SSO Login
</a>
</div>
<script>
// This will be populated with the actual data from the server
const userData = SSO_DATA;
function renderUserData() {
const container = document.getElementById('userData');
const jsonDisplay = document.getElementById('jsonData');
// Format JSON with indentation for display
jsonDisplay.textContent = JSON.stringify(userData, null, 2);
// Clear container
container.innerHTML = '';
// Add each key-value pair to the UI
for (const [key, value] of Object.entries(userData)) {
if (typeof value !== 'object' || value === null) {
const row = document.createElement('div');
row.className = 'data-row';
const label = document.createElement('div');
label.className = 'data-label';
label.textContent = key;
const dataValue = document.createElement('div');
dataValue.className = 'data-value';
dataValue.textContent = value !== null ? value : 'null';
row.appendChild(label);
row.appendChild(dataValue);
container.appendChild(row);
}
}
}
function copyToClipboard(elementId) {
const text = document.getElementById(elementId).textContent;
navigator.clipboard.writeText(text).then(() => {
alert('Copied to clipboard!');
}).catch(err => {
console.error('Could not copy text: ', err);
});
}
// Render the data when the page loads
document.addEventListener('DOMContentLoaded', renderUserData);
</script>
</body>
</html>
"""

View File

@@ -0,0 +1,269 @@
import os
from litellm.proxy.utils import get_custom_url
url_to_redirect_to = os.getenv("PROXY_BASE_URL", "")
server_root_path = os.getenv("SERVER_ROOT_PATH", "")
if server_root_path != "":
url_to_redirect_to += server_root_path
url_to_redirect_to += "/login"
new_ui_login_url = get_custom_url("", "ui/login")
def build_ui_login_form(show_deprecation_banner: bool = False) -> str:
banner_html = (
f"""
<div class="deprecation-banner">
<strong>Deprecated:</strong> Logging in with username and password on this page is deprecated.
Please use the <a href="{new_ui_login_url}">new login page</a> instead.
This page will be dedicated to signing in via SSO in the future.
</div>
"""
if show_deprecation_banner
else ""
)
return f"""
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>LiteLLM Login</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
background-color: #f8fafc;
margin: 0;
padding: 20px;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
color: #333;
}}
form {{
background-color: #fff;
padding: 40px;
border-radius: 8px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1);
width: 450px;
max-width: 100%;
}}
.logo-container {{
text-align: center;
margin-bottom: 30px;
}}
.logo {{
font-size: 24px;
font-weight: 600;
color: #1e293b;
}}
h2 {{
margin: 0 0 10px;
color: #1e293b;
font-size: 28px;
font-weight: 600;
text-align: center;
}}
.subtitle {{
color: #64748b;
margin: 0 0 20px;
font-size: 16px;
text-align: center;
}}
.info-box {{
background-color: #f1f5f9;
border-radius: 6px;
padding: 20px;
margin-bottom: 30px;
border-left: 4px solid #2563eb;
}}
.info-header {{
display: flex;
align-items: center;
margin-bottom: 12px;
color: #1e40af;
font-weight: 600;
font-size: 16px;
}}
.info-header svg {{
margin-right: 8px;
}}
.info-box p {{
color: #475569;
margin: 8px 0;
line-height: 1.5;
font-size: 14px;
}}
label {{
display: block;
margin-bottom: 8px;
font-weight: 500;
color: #334155;
font-size: 14px;
}}
.required {{
color: #dc2626;
margin-left: 2px;
}}
input[type="text"],
input[type="password"] {{
width: 100%;
padding: 10px 14px;
margin-bottom: 20px;
box-sizing: border-box;
border: 1px solid #e2e8f0;
border-radius: 6px;
font-size: 15px;
color: #1e293b;
background-color: #fff;
transition: border-color 0.2s, box-shadow 0.2s;
}}
input[type="text"]:focus,
input[type="password"]:focus {{
outline: none;
border-color: #3b82f6;
box-shadow: 0 0 0 2px rgba(59, 130, 246, 0.2);
}}
.toggle-password {{
display: flex;
align-items: center;
margin-top: -15px;
margin-bottom: 20px;
}}
.toggle-password input[type="checkbox"] {{
margin-right: 8px;
vertical-align: middle;
width: 16px;
height: 16px;
}}
.toggle-password label {{
margin-bottom: 0;
font-size: 14px;
cursor: pointer;
line-height: 1;
}}
input[type="submit"] {{
background-color: #6466E9;
color: #fff;
cursor: pointer;
font-weight: 500;
border: none;
padding: 10px 16px;
transition: background-color 0.2s;
border-radius: 6px;
margin-top: 10px;
font-size: 14px;
width: 100%;
}}
input[type="submit"]:hover {{
background-color: #4138C2;
}}
a {{
color: #3b82f6;
text-decoration: none;
}}
a:hover {{
text-decoration: underline;
}}
code {{
background-color: #f1f5f9;
padding: 2px 4px;
border-radius: 4px;
font-family: monospace;
font-size: 13px;
color: #334155;
}}
.help-text {{
color: #64748b;
font-size: 14px;
margin-top: -12px;
margin-bottom: 20px;
}}
.deprecation-banner {{
background-color: #fee2e2;
border: 1px solid #ef4444;
color: #991b1b;
padding: 14px 16px;
border-radius: 6px;
margin-bottom: 20px;
font-size: 14px;
line-height: 1.5;
}}
.deprecation-banner a {{
color: #991b1b;
font-weight: 600;
text-decoration: underline;
}}
</style>
</head>
<body>
<form action="{url_to_redirect_to}" method="post">
{banner_html}
<div class="logo-container">
<div class="logo">
🚅 LiteLLM
</div>
</div>
<h2>Login</h2>
<p class="subtitle">Access your LiteLLM Admin UI.</p>
<div class="info-box">
<div class="info-header">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round">
<circle cx="12" cy="12" r="10"></circle>
<line x1="12" y1="16" x2="12" y2="12"></line>
<line x1="12" y1="8" x2="12.01" y2="8"></line>
</svg>
Default Credentials
</div>
<p>By default, Username is <code>admin</code> and Password is your set LiteLLM Proxy <code>MASTER_KEY</code>.</p>
<p>Need to set UI credentials or SSO? <a href="https://docs.litellm.ai/docs/proxy/ui" target="_blank">Check the documentation</a>.</p>
</div>
<label for="username">Username<span class="required">*</span></label>
<input type="text" id="username" name="username" required placeholder="Enter your username" autocomplete="username">
<label for="password">Password<span class="required">*</span></label>
<input type="password" id="password" name="password" required placeholder="Enter your password" autocomplete="current-password">
<div class="toggle-password">
<input type="checkbox" id="show-password" onclick="togglePasswordVisibility()">
<label for="show-password">Show password</label>
</div>
<input type="submit" value="Login">
</form>
<script>
function togglePasswordVisibility() {{
var passwordField = document.getElementById("password");
passwordField.type = passwordField.type === "password" ? "text" : "password";
}}
</script>
</body>
</html>
"""
html_form = build_ui_login_form(show_deprecation_banner=True)

View File

@@ -0,0 +1,522 @@
import json
import re
from typing import Any, Collection, Dict, List, Optional
import orjson
from fastapi import Request, UploadFile, status
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ProxyException
from litellm.proxy.common_utils.callback_utils import (
get_metadata_variable_name_from_kwargs,
)
from litellm.types.router import Deployment
async def _read_request_body(request: Optional[Request]) -> Dict:
"""
Safely read the request body and parse it as JSON.
Parameters:
- request: The request object to read the body from
Returns:
- dict: Parsed request data as a dictionary or an empty dictionary if parsing fails
"""
try:
if request is None:
return {}
# Check if we already read and parsed the body
_cached_request_body: Optional[dict] = _safe_get_request_parsed_body(
request=request
)
if _cached_request_body is not None:
return _cached_request_body
_request_headers: dict = _safe_get_request_headers(request=request)
content_type = _request_headers.get("content-type", "")
if "form" in content_type:
parsed_body = dict(await request.form())
if "metadata" in parsed_body and isinstance(parsed_body["metadata"], str):
parsed_body["metadata"] = json.loads(parsed_body["metadata"])
else:
# Read the request body
body = await request.body()
# Return empty dict if body is empty or None
if not body:
parsed_body = {}
else:
try:
parsed_body = orjson.loads(body)
except orjson.JSONDecodeError as e:
# First try the standard json module which is more forgiving
# First decode bytes to string if needed
body_str = body.decode("utf-8") if isinstance(body, bytes) else body
# Replace invalid surrogate pairs
# This regex finds incomplete surrogate pairs
body_str = re.sub(
r"[\uD800-\uDBFF](?![\uDC00-\uDFFF])", "", body_str
)
# This regex finds low surrogates without high surrogates
body_str = re.sub(
r"(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]", "", body_str
)
try:
parsed_body = json.loads(body_str)
except json.JSONDecodeError:
# If both orjson and json.loads fail, throw a proper error
verbose_proxy_logger.error(
f"Invalid JSON payload received: {str(e)}"
)
raise ProxyException(
message=f"Invalid JSON payload: {str(e)}",
type="invalid_request_error",
param="request_body",
code=status.HTTP_400_BAD_REQUEST,
)
# Cache the parsed result
_safe_set_request_parsed_body(request=request, parsed_body=parsed_body)
return parsed_body
except (json.JSONDecodeError, orjson.JSONDecodeError, ProxyException) as e:
# Re-raise ProxyException as-is
verbose_proxy_logger.error(f"Invalid JSON payload received: {str(e)}")
raise
except Exception as e:
# Catch unexpected errors to avoid crashes
verbose_proxy_logger.exception(
"Unexpected error reading request body - {}".format(e)
)
return {}
def _safe_get_request_parsed_body(request: Optional[Request]) -> Optional[dict]:
if request is None:
return None
if (
hasattr(request, "scope")
and "parsed_body" in request.scope
and isinstance(request.scope["parsed_body"], tuple)
):
accepted_keys, parsed_body = request.scope["parsed_body"]
return {key: parsed_body[key] for key in accepted_keys}
return None
def _safe_get_request_query_params(request: Optional[Request]) -> Dict:
if request is None:
return {}
try:
if hasattr(request, "query_params"):
return dict(request.query_params)
return {}
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request query params - {}".format(e)
)
return {}
def _safe_set_request_parsed_body(
request: Optional[Request],
parsed_body: dict,
) -> None:
try:
if request is None:
return
request.scope["parsed_body"] = (tuple(parsed_body.keys()), parsed_body)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error setting request parsed body - {}".format(e)
)
def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers.
Caches the result on request.state to avoid re-creating dict(request.headers) per call.
Warning: Callers must NOT mutate the returned dict — it is shared across
all callers within the same request via the cache.
"""
if request is None:
return {}
state = getattr(request, "state", None)
cached = getattr(state, "_cached_headers", None)
if isinstance(cached, dict):
return cached
if cached is not None:
verbose_proxy_logger.debug(
"Unexpected cached request headers type - {}".format(type(cached))
)
try:
headers = dict(request.headers)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
headers = {}
try:
if state is not None:
state._cached_headers = headers
except Exception:
pass # request.state may not be available in all contexts
return headers
def check_file_size_under_limit(
request_data: dict,
file: UploadFile,
router_model_names: Collection[str],
) -> bool:
"""
Check if any files passed in request are under max_file_size_mb
Returns True -> when file size is under max_file_size_mb limit
Raises ProxyException -> when file size is over max_file_size_mb limit or not a premium_user
"""
from litellm.proxy.proxy_server import (
CommonProxyErrors,
ProxyException,
llm_router,
premium_user,
)
file_contents_size = file.size or 0
file_content_size_in_mb = file_contents_size / (1024 * 1024)
if "metadata" not in request_data:
request_data["metadata"] = {}
request_data["metadata"]["file_size_in_mb"] = file_content_size_in_mb
max_file_size_mb = None
if llm_router is not None and request_data["model"] in router_model_names:
try:
deployment: Optional[
Deployment
] = llm_router.get_deployment_by_model_group_name(
model_group_name=request_data["model"]
)
if (
deployment
and deployment.litellm_params is not None
and deployment.litellm_params.max_file_size_mb is not None
):
max_file_size_mb = deployment.litellm_params.max_file_size_mb
except Exception as e:
verbose_proxy_logger.error(
"Got error when checking file size: %s", (str(e))
)
if max_file_size_mb is not None:
verbose_proxy_logger.debug(
"Checking file size, file content size=%s, max_file_size_mb=%s",
file_content_size_in_mb,
max_file_size_mb,
)
if not premium_user:
raise ProxyException(
message=f"Tried setting max_file_size_mb for /audio/transcriptions. {CommonProxyErrors.not_premium_user.value}",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
if file_content_size_in_mb > max_file_size_mb:
raise ProxyException(
message=f"File size is too large. Please check your file size. Passed file size: {file_content_size_in_mb} MB. Max file size: {max_file_size_mb} MB",
code=status.HTTP_400_BAD_REQUEST,
type="bad_request",
param="file",
)
return True
async def get_form_data(request: Request) -> Dict[str, Any]:
"""
Read form data from request
Handles when OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
"""
form = await request.form()
form_data = dict(form)
parsed_form_data: dict[str, Any] = {}
for key, value in form_data.items():
# OpenAI SDKs pass form keys as `timestamp_granularities[]="word"` instead of `timestamp_granularities=["word", "sentence"]`
if key.endswith("[]"):
clean_key = key[:-2]
parsed_form_data.setdefault(clean_key, []).append(value)
else:
parsed_form_data[key] = value
return parsed_form_data
async def convert_upload_files_to_file_data(
form_data: Dict[str, Any]
) -> Dict[str, Any]:
"""
Convert FastAPI UploadFile objects to file data tuples for litellm.
Converts UploadFile objects to tuples of (filename, content, content_type)
which is the format expected by httpx and litellm's HTTP handlers.
Args:
form_data: Dictionary containing form data with potential UploadFile objects
Returns:
Dictionary with UploadFile objects converted to file data tuples
Example:
```python
form_data = await get_form_data(request)
data = await convert_upload_files_to_file_data(form_data)
# data["files"] is now [(filename, content, content_type), ...]
```
"""
data = {}
for key, value in form_data.items():
if isinstance(value, list):
# Check if it's a list of UploadFile objects
if value and hasattr(value[0], "read"):
files = []
for f in value:
file_content = await f.read()
# Create tuple: (filename, content, content_type)
files.append((f.filename, file_content, f.content_type))
data[key] = files
else:
data[key] = value
elif hasattr(value, "read"):
# Single UploadFile object - read and convert to list for consistency
file_content = await value.read()
data[key] = [(value.filename, file_content, value.content_type)]
else:
# Regular form field
data[key] = value
return data
async def get_request_body(request: Request) -> Dict[str, Any]:
"""
Read the request body and parse it as JSON.
"""
if request.method == "POST":
if request.headers.get("content-type", "") == "application/json":
return await _read_request_body(request)
elif "multipart/form-data" in request.headers.get(
"content-type", ""
) or "application/x-www-form-urlencoded" in request.headers.get(
"content-type", ""
):
return await get_form_data(request)
else:
raise ValueError(
f"Unsupported content type: {request.headers.get('content-type')}"
)
return {}
def extract_nested_form_metadata(
form_data: Dict[str, Any], prefix: str = "litellm_metadata["
) -> Dict[str, Any]:
"""
Extract nested metadata from form data with bracket notation.
Handles form data that uses bracket notation to represent nested dictionaries,
such as litellm_metadata[spend_logs_metadata][owner] = "value".
This is commonly encountered when SDKs or clients send form data with nested
structures using bracket notation instead of JSON.
Args:
form_data: Dictionary containing form data (from request.form())
prefix: The prefix to look for in form keys (default: "litellm_metadata[")
Returns:
Dictionary with nested structure reconstructed from bracket notation
Example:
Input form_data:
{
"litellm_metadata[spend_logs_metadata][owner]": "john",
"litellm_metadata[spend_logs_metadata][team]": "engineering",
"litellm_metadata[tags]": "production",
"other_field": "value"
}
Output:
{
"spend_logs_metadata": {
"owner": "john",
"team": "engineering"
},
"tags": "production"
}
"""
if not form_data:
return {}
metadata: Dict[str, Any] = {}
for key, value in form_data.items():
# Skip keys that don't start with the prefix
if not isinstance(key, str) or not key.startswith(prefix):
continue
# Skip UploadFile objects - they should not be in metadata
if isinstance(value, UploadFile):
verbose_proxy_logger.warning(
f"Skipping UploadFile in metadata extraction for key: {key}"
)
continue
# Extract the nested path from bracket notation
# Example: "litellm_metadata[spend_logs_metadata][owner]" -> ["spend_logs_metadata", "owner"]
try:
# Remove the prefix and strip trailing ']'
path_string = key.replace(prefix, "").rstrip("]")
# Split by "][" to get individual path parts
parts = path_string.split("][")
if not parts or not parts[0]:
verbose_proxy_logger.warning(
f"Invalid metadata key format (empty path): {key}"
)
continue
# Navigate/create nested dictionary structure
current = metadata
for part in parts[:-1]:
if not isinstance(current, dict):
verbose_proxy_logger.warning(
f"Cannot create nested path - intermediate value is not a dict at: {part}"
)
break
current = current.setdefault(part, {})
else:
# Set the final value (only if we didn't break out of the loop)
if isinstance(current, dict):
current[parts[-1]] = value
else:
verbose_proxy_logger.warning(
f"Cannot set value - parent is not a dict for key: {key}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error parsing metadata key '{key}': {str(e)}")
continue
return metadata
def get_tags_from_request_body(request_body: dict) -> List[str]:
"""
Extract tags from request body metadata.
Args:
request_body: The request body dictionary
Returns:
List of tag names (strings), empty list if no valid tags found
"""
metadata_variable_name = get_metadata_variable_name_from_kwargs(request_body)
metadata = request_body.get(metadata_variable_name) or {}
tags_in_metadata: Any = metadata.get("tags", [])
tags_in_request_body: Any = request_body.get("tags", [])
combined_tags: List[str] = []
######################################
# Only combine tags if they are lists
######################################
if isinstance(tags_in_metadata, list):
combined_tags.extend(tags_in_metadata)
if isinstance(tags_in_request_body, list):
combined_tags.extend(tags_in_request_body)
######################################
return [tag for tag in combined_tags if isinstance(tag, str)]
def populate_request_with_path_params(request_data: dict, request: Request) -> dict:
"""
Copy FastAPI path params and query params into the request payload so downstream checks
(e.g. vector store RBAC, organization RBAC) see them the same way as body params.
Since path_params may not be available during dependency injection,
we parse the URL path directly for known patterns.
Args:
request_data: The request data dictionary to populate
request: The FastAPI Request object
Returns:
dict: Updated request_data with path parameters and query parameters added
"""
# Add query parameters to request_data (for GET requests, etc.)
query_params = _safe_get_request_query_params(request)
if query_params:
for key, value in query_params.items():
# Don't overwrite existing values from request body
request_data.setdefault(key, value)
# Try to get path_params if available (sometimes populated by FastAPI)
path_params = getattr(request, "path_params", None)
if isinstance(path_params, dict) and path_params:
for key, value in path_params.items():
if key == "vector_store_id":
request_data.setdefault("vector_store_id", value)
existing_ids = request_data.get("vector_store_ids")
if isinstance(existing_ids, list):
if value not in existing_ids:
existing_ids.append(value)
else:
request_data["vector_store_ids"] = [value]
continue
request_data.setdefault(key, value)
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Found path_params, vector_store_ids={request_data.get('vector_store_ids')}"
)
return request_data
# Fallback: parse the URL path directly to extract vector_store_id
_add_vector_store_id_from_path(request_data=request_data, request=request)
return request_data
def _add_vector_store_id_from_path(request_data: dict, request: Request) -> None:
"""
Parse the request path to find /vector_stores/{vector_store_id}/... segments.
When found, ensure both vector_store_id and vector_store_ids are populated.
Args:
request_data: The request data dictionary to populate
request: The FastAPI Request object
"""
path = request.url.path
vector_store_match = re.search(r"/vector_stores/([^/]+)/", path)
if vector_store_match:
vector_store_id = vector_store_match.group(1)
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Extracted vector_store_id={vector_store_id} from path={path}"
)
request_data.setdefault("vector_store_id", vector_store_id)
existing_ids = request_data.get("vector_store_ids")
if isinstance(existing_ids, list):
if vector_store_id not in existing_ids:
existing_ids.append(vector_store_id)
else:
request_data["vector_store_ids"] = [vector_store_id]
verbose_proxy_logger.debug(
f"populate_request_with_path_params: Updated request_data with vector_store_ids={request_data.get('vector_store_ids')}"
)
else:
verbose_proxy_logger.debug(
f"populate_request_with_path_params: No vector_store_id present in path={path}"
)

View File

@@ -0,0 +1,187 @@
"""
Key Rotation Manager - Automated key rotation based on rotation schedules
Handles finding keys that need rotation based on their individual schedules.
"""
from datetime import datetime, timezone
from typing import List
from litellm._logging import verbose_proxy_logger
from litellm.constants import (
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
LITELLM_KEY_ROTATION_GRACE_PERIOD,
)
from litellm.proxy._types import (
GenerateKeyResponse,
LiteLLM_VerificationToken,
RegenerateKeyRequest,
)
from litellm.proxy.hooks.key_management_event_hooks import KeyManagementEventHooks
from litellm.proxy.management_endpoints.key_management_endpoints import (
_calculate_key_rotation_time,
regenerate_key_fn,
)
from litellm.proxy.utils import PrismaClient
class KeyRotationManager:
"""
Manages automated key rotation based on individual key rotation schedules.
"""
def __init__(self, prisma_client: PrismaClient):
self.prisma_client = prisma_client
async def process_rotations(self):
"""
Main entry point - find and rotate keys that are due for rotation
"""
try:
verbose_proxy_logger.info("Starting scheduled key rotation check...")
# Clean up expired deprecated keys first
await self._cleanup_expired_deprecated_keys()
# Find keys that are due for rotation
keys_to_rotate = await self._find_keys_needing_rotation()
if not keys_to_rotate:
verbose_proxy_logger.debug("No keys are due for rotation at this time")
return
verbose_proxy_logger.info(
f"Found {len(keys_to_rotate)} keys due for rotation"
)
# Rotate each key
for key in keys_to_rotate:
try:
await self._rotate_key(key)
key_identifier = key.key_name or (
key.token[:8] + "..." if key.token else "unknown"
)
verbose_proxy_logger.info(
f"Successfully rotated key: {key_identifier}"
)
except Exception as e:
key_identifier = key.key_name or (
key.token[:8] + "..." if key.token else "unknown"
)
verbose_proxy_logger.error(
f"Failed to rotate key {key_identifier}: {e}"
)
except Exception as e:
verbose_proxy_logger.error(f"Key rotation process failed: {e}")
async def _find_keys_needing_rotation(self) -> List[LiteLLM_VerificationToken]:
"""
Find keys that are due for rotation based on their key_rotation_at timestamp.
Logic:
- Key has auto_rotate = true
- key_rotation_at is null (needs initial setup) OR key_rotation_at <= now
"""
now = datetime.now(timezone.utc)
keys_with_rotation = (
await self.prisma_client.db.litellm_verificationtoken.find_many(
where={
"auto_rotate": True, # Only keys marked for auto rotation
"OR": [
{
"key_rotation_at": None
}, # Keys that need initial rotation time setup
{
"key_rotation_at": {"lte": now}
}, # Keys where rotation time has passed
],
}
)
)
return keys_with_rotation
async def _cleanup_expired_deprecated_keys(self) -> None:
"""
Remove deprecated key entries whose revoke_at has passed.
"""
try:
now = datetime.now(timezone.utc)
result = await self.prisma_client.db.litellm_deprecatedverificationtoken.delete_many(
where={"revoke_at": {"lt": now}}
)
if result > 0:
verbose_proxy_logger.debug(
"Cleaned up %s expired deprecated key(s)", result
)
except Exception as e:
verbose_proxy_logger.debug(
"Deprecated key cleanup skipped (table may not exist): %s", e
)
def _should_rotate_key(self, key: LiteLLM_VerificationToken, now: datetime) -> bool:
"""
Determine if a key should be rotated based on key_rotation_at timestamp.
"""
if not key.rotation_interval:
return False
# If key_rotation_at is not set, rotate immediately (and set it)
if key.key_rotation_at is None:
return True
# Check if the rotation time has passed
return now >= key.key_rotation_at
async def _rotate_key(self, key: LiteLLM_VerificationToken):
"""
Rotate a single key using existing regenerate_key_fn and call the rotation hook
"""
# Create regenerate request with grace period for seamless cutover
regenerate_request = RegenerateKeyRequest(
key=key.token or "",
key_alias=key.key_alias, # Pass key alias to ensure correct secret is updated in AWS Secrets Manager
grace_period=LITELLM_KEY_ROTATION_GRACE_PERIOD or None,
)
# Create a system user for key rotation
from litellm.proxy._types import UserAPIKeyAuth
system_user = UserAPIKeyAuth.get_litellm_internal_jobs_user_api_key_auth()
# Use existing regenerate key function
response = await regenerate_key_fn(
data=regenerate_request,
user_api_key_dict=system_user,
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
)
# Update the NEW key with rotation info (regenerate_key_fn creates a new token)
if (
isinstance(response, GenerateKeyResponse)
and response.token_id
and key.rotation_interval
):
# Calculate next rotation time using helper function
now = datetime.now(timezone.utc)
next_rotation_time = _calculate_key_rotation_time(key.rotation_interval)
await self.prisma_client.db.litellm_verificationtoken.update(
where={"token": response.token_id},
data={
"rotation_count": (key.rotation_count or 0) + 1,
"last_rotation_at": now,
"key_rotation_at": next_rotation_time,
},
)
# Call the existing rotation hook for notifications, audit logs, etc.
if isinstance(response, GenerateKeyResponse):
await KeyManagementEventHooks.async_key_rotated_hook(
data=regenerate_request,
existing_key_row=key,
response=response,
user_api_key_dict=system_user,
litellm_changed_by=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
)

View File

@@ -0,0 +1,178 @@
import os
import yaml
from litellm._logging import verbose_proxy_logger
def get_file_contents_from_s3(bucket_name, object_key):
try:
# v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc
import boto3
from botocore.credentials import Credentials
from litellm.main import bedrock_converse_chat_completion
credentials: Credentials = bedrock_converse_chat_completion.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token, # Optional, if using temporary credentials
)
verbose_proxy_logger.debug(
f"Retrieving {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
verbose_proxy_logger.debug(f"Response: {response}")
# Read the file contents and directly parse YAML
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug("File contents retrieved from S3")
# Parse YAML directly from string
config = yaml.safe_load(file_contents)
return config
except ImportError as e:
# this is most likely if a user is not using the litellm docker container
verbose_proxy_logger.error(f"ImportError: {str(e)}")
pass
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
async def get_config_file_contents_from_gcs(bucket_name, object_key):
try:
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
gcs_bucket = GCSBucketLogger(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)
if file_contents is None:
raise Exception(f"File contents are None for {object_key}")
# file_contentis is a bytes object, so we need to convert it to yaml
file_contents = file_contents.decode("utf-8")
# convert to yaml
config = yaml.safe_load(file_contents)
return config
except Exception as e:
verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}")
return None
def download_python_file_from_s3(
bucket_name: str,
object_key: str,
local_file_path: str,
) -> bool:
"""
Download a Python file from S3 and save it to local filesystem.
Args:
bucket_name (str): S3 bucket name
object_key (str): S3 object key (file path in bucket)
local_file_path (str): Local path where file should be saved
Returns:
bool: True if successful, False otherwise
"""
try:
import boto3
from botocore.credentials import Credentials
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
base_aws_llm = BaseAWSLLM()
credentials: Credentials = base_aws_llm.get_credentials()
s3_client = boto3.client(
"s3",
aws_access_key_id=credentials.access_key,
aws_secret_access_key=credentials.secret_key,
aws_session_token=credentials.token,
)
verbose_proxy_logger.debug(
f"Downloading Python file {object_key} from S3 bucket: {bucket_name}"
)
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
# Read the file contents
file_contents = response["Body"].read().decode("utf-8")
verbose_proxy_logger.debug(f"File contents: {file_contents}")
# Ensure directory exists
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Write to local file
with open(local_file_path, "w") as f:
f.write(file_contents)
verbose_proxy_logger.debug(
f"Python file downloaded successfully to {local_file_path}"
)
return True
except ImportError as e:
verbose_proxy_logger.error(f"ImportError: {str(e)}")
return False
except Exception as e:
verbose_proxy_logger.exception(f"Error downloading Python file: {str(e)}")
return False
async def download_python_file_from_gcs(
bucket_name: str,
object_key: str,
local_file_path: str,
) -> bool:
"""
Download a Python file from GCS and save it to local filesystem.
Args:
bucket_name (str): GCS bucket name
object_key (str): GCS object key (file path in bucket)
local_file_path (str): Local path where file should be saved
Returns:
bool: True if successful, False otherwise
"""
try:
from litellm.integrations.gcs_bucket.gcs_bucket import GCSBucketLogger
gcs_bucket = GCSBucketLogger(
bucket_name=bucket_name,
)
file_contents = await gcs_bucket.download_gcs_object(object_key)
if file_contents is None:
raise Exception(f"File contents are None for {object_key}")
# file_contents is a bytes object, decode it
file_contents = file_contents.decode("utf-8")
# Ensure directory exists
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
# Write to local file
with open(local_file_path, "w") as f:
f.write(file_contents)
verbose_proxy_logger.debug(
f"Python file downloaded successfully to {local_file_path}"
)
return True
except Exception as e:
verbose_proxy_logger.exception(
f"Error downloading Python file from GCS: {str(e)}"
)
return False
# # Example usage
# bucket_name = 'litellm-proxy'
# object_key = 'litellm_proxy_config.yaml'

View File

@@ -0,0 +1,71 @@
"""
Contains utils used by OpenAI compatible endpoints
"""
from typing import Optional, Set
from fastapi import Request
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
SENSITIVE_DATA_MASKER = SensitiveDataMasker()
def remove_sensitive_info_from_deployment(
deployment_dict: dict,
excluded_keys: Optional[Set[str]] = None,
) -> dict:
"""
Removes sensitive information from a deployment dictionary.
Args:
deployment_dict (dict): The deployment dictionary to remove sensitive information from.
excluded_keys (Optional[Set[str]]): Set of keys that should not be masked (exact match).
Returns:
dict: The modified deployment dictionary with sensitive information removed.
"""
deployment_dict["litellm_params"].pop("api_key", None)
deployment_dict["litellm_params"].pop("client_secret", None)
deployment_dict["litellm_params"].pop("vertex_credentials", None)
deployment_dict["litellm_params"].pop("aws_access_key_id", None)
deployment_dict["litellm_params"].pop("aws_secret_access_key", None)
deployment_dict["litellm_params"] = SENSITIVE_DATA_MASKER.mask_dict(
deployment_dict["litellm_params"], excluded_keys=excluded_keys
)
return deployment_dict
async def get_custom_llm_provider_from_request_body(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request body
Safely reads the request body
"""
request_body: dict = await _read_request_body(request=request) or {}
if "custom_llm_provider" in request_body:
return request_body["custom_llm_provider"]
return None
def get_custom_llm_provider_from_request_query(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request query parameters
Safely reads the request query parameters
"""
if "custom_llm_provider" in request.query_params:
return request.query_params["custom_llm_provider"]
return None
def get_custom_llm_provider_from_request_headers(request: Request) -> Optional[str]:
"""
Get the `custom_llm_provider` from the request header `custom-llm-provider`
"""
if "custom-llm-provider" in request.headers:
return request.headers["custom-llm-provider"]
return None

View File

@@ -0,0 +1,121 @@
"""
Utility module for handling OpenAPI schema generation compatibility with FastAPI 0.120+.
FastAPI 0.120+ has stricter schema generation that fails on certain types like openai.Timeout.
This module provides a compatibility layer to handle these cases gracefully.
"""
from typing import Any, Dict
from litellm._logging import verbose_proxy_logger
def get_openapi_schema_with_compat(
get_openapi_func,
title: str,
version: str,
description: str,
routes: list,
) -> Dict[str, Any]:
"""
Generate OpenAPI schema with compatibility handling for FastAPI 0.120+.
This function patches Pydantic's schema generation to handle non-serializable types
like openai.Timeout that cause PydanticSchemaGenerationError in FastAPI 0.120+.
Args:
get_openapi_func: The FastAPI get_openapi function
title: API title
version: API version
description: API description
routes: List of routes
Returns:
OpenAPI schema dictionary
"""
# FastAPI 0.120+ may fail schema generation for certain types (e.g., openai.Timeout)
# Patch Pydantic's schema generation to handle unknown types gracefully
try:
from pydantic._internal._generate_schema import GenerateSchema
from pydantic_core import core_schema
# Store original method
original_unknown_type_schema = GenerateSchema._unknown_type_schema
def patched_unknown_type_schema(self, obj):
"""Patch to handle openai.Timeout and other non-serializable types"""
# Check if it's openai.Timeout or similar types
obj_str = str(obj)
obj_module = getattr(obj, "__module__", "")
if (obj_module == "openai" and "Timeout" in obj_str) or (
hasattr(obj, "__name__")
and obj.__name__ == "Timeout"
and obj_module == "openai"
):
# Return a simple string schema for Timeout types
return core_schema.str_schema()
# For other unknown types, try to return a default schema
# This prevents the error from propagating
try:
return core_schema.any_schema()
except Exception:
# Last resort: return string schema
return core_schema.str_schema()
# Apply patch
setattr(GenerateSchema, "_unknown_type_schema", patched_unknown_type_schema)
try:
openapi_schema = get_openapi_func(
title=title,
version=version,
description=description,
routes=routes,
)
finally:
# Restore original method
setattr(
GenerateSchema, "_unknown_type_schema", original_unknown_type_schema
)
return openapi_schema
except (ImportError, AttributeError) as e:
# If patching fails, try normal generation with error handling
verbose_proxy_logger.debug(
f"Could not patch Pydantic schema generation: {e}. Trying normal generation."
)
try:
return get_openapi_func(
title=title,
version=version,
description=description,
routes=routes,
)
except Exception as pydantic_error:
# Check if it's a PydanticSchemaGenerationError by checking the error type name
# This avoids import issues if PydanticSchemaGenerationError is not available
error_type_name = type(pydantic_error).__name__
if (
error_type_name == "PydanticSchemaGenerationError"
or "PydanticSchemaGenerationError" in str(type(pydantic_error))
):
# If we still get the error, log it and return minimal schema
verbose_proxy_logger.warning(
f"PydanticSchemaGenerationError during schema generation: {pydantic_error}"
)
return {
"openapi": "3.0.0",
"info": {
"title": title,
"version": version,
"description": description or "",
},
"paths": {},
"components": {"schemas": {}},
}
else:
# Re-raise if it's a different error
raise

View File

@@ -0,0 +1,214 @@
# Performance Utilities Documentation
This module provides performance monitoring and profiling functionality for LiteLLM proxy server using `cProfile` and `line_profiler`.
## Table of Contents
- [Line Profiler Usage](#line-profiler-usage)
- [Example 1: Wrapping a function directly](#example-1-wrapping-a-function-directly)
- [Example 2: Wrapping a module function dynamically](#example-2-wrapping-a-module-function-dynamically)
- [Example 3: Manual stats collection](#example-3-manual-stats-collection)
- [Example 4: Analyzing the profile output](#example-4-analyzing-the-profile-output)
- [Example 5: Using in a decorator pattern](#example-5-using-in-a-decorator-pattern)
- [cProfile Usage](#cprofile-usage)
- [Installation](#installation)
- [Notes](#notes)
## Line Profiler Usage
### Example 1: Wrapping a function directly
This is how it's used in `litellm/utils.py` to profile `wrapper_async`:
```python
from litellm.proxy.common_utils.performance_utils import (
register_shutdown_handler,
wrap_function_directly,
)
def client(original_function):
@wraps(original_function)
async def wrapper_async(*args, **kwargs):
# ... function implementation ...
pass
# Wrap the function with line_profiler
wrapper_async = wrap_function_directly(wrapper_async)
# Register shutdown handler to collect stats on server shutdown
register_shutdown_handler(output_file="wrapper_async_line_profile.lprof")
return wrapper_async
```
### Example 2: Wrapping a module function dynamically
```python
import my_module
from litellm.proxy.common_utils.performance_utils import (
wrap_function_with_line_profiler,
register_shutdown_handler,
)
# Wrap a function in a module
wrap_function_with_line_profiler(my_module, "expensive_function")
# Register shutdown handler
register_shutdown_handler(output_file="my_profile.lprof")
# Now all calls to my_module.expensive_function will be profiled
my_module.expensive_function()
```
### Example 3: Manual stats collection
```python
from litellm.proxy.common_utils.performance_utils import (
wrap_function_directly,
collect_line_profiler_stats,
)
def my_function():
# ... implementation ...
pass
# Wrap the function
my_function = wrap_function_directly(my_function)
# Run your code
my_function()
# Collect stats manually (instead of waiting for shutdown)
collect_line_profiler_stats(output_file="manual_profile.lprof")
```
### Example 4: Analyzing the profile output
After running your code, analyze the `.lprof` file:
```bash
# View the profile
python -m line_profiler wrapper_async_line_profile.lprof
# Save to text file
python -m line_profiler wrapper_async_line_profile.lprof > profile_report.txt
```
The output shows:
- **Line #**: Line number in the source file
- **Hits**: Number of times the line was executed
- **Time**: Total time spent on that line (in microseconds)
- **Per Hit**: Average time per execution
- **% Time**: Percentage of total function time
- **Line Contents**: The actual source code
Example output:
```
Timer unit: 1e-06 s
Total time: 3.73697 s
File: litellm/utils.py
Function: client.<locals>.wrapper_async at line 1657
Line # Hits Time Per Hit % Time Line Contents
==============================================================
1657 @wraps(original_function)
1658 async def wrapper_async(*args, **kwargs):
1659 2005 7577.1 3.8 0.2 print_args_passed_to_litellm(...)
1763 2005 1351909.0 674.3 36.2 result = await original_function(*args, **kwargs)
1846 4010 1543688.1 385.0 41.3 update_response_metadata(...)
```
### Example 5: Using in a decorator pattern
```python
from litellm.proxy.common_utils.performance_utils import (
wrap_function_directly,
register_shutdown_handler,
)
def profile_decorator(func):
# Wrap the function
profiled_func = wrap_function_directly(func)
# Register shutdown handler (only once)
if not hasattr(profile_decorator, '_registered'):
register_shutdown_handler(output_file="decorated_functions.lprof")
profile_decorator._registered = True
return profiled_func
@profile_decorator
async def my_async_function():
# This function will be profiled
pass
```
## cProfile Usage
### Example: Using the profile_endpoint decorator
```python
from litellm.proxy.common_utils.performance_utils import profile_endpoint
@profile_endpoint(sampling_rate=0.1) # Profile 10% of requests
async def my_endpoint():
# ... implementation ...
pass
```
The `sampling_rate` parameter controls what percentage of requests are profiled:
- `1.0`: Profile all requests (100%)
- `0.1`: Profile 1 in 10 requests (10%)
- `0.0`: Profile no requests (0%)
## Installation
`line_profiler` must be installed to use the line profiling functionality:
```bash
pip install line_profiler
```
On Windows with Python 3.14+, you may need to install Microsoft Visual C++ Build Tools to compile `line_profiler` from source.
## Notes
- The profiler aggregates stats by source code location, so multiple instances of the same function (e.g., closures) will be profiled together
- Stats are automatically collected on server shutdown via `atexit` handler when using `register_shutdown_handler()`
- You can also manually collect stats using `collect_line_profiler_stats()`
- The line profiler will fail with an `ImportError` if `line_profiler` is not installed (as configured in `litellm/utils.py`)
## API Reference
### `wrap_function_directly(func: Callable) -> Callable`
Wrap a function directly with line_profiler. This is the recommended way to profile functions, especially closures or functions created dynamically.
**Raises:**
- `ImportError`: If line_profiler is not available
- `RuntimeError`: If line_profiler cannot be enabled or function cannot be wrapped
### `wrap_function_with_line_profiler(module: Any, function_name: str) -> bool`
Dynamically wrap a function in a module with line_profiler.
**Returns:** `True` if wrapping was successful, `False` otherwise
### `collect_line_profiler_stats(output_file: Optional[str] = None) -> None`
Collect and save line_profiler statistics. If `output_file` is provided, saves to file. Otherwise, prints to stdout.
### `register_shutdown_handler(output_file: Optional[str] = None) -> None`
Register an `atexit` handler that will automatically save profiling statistics when the Python process exits. Safe to call multiple times (only registers once).
**Default output file:** `line_profile_stats.lprof` if not specified
### `profile_endpoint(sampling_rate: float = 1.0)`
Decorator to sample endpoint hits and save to a profile file using cProfile.
**Args:**
- `sampling_rate`: Rate of requests to profile (0.0 to 1.0)

View File

@@ -0,0 +1,296 @@
"""
Performance utilities for LiteLLM proxy server.
This module provides performance monitoring and profiling functionality for endpoint
performance analysis using cProfile with configurable sampling rates, and line_profiler
for line-by-line profiling.
See performance_utils.md for detailed usage examples and documentation.
"""
import atexit
import cProfile
import functools
import inspect
import threading
from pathlib import Path as PathLib
from typing import Any, Callable, Optional
from litellm._logging import verbose_proxy_logger
# Global profiling state
_profile_lock = threading.Lock()
_profiler = None
_last_profile_file_path = None
_sample_counter = 0
_sample_counter_lock = threading.Lock()
# Global line_profiler state
_line_profiler: Optional[Any] = None
_line_profiler_lock = threading.Lock()
_wrapped_functions: dict[str, Callable] = {} # Store original functions
def _should_sample(profile_sampling_rate: float) -> bool:
"""Determine if current request should be sampled based on sampling rate."""
if profile_sampling_rate >= 1.0:
return True # Always sample
elif profile_sampling_rate <= 0.0:
return False # Never sample
# Use deterministic sampling based on counter for consistent rate
global _sample_counter
with _sample_counter_lock:
_sample_counter += 1
# Sample based on rate (e.g., 0.1 means sample every 10th request)
should_sample = (_sample_counter % int(1.0 / profile_sampling_rate)) == 0
return should_sample
def _start_profiling(profile_sampling_rate: float) -> None:
"""Start cProfile profiling once globally."""
global _profiler
with _profile_lock:
if _profiler is None:
_profiler = cProfile.Profile()
_profiler.enable()
verbose_proxy_logger.info(
f"Profiling started with sampling rate: {profile_sampling_rate}"
)
def _start_profiling_for_request(profile_sampling_rate: float) -> bool:
"""Start profiling for a specific request (if sampling allows)."""
if _should_sample(profile_sampling_rate):
_start_profiling(profile_sampling_rate)
return True
return False
def _save_stats(profile_file: PathLib) -> None:
"""Save current stats directly to file."""
with _profile_lock:
if _profiler is None:
return
try:
# Disable profiler temporarily to dump stats
_profiler.disable()
_profiler.dump_stats(str(profile_file))
# Re-enable profiler to continue profiling
_profiler.enable()
verbose_proxy_logger.debug(f"Profiling stats saved to {profile_file}")
except Exception as e:
verbose_proxy_logger.error(f"Error saving profiling stats: {e}")
# Make sure profiler is re-enabled even if there's an error
try:
_profiler.enable()
except Exception:
pass
def profile_endpoint(sampling_rate: float = 1.0):
"""Decorator to sample endpoint hits and save to a profile file.
Args:
sampling_rate: Rate of requests to profile (0.0 to 1.0)
- 1.0: Profile all requests (100%)
- 0.1: Profile 1 in 10 requests (10%)
- 0.0: Profile no requests (0%)
"""
def decorator(func):
def set_last_profile_path(path: PathLib) -> None:
global _last_profile_file_path
_last_profile_file_path = path
if inspect.iscoroutinefunction(func):
@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
is_sampling = _start_profiling_for_request(sampling_rate)
file_path_obj = PathLib("endpoint_profile.pstat")
set_last_profile_path(file_path_obj)
try:
result = await func(*args, **kwargs)
if is_sampling:
_save_stats(file_path_obj)
return result
except Exception:
if is_sampling:
_save_stats(file_path_obj)
raise
return async_wrapper
else:
@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
is_sampling = _start_profiling_for_request(sampling_rate)
file_path_obj = PathLib("endpoint_profile.pstat")
set_last_profile_path(file_path_obj)
try:
result = func(*args, **kwargs)
if is_sampling:
_save_stats(file_path_obj)
return result
except Exception:
if is_sampling:
_save_stats(file_path_obj)
raise
return sync_wrapper
return decorator
def enable_line_profiler() -> None:
"""Enable line_profiler for dynamic function wrapping.
Raises:
ImportError: If line_profiler is not available
"""
global _line_profiler
from line_profiler import LineProfiler # Will raise ImportError if not available
with _line_profiler_lock:
if _line_profiler is None:
_line_profiler = LineProfiler()
verbose_proxy_logger.info("Line profiler enabled")
def wrap_function_with_line_profiler(module: Any, function_name: str) -> bool:
"""Dynamically wrap a function with line_profiler.
Args:
module: The module containing the function
function_name: Name of the function to wrap
Returns:
True if wrapping was successful, False otherwise
"""
try:
enable_line_profiler() # May raise ImportError if not available
except ImportError:
return False
if _line_profiler is None:
return False
try:
original_function = getattr(module, function_name, None)
if original_function is None:
verbose_proxy_logger.warning(
f"Function {function_name} not found in module {module.__name__}"
)
return False
# Store original function if not already wrapped
if function_name not in _wrapped_functions:
_wrapped_functions[function_name] = original_function
# Wrap with line_profiler
profiled_function = _line_profiler(original_function)
setattr(module, function_name, profiled_function)
verbose_proxy_logger.info(
f"Wrapped {module.__name__}.{function_name} with line_profiler"
)
return True
except Exception as e:
verbose_proxy_logger.error(
f"Error wrapping {function_name} with line_profiler: {e}"
)
return False
def wrap_function_directly(func: Callable) -> Callable:
"""Wrap a function directly with line_profiler.
This is the recommended way to profile functions, especially closures or
functions created dynamically (like wrapper_async in litellm/utils.py).
Args:
func: The function to wrap
Returns:
The wrapped function that will be profiled when called
Raises:
ImportError: If line_profiler is not available
RuntimeError: If line_profiler cannot be enabled or function cannot be wrapped
"""
import warnings
enable_line_profiler() # Will raise ImportError if not available
if _line_profiler is None:
raise RuntimeError("Line profiler was not initialized")
# Suppress warnings about __wrapped__ - we intentionally want to profile the wrapper
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message=".*__wrapped__.*", category=UserWarning
)
# Add function to line_profiler and wrap it
_line_profiler.add_function(func)
profiled_function = _line_profiler(func)
verbose_proxy_logger.info(f"Wrapped function {func.__name__} with line_profiler")
return profiled_function
def collect_line_profiler_stats(output_file: Optional[str] = None) -> None:
"""Collect and save line_profiler statistics.
This can be called manually to collect stats at any time, or it's
automatically called on shutdown if register_shutdown_handler() was used.
Args:
output_file: Optional path to save stats. If None, prints to stdout.
"""
global _line_profiler
with _line_profiler_lock:
if _line_profiler is None:
verbose_proxy_logger.debug("Line profiler not enabled, nothing to collect")
return
try:
if output_file:
# Save to file
output_path = PathLib(output_file)
_line_profiler.dump_stats(str(output_path))
verbose_proxy_logger.info(f"Line profiler stats saved to {output_path}")
else:
# Print to stdout
from io import StringIO
stream = StringIO()
_line_profiler.print_stats(stream=stream)
stats_output = stream.getvalue()
verbose_proxy_logger.info("Line profiler stats:\n" + stats_output)
except Exception as e:
verbose_proxy_logger.error(f"Error collecting line profiler stats: {e}")
def register_shutdown_handler(output_file: Optional[str] = None) -> None:
"""Register a shutdown handler to collect line_profiler stats.
This registers an atexit handler that will automatically save profiling
statistics when the Python process exits. Safe to call multiple times
(only registers once).
Args:
output_file: Optional path to save stats on shutdown.
Defaults to 'line_profile_stats.lprof'
"""
if output_file is None:
output_file = "line_profile_stats.lprof"
def shutdown_handler():
collect_line_profiler_stats(output_file=output_file)
atexit.register(shutdown_handler)
verbose_proxy_logger.debug(
f"Registered line_profiler shutdown handler for {output_file}"
)

View File

@@ -0,0 +1,36 @@
"""
This file is used to store the state variables of the proxy server.
Example: `spend_logs_row_count` is used to store the number of rows in the `LiteLLM_SpendLogs` table.
"""
from typing import Any, Literal
from litellm.proxy._types import ProxyStateVariables
class ProxyState:
"""
Proxy state class has get/set methods for Proxy state variables.
"""
# Note: mypy does not recognize when we fetch ProxyStateVariables.annotations.keys(), so we also need to add the valid keys here
valid_keys_literal = Literal["spend_logs_row_count"]
def __init__(self) -> None:
self.proxy_state_variables: ProxyStateVariables = ProxyStateVariables(
spend_logs_row_count=0,
)
def get_proxy_state_variable(
self,
variable_name: valid_keys_literal,
) -> Any:
return self.proxy_state_variables.get(variable_name, None)
def set_proxy_state_variable(
self,
variable_name: valid_keys_literal,
value: Any,
) -> None:
self.proxy_state_variables[variable_name] = value

View File

@@ -0,0 +1,70 @@
"""
RBAC utility helpers for feature-level access control.
These helpers are used by agent and vector store endpoints to enforce
proxy-admin-configurable toggles that restrict access for internal users.
"""
from typing import Literal
from fastapi import HTTPException
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
FeatureName = Literal["agents", "vector_stores"]
async def check_feature_access_for_user(
user_api_key_dict: UserAPIKeyAuth,
feature_name: FeatureName,
) -> None:
"""
Raise HTTP 403 if the user's role is blocked from accessing the given feature
by the UI settings stored in general_settings.
Args:
user_api_key_dict: The authenticated user.
feature_name: Either "agents" or "vector_stores".
"""
# Proxy admins (and view-only admins) are never blocked.
if user_api_key_dict.user_role in (
LitellmUserRoles.PROXY_ADMIN,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY,
LitellmUserRoles.PROXY_ADMIN.value,
LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value,
):
return
from litellm.proxy.proxy_server import (
general_settings,
prisma_client,
user_api_key_cache,
)
disable_flag = f"disable_{feature_name}_for_internal_users"
allow_team_admins_flag = f"allow_{feature_name}_for_team_admins"
if not general_settings.get(disable_flag, False):
# Feature is not disabled — allow all authenticated users.
return
# Feature is disabled. Check if team/org admins are exempted.
if general_settings.get(allow_team_admins_flag, False):
from litellm.proxy.management_endpoints.common_utils import (
_user_has_admin_privileges,
)
is_admin = await _user_has_admin_privileges(
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
if is_admin:
return
raise HTTPException(
status_code=403,
detail={
"error": f"Access to {feature_name} is disabled for your role. Contact your proxy admin."
},
)

View File

@@ -0,0 +1,13 @@
from functools import lru_cache
from typing import Optional
from litellm.constants import _REALTIME_BODY_CACHE_SIZE
@lru_cache(maxsize=_REALTIME_BODY_CACHE_SIZE)
def _realtime_request_body(model: Optional[str]) -> bytes:
"""
Generate the realtime websocket request body. Cached with LRU semantics to avoid repeated
string formatting work while keeping memory usage bounded.
"""
return f'{{"model": "{model or ""}"}}'.encode()

View File

@@ -0,0 +1,619 @@
import asyncio
import json
import time
from datetime import datetime, timedelta, timezone
from typing import List, Literal, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
LiteLLM_BudgetTableFull,
LiteLLM_EndUserTable,
LiteLLM_TeamTable,
LiteLLM_UserTable,
LiteLLM_VerificationToken,
)
from litellm.proxy.utils import PrismaClient, ProxyLogging
from litellm.types.services import ServiceTypes
class ResetBudgetJob:
"""
Resets the budget for all the keys, users, and teams that need it
"""
def __init__(self, proxy_logging_obj: ProxyLogging, prisma_client: PrismaClient):
self.proxy_logging_obj: ProxyLogging = proxy_logging_obj
self.prisma_client: PrismaClient = prisma_client
async def reset_budget(
self,
):
"""
Gets all the non-expired keys for a db, which need spend to be reset
Resets their spend
Updates db
"""
if self.prisma_client is not None:
### RESET KEY BUDGET ###
await self.reset_budget_for_litellm_keys()
### RESET USER BUDGET ###
await self.reset_budget_for_litellm_users()
## Reset Team Budget
await self.reset_budget_for_litellm_teams()
### RESET ENDUSER (Customer) BUDGET and corresponding Budget duration ###
await self.reset_budget_for_litellm_budget_table()
async def reset_budget_for_litellm_team_members(
self, budgets_to_reset: List[LiteLLM_BudgetTableFull]
):
"""
Resets the budget for all LiteLLM Team Members if their budget has expired
"""
return await self.prisma_client.db.litellm_teammembership.update_many(
where={
"budget_id": {
"in": [
budget.budget_id
for budget in budgets_to_reset
if budget.budget_id is not None
]
}
},
data={
"spend": 0,
},
)
async def reset_budget_for_keys_linked_to_budgets(
self, budgets_to_reset: List[LiteLLM_BudgetTableFull]
):
"""
Resets the spend for keys linked to budget tiers that are being reset.
This handles keys that have budget_id but no budget_duration set on the key
itself. Keys with budget_id rely on their linked budget tier's reset schedule
rather than having their own budget_duration.
Keys that have their own budget_duration are already handled by
reset_budget_for_litellm_keys() and are excluded here to avoid
double-resetting.
"""
budget_ids = [
budget.budget_id
for budget in budgets_to_reset
if budget.budget_id is not None
]
if not budget_ids:
return
return await self.prisma_client.db.litellm_verificationtoken.update_many(
where={
"budget_id": {"in": budget_ids},
"budget_duration": None, # only keys without their own reset schedule
"spend": {"gt": 0}, # only reset keys that have accumulated spend
},
data={
"spend": 0,
},
)
async def reset_budget_for_litellm_budget_table(self):
"""
Resets the budget for all LiteLLM End-Users (Customers), and Team Members if their budget has expired
The corresponding Budget duration is also updated.
"""
now = datetime.now(timezone.utc)
start_time = time.time()
endusers_to_reset: Optional[List[LiteLLM_EndUserTable]] = None
budgets_to_reset: Optional[List[LiteLLM_BudgetTableFull]] = None
updated_endusers: List[LiteLLM_EndUserTable] = []
failed_endusers = []
try:
budgets_to_reset = await self.prisma_client.get_data(
table_name="budget", query_type="find_all", reset_at=now
)
if budgets_to_reset is not None and len(budgets_to_reset) > 0:
for budget in budgets_to_reset:
budget = await ResetBudgetJob._reset_budget_reset_at_date(
budget, now
)
await self.prisma_client.update_data(
query_type="update_many",
data_list=budgets_to_reset,
table_name="budget",
)
endusers_to_reset = await self.prisma_client.get_data(
table_name="enduser",
query_type="find_all",
budget_id_list=[
budget.budget_id
for budget in budgets_to_reset
if budget.budget_id is not None
],
)
await self.reset_budget_for_litellm_team_members(
budgets_to_reset=budgets_to_reset
)
await self.reset_budget_for_keys_linked_to_budgets(
budgets_to_reset=budgets_to_reset
)
if endusers_to_reset is not None and len(endusers_to_reset) > 0:
for enduser in endusers_to_reset:
try:
updated_enduser = (
await ResetBudgetJob._reset_budget_for_enduser(
enduser=enduser
)
)
if updated_enduser is not None:
updated_endusers.append(updated_enduser)
else:
failed_endusers.append(
{
"enduser": enduser,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_endusers.append({"enduser": enduser, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for enduser: %s", enduser
)
verbose_proxy_logger.debug(
"Updated users %s",
json.dumps(updated_endusers, indent=4, default=str),
)
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_endusers,
table_name="enduser",
)
end_time = time.time()
if len(failed_endusers) > 0: # If any endusers failed to reset
raise Exception(
f"Failed to reset {len(failed_endusers)} endusers: {json.dumps(failed_endusers, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_budget_table",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_budgets_found": (
len(budgets_to_reset) if budgets_to_reset else 0
),
"budgets_found": json.dumps(
budgets_to_reset, indent=4, default=str
),
"num_endusers_found": (
len(endusers_to_reset) if endusers_to_reset else 0
),
"endusers_found": json.dumps(
endusers_to_reset, indent=4, default=str
),
"num_endusers_updated": len(updated_endusers),
"endusers_updated": json.dumps(
updated_endusers, indent=4, default=str
),
"num_endusers_failed": len(failed_endusers),
"endusers_failed": json.dumps(
failed_endusers, indent=4, default=str
),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_endusers",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_budgets_found": (
len(budgets_to_reset) if budgets_to_reset else 0
),
"budgets_found": json.dumps(
budgets_to_reset, indent=4, default=str
),
"num_endusers_found": (
len(endusers_to_reset) if endusers_to_reset else 0
),
"endusers_found": json.dumps(
endusers_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for endusers: %s", e)
async def reset_budget_for_litellm_keys(self):
"""
Resets the budget for all the litellm keys
Catches Exceptions and logs them
"""
now = datetime.utcnow()
start_time = time.time()
keys_to_reset: Optional[List[LiteLLM_VerificationToken]] = None
try:
keys_to_reset = await self.prisma_client.get_data(
table_name="key", query_type="find_all", expires=now, reset_at=now
)
verbose_proxy_logger.debug(
"Keys to reset %s", json.dumps(keys_to_reset, indent=4, default=str)
)
updated_keys: List[LiteLLM_VerificationToken] = []
failed_keys = []
if keys_to_reset is not None and len(keys_to_reset) > 0:
for key in keys_to_reset:
try:
updated_key = await ResetBudgetJob._reset_budget_for_key(
key=key, current_time=now
)
if updated_key is not None:
updated_keys.append(updated_key)
else:
failed_keys.append(
{"key": key, "error": "Returned None without exception"}
)
except Exception as e:
failed_keys.append({"key": key, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for key: %s", key
)
verbose_proxy_logger.debug(
"Updated keys %s", json.dumps(updated_keys, indent=4, default=str)
)
if updated_keys:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_keys,
table_name="key",
)
end_time = time.time()
if len(failed_keys) > 0: # If any keys failed to reset
raise Exception(
f"Failed to reset {len(failed_keys)} keys: {json.dumps(failed_keys, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
"num_keys_updated": len(updated_keys),
"keys_updated": json.dumps(updated_keys, indent=4, default=str),
"num_keys_failed": len(failed_keys),
"keys_failed": json.dumps(failed_keys, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_keys",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_keys_found": len(keys_to_reset) if keys_to_reset else 0,
"keys_found": json.dumps(keys_to_reset, indent=4, default=str),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for keys: %s", e)
async def reset_budget_for_litellm_users(self):
"""
Resets the budget for all LiteLLM Internal Users if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
users_to_reset: Optional[List[LiteLLM_UserTable]] = None
try:
users_to_reset = await self.prisma_client.get_data(
table_name="user", query_type="find_all", reset_at=now
)
updated_users: List[LiteLLM_UserTable] = []
failed_users = []
if users_to_reset is not None and len(users_to_reset) > 0:
for user in users_to_reset:
try:
updated_user = await ResetBudgetJob._reset_budget_for_user(
user=user, current_time=now
)
if updated_user is not None:
updated_users.append(updated_user)
else:
failed_users.append(
{
"user": user,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_users.append({"user": user, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for user: %s", user
)
verbose_proxy_logger.debug(
"Updated users %s", json.dumps(updated_users, indent=4, default=str)
)
if updated_users:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_users,
table_name="user",
)
end_time = time.time()
if len(failed_users) > 0: # If any users failed to reset
raise Exception(
f"Failed to reset {len(failed_users)} users: {json.dumps(failed_users, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
"num_users_updated": len(updated_users),
"users_updated": json.dumps(
updated_users, indent=4, default=str
),
"num_users_failed": len(failed_users),
"users_failed": json.dumps(failed_users, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_users",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_users_found": len(users_to_reset) if users_to_reset else 0,
"users_found": json.dumps(
users_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for users: %s", e)
async def reset_budget_for_litellm_teams(self):
"""
Resets the budget for all LiteLLM Internal Teams if their budget has expired
"""
now = datetime.utcnow()
start_time = time.time()
teams_to_reset: Optional[List[LiteLLM_TeamTable]] = None
try:
teams_to_reset = await self.prisma_client.get_data(
table_name="team", query_type="find_all", reset_at=now
)
updated_teams: List[LiteLLM_TeamTable] = []
failed_teams = []
if teams_to_reset is not None and len(teams_to_reset) > 0:
for team in teams_to_reset:
try:
updated_team = await ResetBudgetJob._reset_budget_for_team(
team=team, current_time=now
)
if updated_team is not None:
updated_teams.append(updated_team)
else:
failed_teams.append(
{
"team": team,
"error": "Returned None without exception",
}
)
except Exception as e:
failed_teams.append({"team": team, "error": str(e)})
verbose_proxy_logger.exception(
"Failed to reset budget for team: %s", team
)
verbose_proxy_logger.debug(
"Updated teams %s", json.dumps(updated_teams, indent=4, default=str)
)
if updated_teams:
await self.prisma_client.update_data(
query_type="update_many",
data_list=updated_teams,
table_name="team",
)
end_time = time.time()
if len(failed_teams) > 0: # If any teams failed to reset
raise Exception(
f"Failed to reset {len(failed_teams)} teams: {json.dumps(failed_teams, default=str)}"
)
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
"num_teams_updated": len(updated_teams),
"teams_updated": json.dumps(
updated_teams, indent=4, default=str
),
"num_teams_failed": len(failed_teams),
"teams_failed": json.dumps(failed_teams, indent=4, default=str),
},
)
)
except Exception as e:
end_time = time.time()
asyncio.create_task(
self.proxy_logging_obj.service_logging_obj.async_service_failure_hook(
service=ServiceTypes.RESET_BUDGET_JOB,
duration=end_time - start_time,
error=e,
call_type="reset_budget_teams",
start_time=start_time,
end_time=end_time,
event_metadata={
"num_teams_found": len(teams_to_reset) if teams_to_reset else 0,
"teams_found": json.dumps(
teams_to_reset, indent=4, default=str
),
},
)
)
verbose_proxy_logger.exception("Failed to reset budget for teams: %s", e)
@staticmethod
async def _reset_budget_common(
item: Union[LiteLLM_TeamTable, LiteLLM_UserTable, LiteLLM_VerificationToken],
current_time: datetime,
item_type: Literal["key", "team", "user"],
):
"""
In-place, updates spend=0, and sets budget_reset_at to current_time + budget_duration
Common logic for resetting budget for a team, user, or key
"""
try:
item.spend = 0.0
if hasattr(item, "budget_duration") and item.budget_duration is not None:
# Get standardized reset time based on budget duration
from litellm.proxy.common_utils.timezone_utils import (
get_budget_reset_time,
)
item.budget_reset_at = get_budget_reset_time(
budget_duration=item.budget_duration
)
return item
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget for %s: %s. Item: %s", item_type, e, item
)
raise e
@staticmethod
async def _reset_budget_for_team(
team: LiteLLM_TeamTable, current_time: datetime
) -> Optional[LiteLLM_TeamTable]:
await ResetBudgetJob._reset_budget_common(
item=team, current_time=current_time, item_type="team"
)
return team
@staticmethod
async def _reset_budget_for_user(
user: LiteLLM_UserTable, current_time: datetime
) -> Optional[LiteLLM_UserTable]:
await ResetBudgetJob._reset_budget_common(
item=user, current_time=current_time, item_type="user"
)
return user
@staticmethod
async def _reset_budget_for_enduser(
enduser: LiteLLM_EndUserTable,
) -> Optional[LiteLLM_EndUserTable]:
try:
enduser.spend = 0.0
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget for enduser: %s. Item: %s", e, enduser
)
raise e
return enduser
@staticmethod
async def _reset_budget_reset_at_date(
budget: LiteLLM_BudgetTableFull, current_time: datetime
) -> LiteLLM_BudgetTableFull:
try:
if budget.budget_duration is not None:
from litellm.litellm_core_utils.duration_parser import (
duration_in_seconds,
)
duration_s = duration_in_seconds(duration=budget.budget_duration)
# Fallback for existing budgets that do not have a budget_reset_at date set, ensuring the duration is taken into account
if (
budget.budget_reset_at is None
and budget.created_at + timedelta(seconds=duration_s) > current_time
):
budget.budget_reset_at = budget.created_at + timedelta(
seconds=duration_s
)
else:
budget.budget_reset_at = current_time + timedelta(
seconds=duration_s
)
except Exception as e:
verbose_proxy_logger.exception(
"Error resetting budget_reset_at for budget: %s. Item: %s", e, budget
)
raise e
return budget
@staticmethod
async def _reset_budget_for_key(
key: LiteLLM_VerificationToken, current_time: datetime
) -> Optional[LiteLLM_VerificationToken]:
await ResetBudgetJob._reset_budget_common(
item=key, current_time=current_time, item_type="key"
)
return key

View File

@@ -0,0 +1,48 @@
from typing import Any, Dict
from pydantic import BaseModel, Field
from litellm.exceptions import LITELLM_EXCEPTION_TYPES
class ErrorResponse(BaseModel):
detail: Dict[str, Any] = Field(
...,
example={ # type: ignore
"error": {
"message": "Error message",
"type": "error_type",
"param": "error_param",
"code": "error_code",
}
},
)
# Define a function to get the status code
def get_status_code(exception):
if hasattr(exception, "status_code"):
return exception.status_code
# Default status codes for exceptions without a status_code attribute
if exception.__name__ == "Timeout":
return 408 # Request Timeout
if exception.__name__ == "APIConnectionError":
return 503 # Service Unavailable
return 500 # Internal Server Error as default
# Create error responses
ERROR_RESPONSES = {
get_status_code(exception): {
"model": ErrorResponse,
"description": exception.__doc__ or exception.__name__,
}
for exception in LITELLM_EXCEPTION_TYPES
}
# Ensure we have a 500 error response
if 500 not in ERROR_RESPONSES:
ERROR_RESPONSES[500] = {
"model": ErrorResponse,
"description": "Internal Server Error",
}

View File

@@ -0,0 +1,29 @@
from datetime import datetime, timezone
import litellm
from litellm.litellm_core_utils.duration_parser import get_next_standardized_reset_time
def get_budget_reset_timezone():
"""
Get the budget reset timezone from litellm_settings.
Falls back to UTC if not specified.
litellm_settings values are set as attributes on the litellm module
by proxy_server.py at startup (via setattr(litellm, key, value)).
"""
return getattr(litellm, "timezone", None) or "UTC"
def get_budget_reset_time(budget_duration: str):
"""
Get the budget reset time based on the configured timezone.
Falls back to UTC if not specified.
"""
reset_at = get_next_standardized_reset_time(
duration=budget_duration,
current_time=datetime.now(timezone.utc),
timezone_str=get_budget_reset_timezone(),
)
return reset_at