chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,53 @@
from typing import Any, Literal, List
class CustomDB:
"""
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
"""
def __init__(self) -> None:
pass
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
"""
Check if key valid
"""
pass
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
"""
For new key / user logic
"""
pass
def update_data(
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
):
"""
For cost tracking logic
"""
pass
def delete_data(
self, keys: List[str], table_name: Literal["user", "key", "config"]
):
"""
For /key/delete endpoint s
"""
def connect(
self,
):
"""
For connecting to db and creating / updating any tables
"""
pass
def disconnect(
self,
):
"""
For closing connection on server shutdown
"""
pass

View File

@@ -0,0 +1,104 @@
"""Module for checking differences between Prisma schema and database."""
import os
import subprocess
from typing import List, Optional, Tuple
from litellm._logging import verbose_logger
def extract_sql_commands(diff_output: str) -> List[str]:
"""
Extract SQL commands from the Prisma migrate diff output.
Args:
diff_output (str): The full output from prisma migrate diff.
Returns:
List[str]: A list of SQL commands extracted from the diff output.
"""
# Split the output into lines and remove empty lines
lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
sql_commands = []
current_command = ""
in_sql_block = False
for line in lines:
if line.startswith("-- "): # Comment line, likely a table operation description
if in_sql_block and current_command:
sql_commands.append(current_command.strip())
current_command = ""
in_sql_block = True
elif in_sql_block:
if line.endswith(";"):
current_command += line
sql_commands.append(current_command.strip())
current_command = ""
in_sql_block = False
else:
current_command += line + " "
# Add any remaining command
if current_command:
sql_commands.append(current_command.strip())
return sql_commands
def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
"""Checks for differences between current database and Prisma schema.
Returns:
A tuple containing:
- A boolean indicating if differences were found (True) or not (False).
- A string with the diff output or error message.
Raises:
subprocess.CalledProcessError: If the Prisma command fails.
Exception: For any other errors during execution.
"""
verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201
try:
result = subprocess.run(
[
"prisma",
"migrate",
"diff",
"--from-url",
db_url,
"--to-schema-datamodel",
"./schema.prisma",
"--script",
],
capture_output=True,
text=True,
check=True,
)
# return True, "Migration diff generated successfully."
sql_commands = extract_sql_commands(result.stdout)
if sql_commands:
print("Changes to DB Schema detected") # noqa: T201
print("Required SQL commands:") # noqa: T201
for command in sql_commands:
print(command) # noqa: T201
return True, sql_commands
else:
return False, []
except subprocess.CalledProcessError as e:
error_message = f"Failed to generate migration diff. Error: {e.stderr}"
print(error_message) # noqa: T201
return False, []
def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
"""Main function to run the Prisma schema diff check."""
if db_url is None:
db_url = os.getenv("DATABASE_URL")
if db_url is None:
raise Exception("DATABASE_URL not set")
has_diff, message = check_prisma_schema_diff_helper(db_url)
if has_diff:
verbose_logger.exception(
"🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
message
)
)

View File

@@ -0,0 +1,227 @@
from typing import Any
from litellm import verbose_logger
_db = Any
async def create_missing_views(db: _db): # noqa: PLR0915
"""
--------------------------------------------------
NOTE: Copy of `litellm/db_scripts/create_views.py`.
--------------------------------------------------
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
If the view doesn't exist, one will be created.
"""
try:
# Try to select one row from the view
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
print("LiteLLM_VerificationTokenView Exists!") # noqa
except Exception:
# If an error occurs, the view does not exist, so create it
await db.execute_raw(
"""
CREATE VIEW "LiteLLM_VerificationTokenView" AS
SELECT
v.*,
t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit
FROM "LiteLLM_VerificationToken" v
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
"""
)
print("LiteLLM_VerificationTokenView Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
print("MonthlyGlobalSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime");
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
print("Last30dKeysBySpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
SELECT
L."api_key",
V."key_alias",
V."key_name",
SUM(L."spend") AS total_spend
FROM
"LiteLLM_SpendLogs" L
LEFT JOIN
"LiteLLM_VerificationToken" V
ON
L."api_key" = V."token"
WHERE
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
L."api_key", V."key_alias", V."key_name"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dKeysBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
print("Last30dModelsBySpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
SELECT
"model",
SUM("spend") AS total_spend
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
AND "model" != ''
GROUP BY
"model"
ORDER BY
total_spend DESC;
"""
await db.execute_raw(query=sql_query)
print("Last30dModelsBySpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
print("MonthlyGlobalSpendPerKey Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerKey Created!") # noqa
try:
await db.query_raw(
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
)
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
SELECT
DATE("startTime") AS date,
SUM("spend") AS spend,
api_key as api_key,
"user" as "user"
FROM
"LiteLLM_SpendLogs"
WHERE
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
GROUP BY
DATE("startTime"),
"user",
api_key;
"""
await db.execute_raw(query=sql_query)
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
print("DailyTagSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE OR REPLACE VIEW "DailyTagSpend" AS
SELECT
jsonb_array_elements_text(request_tags) AS individual_request_tag,
DATE(s."startTime") AS spend_date,
COUNT(*) AS log_count,
SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs" s
GROUP BY individual_request_tag, DATE(s."startTime");
"""
await db.execute_raw(query=sql_query)
print("DailyTagSpend Created!") # noqa
try:
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
print("Last30dTopEndUsersSpend Exists!") # noqa
except Exception:
sql_query = """
CREATE VIEW "Last30dTopEndUsersSpend" AS
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
FROM "LiteLLM_SpendLogs"
WHERE end_user <> '' AND end_user <> user
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
GROUP BY end_user
ORDER BY total_spend DESC
LIMIT 100;
"""
await db.execute_raw(query=sql_query)
print("Last30dTopEndUsersSpend Created!") # noqa
return
async def should_create_missing_views(db: _db) -> bool:
"""
Run only on first time startup.
If SpendLogs table already has values, then don't create views on startup.
"""
sql_query = """
SELECT reltuples::BIGINT
FROM pg_class
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
"""
result = await db.query_raw(query=sql_query)
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
if (
result
and isinstance(result, list)
and len(result) > 0
and isinstance(result[0], dict)
and "reltuples" in result[0]
and result[0]["reltuples"]
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
):
verbose_logger.debug("Should create views")
return True
return False

View File

@@ -0,0 +1,62 @@
"""
Base class for in memory buffer for database transactions
"""
import asyncio
from typing import Optional
from litellm._logging import verbose_proxy_logger
from litellm._service_logger import ServiceLogging
service_logger_obj = (
ServiceLogging()
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
from litellm.constants import (
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT,
MAX_SIZE_IN_MEMORY_QUEUE,
)
class BaseUpdateQueue:
"""Base class for in memory buffer for database transactions"""
def __init__(self):
self.update_queue = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE
if MAX_SIZE_IN_MEMORY_QUEUE >= LITELLM_ASYNCIO_QUEUE_MAXSIZE:
verbose_proxy_logger.warning(
"Misconfigured queue thresholds: MAX_SIZE_IN_MEMORY_QUEUE (%d) >= LITELLM_ASYNCIO_QUEUE_MAXSIZE (%d). "
"The spend aggregation check will never trigger because the asyncio.Queue blocks at %d items. "
"Set MAX_SIZE_IN_MEMORY_QUEUE to a value less than LITELLM_ASYNCIO_QUEUE_MAXSIZE (recommended: 80%% of it).",
MAX_SIZE_IN_MEMORY_QUEUE,
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
)
async def add_update(self, update):
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
await self._emit_new_item_added_to_queue_event(
queue_size=self.update_queue.qsize()
)
async def flush_all_updates_from_in_memory_queue(self):
"""Get all updates from the queue."""
updates = []
while not self.update_queue.empty():
# Circuit breaker to ensure we're not stuck dequeuing updates. Protect CPU utilization
if len(updates) >= MAX_IN_MEMORY_QUEUE_FLUSH_COUNT:
verbose_proxy_logger.debug(
"Max in memory queue flush count reached, stopping flush"
)
break
updates.append(await self.update_queue.get())
return updates
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
"""placeholder, emit event when a new item is added to the queue"""
pass

View File

@@ -0,0 +1,155 @@
import asyncio
from copy import deepcopy
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
from litellm.proxy._types import BaseDailySpendTransaction
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class DailySpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for daily spend updates that should be committed to the database
To add a new daily spend update transaction, use the following format:
daily_spend_update_queue.add_update({
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
}
})
Queue contains a list of daily spend update transactions
eg
queue = [
{
"user1_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
},
{
"user2_date_api_key_model_custom_llm_provider": {
"spend": 10,
"prompt_tokens": 100,
"completion_tokens": 100,
"api_requests": 100,
"successful_requests": 100,
"failed_requests": 100,
}
}
]
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[
Dict[str, BaseDailySpendTransaction]
] = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
"""Enqueue an update."""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
verbose_proxy_logger.warning(
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
)
await self.aggregate_queue_updates()
async def aggregate_queue_updates(self):
"""
Combine all updates in the queue into a single update.
This is used to reduce the size of the in-memory queue.
"""
updates: List[
Dict[str, BaseDailySpendTransaction]
] = await self.flush_all_updates_from_in_memory_queue()
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
updates
)
await self.update_queue.put(aggregated_updates)
async def flush_and_get_aggregated_daily_spend_update_transactions(
self,
) -> Dict[str, BaseDailySpendTransaction]:
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key. Works for both user and team spend updates."""
updates = await self.flush_all_updates_from_in_memory_queue()
if len(updates) > 0:
verbose_proxy_logger.info(
"Spend tracking - flushed %d daily spend update items from in-memory queue",
len(updates),
)
aggregated_daily_spend_update_transactions = (
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
updates
)
)
verbose_proxy_logger.debug(
"Aggregated daily spend update transactions: %s",
aggregated_daily_spend_update_transactions,
)
return aggregated_daily_spend_update_transactions
@staticmethod
def get_aggregated_daily_spend_update_transactions(
updates: List[Dict[str, BaseDailySpendTransaction]],
) -> Dict[str, BaseDailySpendTransaction]:
"""Aggregate updates by daily_transaction_key."""
aggregated_daily_spend_update_transactions: Dict[
str, BaseDailySpendTransaction
] = {}
for _update in updates:
for _key, payload in _update.items():
if _key in aggregated_daily_spend_update_transactions:
daily_transaction = aggregated_daily_spend_update_transactions[_key]
daily_transaction["spend"] += payload["spend"]
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
daily_transaction["completion_tokens"] += payload[
"completion_tokens"
]
daily_transaction["api_requests"] += payload["api_requests"]
daily_transaction["successful_requests"] += payload[
"successful_requests"
]
daily_transaction["failed_requests"] += payload["failed_requests"]
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
daily_transaction["cache_read_input_tokens"] = (
payload.get("cache_read_input_tokens", 0) or 0
) + daily_transaction.get("cache_read_input_tokens", 0)
daily_transaction["cache_creation_input_tokens"] = (
payload.get("cache_creation_input_tokens", 0) or 0
) + daily_transaction.get("cache_creation_input_tokens", 0)
else:
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
return aggregated_daily_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,186 @@
import asyncio
from litellm._uuid import uuid
from typing import TYPE_CHECKING, Any, Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching.redis_cache import RedisCache
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
ProxyLogging = Any
else:
ProxyLogging = Any
class PodLockManager:
"""
Manager for acquiring and releasing locks for cron jobs using Redis.
Ensures that only one pod can run a cron job at a time.
"""
def __init__(self, redis_cache: Optional[RedisCache] = None):
self.pod_id = str(uuid.uuid4())
self.redis_cache = redis_cache
@staticmethod
def get_redis_lock_key(cronjob_id: str) -> str:
return f"cronjob_lock:{cronjob_id}"
async def acquire_lock(
self,
cronjob_id: str,
) -> Optional[bool]:
"""
Attempt to acquire the lock for a specific cron job using Redis.
Uses the SET command with NX and EX options to ensure atomicity.
Args:
cronjob_id: The ID of the cron job to lock
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
return None
try:
verbose_proxy_logger.debug(
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
# and with an expiration (EX) to avoid deadlocks.
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
acquired = await self.redis_cache.async_set_cache(
lock_key,
self.pod_id,
nx=True,
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
if acquired:
verbose_proxy_logger.info(
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
return True
else:
# Check if the current pod already holds the lock
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
verbose_proxy_logger.info(
"Pod %s already holds the Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
return True
else:
verbose_proxy_logger.info(
"Spend tracking - pod %s could not acquire lock for cronjob_id=%s, "
"held by pod %s. Spend updates in Redis will wait for the leader pod to commit.",
self.pod_id,
cronjob_id,
current_value,
)
return False
except Exception as e:
verbose_proxy_logger.error(
f"Error acquiring Redis lock for {cronjob_id}: {e}"
)
return False
async def release_lock(
self,
cronjob_id: str,
):
"""
Release the lock if the current pod holds it.
Uses get and delete commands to ensure that only the owner can release the lock.
"""
if self.redis_cache is None:
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
return
try:
cronjob_id = cronjob_id
verbose_proxy_logger.debug(
"Pod %s attempting to release Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
current_value = await self.redis_cache.async_get_cache(lock_key)
if current_value is not None:
if isinstance(current_value, bytes):
current_value = current_value.decode("utf-8")
if current_value == self.pod_id:
result = await self.redis_cache.async_delete_cache(lock_key)
if result == 1:
verbose_proxy_logger.info(
"Pod %s successfully released Redis lock for cronjob_id=%s",
self.pod_id,
cronjob_id,
)
self._emit_released_lock_event(
cronjob_id=cronjob_id,
pod_id=self.pod_id,
)
else:
verbose_proxy_logger.warning(
"Spend tracking - pod %s failed to release Redis lock for cronjob_id=%s. "
"Lock will expire after TTL=%ds.",
self.pod_id,
cronjob_id,
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
)
else:
verbose_proxy_logger.debug(
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
self.pod_id,
cronjob_id,
current_value,
)
else:
verbose_proxy_logger.debug(
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
self.pod_id,
cronjob_id,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error releasing Redis lock for {cronjob_id}: {e}"
)
@staticmethod
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_acquired_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 1,
},
)
)
@staticmethod
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.POD_LOCK_MANAGER,
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
call_type="_emit_released_lock_event",
event_metadata={
"gauge_labels": f"{cronjob_id}:{pod_id}",
"gauge_value": 0,
},
)
)

View File

@@ -0,0 +1,677 @@
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
import asyncio
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.constants import (
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
REDIS_UPDATE_BUFFER_KEY,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
DailyAgentSpendTransaction,
DailyEndUserSpendTransaction,
DailyOrganizationSpendTransaction,
DailyTagSpendTransaction,
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
from litellm.secret_managers.main import str_to_bool
from litellm.types.caching import (
RedisPipelineLpopOperation,
RedisPipelineRpushOperation,
)
from litellm.types.services import ServiceTypes
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
else:
PrismaClient = Any
class RedisUpdateBuffer:
"""
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
This is to prevent deadlocks and improve reliability
"""
def __init__(
self,
redis_cache: Optional[RedisCache] = None,
):
self.redis_cache = redis_cache
@staticmethod
def _should_commit_spend_updates_to_redis() -> bool:
"""
Checks if the Pod should commit spend updates to Redis
This setting enables buffering database transactions in Redis
to improve reliability and reduce database contention
"""
from litellm.proxy.proxy_server import general_settings
_use_redis_transaction_buffer: Optional[
Union[bool, str]
] = general_settings.get("use_redis_transaction_buffer", False)
if isinstance(_use_redis_transaction_buffer, str):
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
if _use_redis_transaction_buffer is None:
return False
return _use_redis_transaction_buffer
async def _store_transactions_in_redis(
self,
transactions: Any,
redis_key: str,
service_type: ServiceTypes,
) -> None:
"""
Helper method to store transactions in Redis and emit an event
Args:
transactions: The transactions to store
redis_key: The Redis key to store under
service_type: The service type for event emission
"""
if transactions is None or len(transactions) == 0:
return
list_of_transactions = [safe_dumps(transactions)]
if self.redis_cache is None:
return
try:
current_redis_buffer_size = await self.redis_cache.async_rpush(
key=redis_key,
values=list_of_transactions,
)
verbose_proxy_logger.debug(
"Spend tracking - pushed spend updates to Redis buffer. "
"redis_key=%s, buffer_size=%s",
redis_key,
current_redis_buffer_size,
)
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=current_redis_buffer_size,
service=service_type,
)
except Exception as e:
verbose_proxy_logger.error(
"Spend tracking - failed to push spend updates to Redis (redis_key=%s). "
"Error: %s",
redis_key,
str(e),
)
async def store_in_memory_spend_updates_in_redis(
self,
spend_update_queue: SpendUpdateQueue,
daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_org_spend_update_queue: DailySpendUpdateQueue,
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
daily_agent_spend_update_queue: DailySpendUpdateQueue,
daily_tag_spend_update_queue: DailySpendUpdateQueue,
):
"""
Stores the in-memory spend updates to Redis
Stores the following in memory data structures in Redis:
- SpendUpdateQueue - Key, User, Team, TeamMember, Org, EndUser Spend updates
- DailySpendUpdateQueue - Daily Spend updates Aggregate view
For SpendUpdateQueue:
Each transaction is a dict stored as following:
- key is the entity id
- value is the spend amount
```
Redis List:
key_list_transactions:
[
"0929880201": 1.2,
"0929880202": 0.01,
"0929880203": 0.001,
]
```
For DailySpendUpdateQueue:
Each transaction is a Dict[str, DailyUserSpendTransaction] stored as following:
- key is the daily_transaction_key
- value is the DailyUserSpendTransaction
```
Redis List:
daily_spend_update_transactions:
[
{
"user_keyhash_1_model_1": {
"spend": 1.2,
"prompt_tokens": 1000,
"completion_tokens": 1000,
"api_requests": 1000,
"successful_requests": 1000,
},
}
]
```
"""
if self.redis_cache is None:
verbose_proxy_logger.debug(
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
)
return
# Get all transactions
db_spend_update_transactions = (
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
)
daily_spend_update_transactions = (
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_team_spend_update_transactions = (
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_org_spend_update_transactions = (
await daily_org_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_end_user_spend_update_transactions = (
await daily_end_user_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_agent_spend_update_transactions = (
await daily_agent_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
daily_tag_spend_update_transactions = (
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
verbose_proxy_logger.debug(
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
)
verbose_proxy_logger.debug(
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
)
# Build a list of rpush operations, skipping empty/None transaction sets
_queue_configs: List[Tuple[Any, str, ServiceTypes]] = [
(
db_spend_update_transactions,
REDIS_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
),
(
daily_spend_update_transactions,
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
),
(
daily_team_spend_update_transactions,
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
),
(
daily_org_spend_update_transactions,
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_ORG_SPEND_UPDATE_QUEUE,
),
(
daily_end_user_spend_update_transactions,
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_END_USER_SPEND_UPDATE_QUEUE,
),
(
daily_agent_spend_update_transactions,
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_AGENT_SPEND_UPDATE_QUEUE,
),
(
daily_tag_spend_update_transactions,
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
),
]
rpush_list: List[RedisPipelineRpushOperation] = []
service_types: List[ServiceTypes] = []
for transactions, redis_key, service_type in _queue_configs:
if transactions is None or len(transactions) == 0:
continue
rpush_list.append(
RedisPipelineRpushOperation(
key=redis_key,
values=[safe_dumps(transactions)],
)
)
service_types.append(service_type)
if len(rpush_list) == 0:
return
result_lengths = await self.redis_cache.async_rpush_pipeline(
rpush_list=rpush_list,
)
# Emit gauge events for each queue
for i, queue_size in enumerate(result_lengths):
if i < len(service_types):
await self._emit_new_item_added_to_redis_buffer_event(
queue_size=queue_size,
service=service_types[i],
)
@staticmethod
def _number_of_transactions_to_store_in_redis(
db_spend_update_transactions: DBSpendUpdateTransactions,
) -> int:
"""
Gets the number of transactions to store in Redis
"""
num_transactions = 0
for v in db_spend_update_transactions.values():
if isinstance(v, dict):
num_transactions += len(v)
return num_transactions
@staticmethod
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
"""
Removes the specified prefix from the keys of a dictionary.
"""
return {key.replace(prefix, "", 1): value for key, value in data.items()}
async def get_all_update_transactions_from_redis_buffer(
self,
) -> Optional[DBSpendUpdateTransactions]:
"""
Gets all the update transactions from Redis
On Redis we store a list of transactions as a JSON string
eg.
[
DBSpendUpdateTransactions(
user_list_transactions={
"user_id_1": 1.2,
"user_id_2": 0.01,
},
end_user_list_transactions={},
key_list_transactions={
"0929880201": 1.2,
"0929880202": 0.01,
},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
),
DBSpendUpdateTransactions(
user_list_transactions={
"user_id_3": 1.2,
"user_id_4": 0.01,
},
end_user_list_transactions={},
key_list_transactions={
"key_id_1": 1.2,
"key_id_2": 0.01,
},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
]
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
verbose_proxy_logger.info(
"Spend tracking - popped %d spend update batches from Redis buffer (key=%s). "
"These items are now removed from Redis and must be committed to DB.",
len(list_of_transactions) if isinstance(list_of_transactions, list) else 1,
REDIS_UPDATE_BUFFER_KEY,
)
# Parse the list of transactions from JSON strings
parsed_transactions = self._parse_list_of_transactions(list_of_transactions)
# If there are no transactions, return None
if len(parsed_transactions) == 0:
return None
# Combine all transactions into a single transaction
combined_transaction = self._combine_list_of_transactions(parsed_transactions)
return combined_transaction
async def get_all_transactions_from_redis_buffer_pipeline(
self,
) -> Tuple[
Optional[DBSpendUpdateTransactions],
Optional[Dict[str, DailyUserSpendTransaction]],
Optional[Dict[str, DailyTeamSpendTransaction]],
Optional[Dict[str, DailyOrganizationSpendTransaction]],
Optional[Dict[str, DailyEndUserSpendTransaction]],
Optional[Dict[str, DailyAgentSpendTransaction]],
Optional[Dict[str, DailyTagSpendTransaction]],
]:
"""
Drains all 7 Redis buffer queues in a single pipeline round-trip.
Returns a 7-tuple of parsed results in this order:
0: DBSpendUpdateTransactions
1: daily user spend
2: daily team spend
3: daily org spend
4: daily end-user spend
5: daily agent spend
6: daily tag spend
"""
if self.redis_cache is None:
return None, None, None, None, None, None, None
lpop_list: List[RedisPipelineLpopOperation] = [
RedisPipelineLpopOperation(
key=REDIS_UPDATE_BUFFER_KEY, count=MAX_REDIS_BUFFER_DEQUEUE_COUNT
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
RedisPipelineLpopOperation(
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
),
]
raw_results = await self.redis_cache.async_lpop_pipeline(lpop_list=lpop_list)
# Pad with None if pipeline returned fewer results than expected
while len(raw_results) < 7:
raw_results.append(None)
# Slot 0: DBSpendUpdateTransactions
db_spend: Optional[DBSpendUpdateTransactions] = None
if raw_results[0] is not None:
parsed = self._parse_list_of_transactions(raw_results[0])
if len(parsed) > 0:
db_spend = self._combine_list_of_transactions(parsed)
# Slots 1-6: daily spend categories
daily_results: List[Optional[Dict[str, Any]]] = []
for slot in range(1, 7):
if raw_results[slot] is None:
daily_results.append(None)
else:
list_of_daily = [json.loads(t) for t in raw_results[slot]] # type: ignore
aggregated = DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily
)
daily_results.append(aggregated)
return (
db_spend,
cast(Optional[Dict[str, DailyUserSpendTransaction]], daily_results[0]),
cast(Optional[Dict[str, DailyTeamSpendTransaction]], daily_results[1]),
cast(
Optional[Dict[str, DailyOrganizationSpendTransaction]], daily_results[2]
),
cast(Optional[Dict[str, DailyEndUserSpendTransaction]], daily_results[3]),
cast(Optional[Dict[str, DailyAgentSpendTransaction]], daily_results[4]),
cast(Optional[Dict[str, DailyTagSpendTransaction]], daily_results[5]),
)
async def get_all_daily_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyUserSpendTransaction]]:
"""
Gets all the daily spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyUserSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_team_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyTeamSpendTransaction]]:
"""
Gets all the daily team spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyTeamSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_org_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyOrganizationSpendTransaction]]:
"""
Gets all the daily organization spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyOrganizationSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_end_user_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyEndUserSpendTransaction]]:
"""
Gets all the daily end-user spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyEndUserSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_agent_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyAgentSpendTransaction]]:
"""
Gets all the daily agent spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyAgentSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
async def get_all_daily_tag_spend_update_transactions_from_redis_buffer(
self,
) -> Optional[Dict[str, DailyTagSpendTransaction]]:
"""
Gets all the daily tag spend update transactions from Redis
"""
if self.redis_cache is None:
return None
list_of_transactions = await self.redis_cache.async_lpop(
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
)
if list_of_transactions is None:
return None
list_of_daily_spend_update_transactions = [
json.loads(transaction) for transaction in list_of_transactions
]
return cast(
Dict[str, DailyTagSpendTransaction],
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
list_of_daily_spend_update_transactions
),
)
@staticmethod
def _parse_list_of_transactions(
list_of_transactions: Union[Any, List[Any]],
) -> List[DBSpendUpdateTransactions]:
"""
Parses the list of transactions from Redis
"""
if isinstance(list_of_transactions, list):
return [json.loads(transaction) for transaction in list_of_transactions]
else:
return [json.loads(list_of_transactions)]
@staticmethod
def _combine_list_of_transactions(
list_of_transactions: List[DBSpendUpdateTransactions],
) -> DBSpendUpdateTransactions:
"""
Combines the list of transactions into a single DBSpendUpdateTransactions object
"""
# Initialize a new combined transaction object with empty dictionaries
combined_transaction = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
tag_list_transactions={},
agent_list_transactions={},
)
# Define the transaction fields to process
transaction_fields = [
"user_list_transactions",
"end_user_list_transactions",
"key_list_transactions",
"team_list_transactions",
"team_member_list_transactions",
"org_list_transactions",
"tag_list_transactions",
"agent_list_transactions",
]
# Loop through each transaction and combine the values
for transaction in list_of_transactions:
# Process each field type
for field in transaction_fields:
if transaction.get(field):
for entity_id, amount in transaction[field].items(): # type: ignore
combined_transaction[field][entity_id] = ( # type: ignore
combined_transaction[field].get(entity_id, 0) + amount # type: ignore
)
return combined_transaction
async def _emit_new_item_added_to_redis_buffer_event(
self,
service: ServiceTypes,
queue_size: int,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=service,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": service,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,172 @@
import asyncio
from datetime import datetime, timedelta, timezone
from typing import Optional
from litellm._logging import verbose_proxy_logger
from litellm.caching import RedisCache
from litellm.constants import (
SPEND_LOG_CLEANUP_BATCH_SIZE,
SPEND_LOG_CLEANUP_JOB_NAME,
SPEND_LOG_RUN_LOOPS,
)
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy.utils import PrismaClient
class SpendLogCleanup:
"""
Handles cleaning up old spend logs based on maximum retention period.
Deletes logs in batches to prevent timeouts.
Uses PodLockManager to ensure only one pod runs cleanup in multi-pod deployments.
"""
def __init__(self, general_settings=None, redis_cache: Optional[RedisCache] = None):
self.batch_size = SPEND_LOG_CLEANUP_BATCH_SIZE
self.retention_seconds: Optional[int] = None
from litellm.proxy.proxy_server import general_settings as default_settings
self.general_settings = general_settings or default_settings
from litellm.proxy.proxy_server import proxy_logging_obj
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
self.pod_lock_manager = pod_lock_manager
verbose_proxy_logger.info(
f"SpendLogCleanup initialized with batch size: {self.batch_size}"
)
def _should_delete_spend_logs(self) -> bool:
"""
Determines if logs should be deleted based on the max retention period in settings.
"""
retention_setting = self.general_settings.get(
"maximum_spend_logs_retention_period"
)
verbose_proxy_logger.info(f"Checking retention setting: {retention_setting}")
if retention_setting is None:
verbose_proxy_logger.info("No retention setting found")
return False
try:
if isinstance(retention_setting, int):
verbose_proxy_logger.warning(
f"maximum_spend_logs_retention_period is an integer ({retention_setting}); treating as days. "
"Use a string like '3d' to be explicit."
)
retention_setting = f"{retention_setting}d"
self.retention_seconds = duration_in_seconds(retention_setting)
verbose_proxy_logger.info(
f"Retention period set to {self.retention_seconds} seconds"
)
return True
except ValueError as e:
verbose_proxy_logger.warning(
f"Invalid maximum_spend_logs_retention_period value: {retention_setting}, error: {str(e)}"
)
return False
async def _delete_old_logs(
self, prisma_client: PrismaClient, cutoff_date: datetime
) -> int:
"""
Helper method to delete old logs in batches.
Returns the total number of logs deleted.
"""
total_deleted = 0
run_count = 0
while True:
if run_count > SPEND_LOG_RUN_LOOPS:
verbose_proxy_logger.info(
"Max logs deleted - 1,00,000, rest of the logs will be deleted in next run"
)
break
# Step 1: Find logs and delete them in one go without fetching to application
# Delete in batches, limited by self.batch_size
deleted_count = await prisma_client.db.execute_raw(
"""
DELETE FROM "LiteLLM_SpendLogs"
WHERE "request_id" IN (
SELECT "request_id" FROM "LiteLLM_SpendLogs"
WHERE "startTime" < $1::timestamptz
LIMIT $2
)
""",
cutoff_date,
self.batch_size,
)
verbose_proxy_logger.info(f"Deleted {deleted_count} logs in this batch")
if deleted_count == 0:
verbose_proxy_logger.info(
f"No more logs to delete. Total deleted: {total_deleted}"
)
break
total_deleted += deleted_count
run_count += 1
# Add a small sleep to prevent overwhelming the database
await asyncio.sleep(0.1)
return total_deleted
async def cleanup_old_spend_logs(self, prisma_client: PrismaClient) -> None:
"""
Main cleanup function. Deletes old spend logs in batches.
If pod_lock_manager is available, ensures only one pod runs cleanup.
If no pod_lock_manager, runs cleanup without distributed locking.
"""
lock_acquired = False
try:
verbose_proxy_logger.info(f"Cleanup job triggered at {datetime.now()}")
if not self._should_delete_spend_logs():
return
if self.retention_seconds is None:
verbose_proxy_logger.error(
"Retention seconds is None, cannot proceed with cleanup"
)
return
# If we have a pod lock manager, try to acquire the lock
if self.pod_lock_manager and self.pod_lock_manager.redis_cache:
lock_acquired = (
await self.pod_lock_manager.acquire_lock(
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME,
)
or False
)
verbose_proxy_logger.info(
f"Lock acquisition attempt: {'successful' if lock_acquired else 'failed'} at {datetime.now()}"
)
if not lock_acquired:
verbose_proxy_logger.info("Another pod is already running cleanup")
return
cutoff_date = datetime.now(timezone.utc) - timedelta(
seconds=float(self.retention_seconds)
)
verbose_proxy_logger.info(
f"Deleting logs older than {cutoff_date.isoformat()}"
)
# Perform the actual deletion
total_deleted = await self._delete_old_logs(prisma_client, cutoff_date)
verbose_proxy_logger.info(f"Deleted {total_deleted} logs")
except Exception as e:
verbose_proxy_logger.error(f"Error during cleanup: {str(e)}")
return # Return after error handling
finally:
# Only release the lock if it was actually acquired
if (
lock_acquired
and self.pod_lock_manager
and self.pod_lock_manager.redis_cache
):
await self.pod_lock_manager.release_lock(
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME
)
verbose_proxy_logger.info("Released cleanup lock")

View File

@@ -0,0 +1,246 @@
import asyncio
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
from litellm.proxy._types import (
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
BaseUpdateQueue,
service_logger_obj,
)
from litellm.types.services import ServiceTypes
class SpendUpdateQueue(BaseUpdateQueue):
"""
In memory buffer for spend updates that should be committed to the database
"""
def __init__(self):
super().__init__()
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue(
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
)
async def flush_and_get_aggregated_db_spend_update_transactions(
self,
) -> DBSpendUpdateTransactions:
"""Flush all updates from the queue and return all updates aggregated by entity type."""
updates = await self.flush_all_updates_from_in_memory_queue()
if len(updates) > 0:
verbose_proxy_logger.info(
"Spend tracking - flushed %d spend update items from in-memory queue",
len(updates),
)
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
return self.get_aggregated_db_spend_update_transactions(updates)
async def add_update(self, update: SpendUpdateQueueItem):
"""Enqueue an update to the spend update queue"""
verbose_proxy_logger.debug("Adding update to queue: %s", update)
await self.update_queue.put(update)
# if the queue is full, aggregate the updates
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
verbose_proxy_logger.warning(
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
)
await self.aggregate_queue_updates()
async def aggregate_queue_updates(self):
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
updates: List[
SpendUpdateQueueItem
] = await self.flush_all_updates_from_in_memory_queue()
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
for update in aggregated_updates:
await self.update_queue.put(update)
return
def _get_aggregated_spend_update_queue_item(
self, updates: List[SpendUpdateQueueItem]
) -> List[SpendUpdateQueueItem]:
"""
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
Aggregate updates by entity type + id
eg.
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 100
},
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 200
}
]
```
becomes
```
[
{
"entity_type": "user",
"entity_id": "123",
"response_cost": 300
}
]
```
"""
verbose_proxy_logger.debug(
"Aggregating spend updates, current queue size: %s",
self.update_queue.qsize(),
)
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
"""
Used for combining several updates into a single update
Key=entity_type:entity_id
Value=SpendUpdateQueueItem
"""
for update in updates:
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
if _key not in _in_memory_map:
# avoid mutating caller-owned dicts while aggregating queue entries
_in_memory_map[_key] = update.copy()
else:
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
update_cost = update.get("response_cost", 0) or 0
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
for _key, update in _in_memory_map.items():
aggregated_spend_updates.append(update)
verbose_proxy_logger.debug(
"Aggregated spend updates: %s", aggregated_spend_updates
)
return aggregated_spend_updates
def get_aggregated_db_spend_update_transactions(
self, updates: List[SpendUpdateQueueItem]
) -> DBSpendUpdateTransactions:
"""Aggregate updates by entity type."""
# Initialize all transaction lists as empty dicts
db_spend_update_transactions = DBSpendUpdateTransactions(
user_list_transactions={},
end_user_list_transactions={},
key_list_transactions={},
team_list_transactions={},
team_member_list_transactions={},
org_list_transactions={},
tag_list_transactions={},
agent_list_transactions={},
)
# Map entity types to their corresponding transaction dictionary keys
entity_type_to_dict_key = {
Litellm_EntityType.USER: "user_list_transactions",
Litellm_EntityType.END_USER: "end_user_list_transactions",
Litellm_EntityType.KEY: "key_list_transactions",
Litellm_EntityType.TEAM: "team_list_transactions",
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
Litellm_EntityType.TAG: "tag_list_transactions",
Litellm_EntityType.AGENT: "agent_list_transactions",
}
for update in updates:
entity_type = update.get("entity_type")
entity_id = update.get("entity_id") or ""
response_cost = update.get("response_cost") or 0
if entity_type is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is None",
update,
)
continue
dict_key = entity_type_to_dict_key.get(entity_type)
if dict_key is None:
verbose_proxy_logger.debug(
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
update,
)
continue # Skip unknown entity types
# Type-safe access using if/elif statements
if dict_key == "user_list_transactions":
transactions_dict = db_spend_update_transactions[
"user_list_transactions"
]
elif dict_key == "end_user_list_transactions":
transactions_dict = db_spend_update_transactions[
"end_user_list_transactions"
]
elif dict_key == "key_list_transactions":
transactions_dict = db_spend_update_transactions[
"key_list_transactions"
]
elif dict_key == "team_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_list_transactions"
]
elif dict_key == "team_member_list_transactions":
transactions_dict = db_spend_update_transactions[
"team_member_list_transactions"
]
elif dict_key == "org_list_transactions":
transactions_dict = db_spend_update_transactions[
"org_list_transactions"
]
elif dict_key == "tag_list_transactions":
transactions_dict = db_spend_update_transactions[
"tag_list_transactions"
]
elif dict_key == "agent_list_transactions":
transactions_dict = db_spend_update_transactions[
"agent_list_transactions"
]
else:
continue
if transactions_dict is None:
transactions_dict = {}
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
if entity_id not in transactions_dict:
transactions_dict[entity_id] = 0
transactions_dict[entity_id] += response_cost or 0
return db_spend_update_transactions
async def _emit_new_item_added_to_queue_event(
self,
queue_size: Optional[int] = None,
):
asyncio.create_task(
service_logger_obj.async_service_success_hook(
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
duration=0,
call_type="_emit_new_item_added_to_queue_event",
event_metadata={
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
"gauge_value": queue_size,
},
)
)

View File

@@ -0,0 +1,54 @@
"""
In-memory buffer for tool registry upserts.
Unlike SpendUpdateQueue (which aggregates increments), ToolDiscoveryQueue
uses set-deduplication: each unique tool_name is only queued once per flush
cycle (~30s). The seen-set is cleared on every flush so that call_count
increments in subsequent cycles rather than stopping after the first flush.
"""
from typing import List, Set
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem
class ToolDiscoveryQueue:
"""
In-memory buffer for tool registry upserts.
Deduplicates by tool_name within each flush cycle: a tool is only queued
once per ~30s batch, so call_count increments once per flush cycle the
tool appears in (not once per invocation, but not once per pod lifetime
either). The seen-set is cleared on flush so subsequent batches can
re-count the same tool.
"""
def __init__(self) -> None:
self._seen_tool_names: Set[str] = set()
self._pending: List[ToolDiscoveryQueueItem] = []
def add_update(self, item: ToolDiscoveryQueueItem) -> None:
"""Enqueue a tool discovery item if tool_name has not been seen before."""
tool_name = item.get("tool_name", "")
if not tool_name:
return
if tool_name in self._seen_tool_names:
verbose_proxy_logger.debug(
"ToolDiscoveryQueue: skipping already-seen tool %s", tool_name
)
return
self._seen_tool_names.add(tool_name)
self._pending.append(item)
verbose_proxy_logger.debug(
"ToolDiscoveryQueue: queued new tool %s (origin=%s)",
tool_name,
item.get("origin"),
)
def flush(self) -> List[ToolDiscoveryQueueItem]:
"""Return and clear all pending items. Resets seen-set so the next
flush cycle can re-count the same tools."""
items, self._pending = self._pending, []
self._seen_tool_names.clear()
return items

View File

@@ -0,0 +1,71 @@
"""
Deprecated. Only PostgresSQL is supported.
"""
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import DynamoDBArgs
from litellm.proxy.db.base_client import CustomDB
class DynamoDBWrapper(CustomDB):
from aiodynamo.credentials import Credentials, StaticCredentials
credentials: Credentials
def __init__(self, database_arguments: DynamoDBArgs):
from aiodynamo.models import PayPerRequest, Throughput
self.throughput_type = None
if database_arguments.billing_mode == "PAY_PER_REQUEST":
self.throughput_type = PayPerRequest()
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
if (
database_arguments.read_capacity_units is not None
and isinstance(database_arguments.read_capacity_units, int)
and database_arguments.write_capacity_units is not None
and isinstance(database_arguments.write_capacity_units, int)
):
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
else:
raise Exception(
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
)
self.database_arguments = database_arguments
self.region_name = database_arguments.region_name
def set_env_vars_based_on_arn(self):
if self.database_arguments.aws_role_name is None:
return
verbose_proxy_logger.debug(
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
)
import os
import boto3
sts_client = boto3.client("sts")
# call 1
sts_client.assume_role_with_web_identity(
RoleArn=self.database_arguments.aws_role_name,
RoleSessionName=self.database_arguments.aws_session_name,
WebIdentityToken=self.database_arguments.aws_web_identity_token,
)
# call 2
assumed_role = sts_client.assume_role(
RoleArn=self.database_arguments.assume_role_aws_role_name,
RoleSessionName=self.database_arguments.assume_role_aws_session_name,
)
aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
aws_session_token = assumed_role["Credentials"]["SessionToken"]
verbose_proxy_logger.debug(
f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
)
# set these in the env so aiodynamo can use them
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
os.environ["AWS_SESSION_TOKEN"] = aws_session_token

View File

@@ -0,0 +1,106 @@
from typing import Union
from litellm.proxy._types import (
DB_CONNECTION_ERROR_TYPES,
ProxyErrorTypes,
ProxyException,
)
from litellm.secret_managers.main import str_to_bool
class PrismaDBExceptionHandler:
"""
Class to handle DB Exceptions or Connection Errors
"""
@staticmethod
def should_allow_request_on_db_unavailable() -> bool:
"""
Returns True if the request should be allowed to proceed despite the DB connection error
"""
from litellm.proxy.proxy_server import general_settings
_allow_requests_on_db_unavailable: Union[bool, str] = general_settings.get(
"allow_requests_on_db_unavailable", False
)
if isinstance(_allow_requests_on_db_unavailable, bool):
return _allow_requests_on_db_unavailable
if str_to_bool(_allow_requests_on_db_unavailable) is True:
return True
return False
@staticmethod
def is_database_connection_error(e: Exception) -> bool:
"""
Returns True if the exception is from a database outage / connection error.
Any PrismaError qualifies — the DB failed to serve the request.
Used by allow_requests_on_db_unavailable logic and endpoint 503 responses.
"""
import prisma
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
return True
if isinstance(e, prisma.errors.PrismaError):
return True
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
return True
return False
@staticmethod
def is_database_transport_error(e: Exception) -> bool:
"""
Returns True only for transport/connectivity failures where a reconnect
attempt makes sense (e.g. DB is unreachable, connection dropped).
Use this for reconnect logic — data-layer errors like UniqueViolationError
mean the DB IS reachable, so reconnecting would be pointless.
"""
import prisma
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
return True
if isinstance(
e,
(
prisma.errors.ClientNotConnectedError,
prisma.errors.HTTPClientClosedError,
),
):
return True
if isinstance(e, prisma.errors.PrismaError):
error_message = str(e).lower()
connection_keywords = (
"can't reach database server",
"cannot reach database server",
"can't connect",
"cannot connect",
"connection error",
"connection closed",
"timed out",
"timeout",
"connection refused",
"network is unreachable",
"no route to host",
"broken pipe",
)
if any(keyword in error_message for keyword in connection_keywords):
return True
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
return True
return False
@staticmethod
def handle_db_exception(e: Exception):
"""
Primary handler for `allow_requests_on_db_unavailable` flag. Decides whether to raise a DB Exception or not based on the flag.
- If exception is a DB Connection Error, and `allow_requests_on_db_unavailable` is True,
- Do not raise an exception, return None
- Else, raise the exception
"""
if (
PrismaDBExceptionHandler.is_database_connection_error(e)
and PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
):
return None
raise e

View File

@@ -0,0 +1,142 @@
"""
Handles logging DB success/failure to ServiceLogger()
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
"""
import asyncio
from datetime import datetime
from functools import wraps
from typing import Callable, Dict, Tuple
from litellm._service_logger import ServiceTypes
from litellm.litellm_core_utils.core_helpers import (
_get_parent_otel_span_from_kwargs,
get_litellm_metadata_from_kwargs,
)
def log_db_metrics(func):
"""
Decorator to log the duration of a DB related function to ServiceLogger()
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
Args:
func: The function to be decorated
Returns:
Result from the decorated function
Raises:
Exception: If the decorated function raises an exception
"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time: datetime = datetime.now()
try:
result = await func(*args, **kwargs)
end_time: datetime = datetime.now()
from litellm.proxy.proxy_server import proxy_logging_obj
if "PROXY" not in func.__name__:
asyncio.create_task(
proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span", None),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)
)
elif (
# in litellm custom callbacks kwargs is passed as arg[0]
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
args is not None
and len(args) > 1
and isinstance(args[1], dict)
):
passed_kwargs = args[1]
parent_otel_span = _get_parent_otel_span_from_kwargs(
kwargs=passed_kwargs
)
if parent_otel_span is not None:
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
asyncio.create_task(
proxy_logging_obj.service_logging_obj.async_service_success_hook(
service=ServiceTypes.BATCH_WRITE_TO_DB,
call_type=func.__name__,
parent_otel_span=parent_otel_span,
duration=0.0,
start_time=start_time,
end_time=end_time,
event_metadata=metadata,
)
)
# end of logging to otel
return result
except Exception as e:
end_time: datetime = datetime.now()
await _handle_logging_db_exception(
e=e,
func=func,
kwargs=kwargs,
args=args,
start_time=start_time,
end_time=end_time,
)
raise e
return wrapper
def _is_exception_related_to_db(e: Exception) -> bool:
"""
Returns True if the exception is related to the DB
"""
import httpx
from prisma.errors import PrismaError
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
async def _handle_logging_db_exception(
e: Exception,
func: Callable,
kwargs: Dict,
args: Tuple,
start_time: datetime,
end_time: datetime,
) -> None:
from litellm.proxy.proxy_server import proxy_logging_obj
# don't log this as a DB Service Failure, if the DB did not raise an exception
if _is_exception_related_to_db(e) is not True:
return
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
error=e,
service=ServiceTypes.DB,
call_type=func.__name__,
parent_otel_span=kwargs.get("parent_otel_span"),
duration=(end_time - start_time).total_seconds(),
start_time=start_time,
end_time=end_time,
event_metadata={
"function_name": func.__name__,
"function_kwargs": kwargs,
"function_args": args,
},
)

View File

@@ -0,0 +1,438 @@
"""
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
"""
import asyncio
import os
import random
import subprocess
import time
import urllib
import urllib.parse
from datetime import datetime, timedelta
from typing import Any, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool
class PrismaWrapper:
"""
Wrapper around Prisma client that handles RDS IAM token authentication.
When iam_token_db_auth is enabled, this wrapper:
1. Proactively refreshes IAM tokens before they expire (background task)
2. Falls back to synchronous refresh if a token is found expired
3. Uses proper locking to prevent race conditions during reconnection
RDS IAM tokens are valid for 15 minutes. This wrapper refreshes them
3 minutes before expiration to ensure uninterrupted database connectivity.
"""
# Buffer time in seconds before token expiration to trigger refresh
# Refresh 3 minutes (180 seconds) before the token expires
TOKEN_REFRESH_BUFFER_SECONDS = 180
# Fallback refresh interval if token parsing fails (10 minutes)
FALLBACK_REFRESH_INTERVAL_SECONDS = 600
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
self._original_prisma = original_prisma
self.iam_token_db_auth = iam_token_db_auth
# Background token refresh task management
self._token_refresh_task: Optional[asyncio.Task] = None
self._reconnection_lock = asyncio.Lock()
self._last_refresh_time: Optional[datetime] = None
def _extract_token_from_db_url(self, db_url: Optional[str]) -> Optional[str]:
"""
Extract the token (password) from the DATABASE_URL.
The token contains the AWS signature with X-Amz-Date and X-Amz-Expires parameters.
Important: We must parse the URL while it's still encoded to preserve structure,
then decode the password portion. Otherwise the '?' in the token breaks URL parsing.
"""
if db_url is None:
return None
try:
# Parse URL while still encoded to preserve structure
parsed = urllib.parse.urlparse(db_url)
if parsed.password:
# Now decode just the password/token
return urllib.parse.unquote(parsed.password)
return None
except Exception:
return None
def _parse_token_expiration(self, token: Optional[str]) -> Optional[datetime]:
"""
Parse the token to extract its expiration time.
Returns the datetime when the token expires, or None if parsing fails.
"""
if token is None:
return None
try:
# Token format: ...?X-Amz-Date=YYYYMMDDTHHMMSSZ&X-Amz-Expires=900&...
if "?" not in token:
return None
query_string = token.split("?", 1)[1]
params = urllib.parse.parse_qs(query_string)
expires_str = params.get("X-Amz-Expires", [None])[0]
date_str = params.get("X-Amz-Date", [None])[0]
if not expires_str or not date_str:
return None
token_created = datetime.strptime(date_str, "%Y%m%dT%H%M%SZ")
expires_in = int(expires_str)
return token_created + timedelta(seconds=expires_in)
except Exception as e:
verbose_proxy_logger.debug(f"Failed to parse token expiration: {e}")
return None
def _calculate_seconds_until_refresh(self) -> float:
"""
Calculate exactly how many seconds until we need to refresh the token.
Uses precise timing: sleeps until (token_expiration - buffer_seconds).
For a 15-minute (900s) token with 180s buffer, this returns ~720s (12 min).
Returns:
Number of seconds to sleep before the next refresh.
Returns 0 if token should be refreshed immediately.
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
"""
db_url = os.getenv("DATABASE_URL")
token = self._extract_token_from_db_url(db_url)
expiration_time = self._parse_token_expiration(token)
if expiration_time is None:
# If we can't parse the token, use fallback interval
verbose_proxy_logger.debug(
f"Could not parse token expiration, using fallback interval of "
f"{self.FALLBACK_REFRESH_INTERVAL_SECONDS}s"
)
return self.FALLBACK_REFRESH_INTERVAL_SECONDS
# Calculate when we should refresh (expiration - buffer)
refresh_at = expiration_time - timedelta(
seconds=self.TOKEN_REFRESH_BUFFER_SECONDS
)
# How long until refresh time?
now = datetime.utcnow()
seconds_until_refresh = (refresh_at - now).total_seconds()
# If already past refresh time, return 0 (refresh immediately)
return max(0, seconds_until_refresh)
def is_token_expired(self, token_url: Optional[str]) -> bool:
"""Check if the token in the given URL is expired."""
if token_url is None:
return True
token = self._extract_token_from_db_url(token_url)
expiration_time = self._parse_token_expiration(token)
if expiration_time is None:
# If we can't parse the token, assume it's expired to trigger refresh
verbose_proxy_logger.debug(
"Could not parse token expiration, treating as expired"
)
return True
return datetime.utcnow() > expiration_time
def get_rds_iam_token(self) -> Optional[str]:
"""Generate a new RDS IAM token and update DATABASE_URL."""
if self.iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT")
db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA")
token = generate_iam_auth_token(
db_host=db_host, db_port=db_port, db_user=db_user
)
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
if db_schema:
_db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url
return _db_url
return None
async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None
):
"""Disconnect and reconnect the Prisma client with a new database URL."""
from prisma import Prisma # type: ignore
try:
await self._original_prisma.disconnect()
except Exception as e:
verbose_proxy_logger.warning(f"Failed to disconnect Prisma client: {e}")
if http_client is not None:
self._original_prisma = Prisma(http=http_client)
else:
self._original_prisma = Prisma()
await self._original_prisma.connect()
async def start_token_refresh_task(self) -> None:
"""
Start the background token refresh task.
This task proactively refreshes RDS IAM tokens before they expire,
preventing connection failures. Should be called after the initial
Prisma client connection is established.
"""
if not self.iam_token_db_auth:
verbose_proxy_logger.debug(
"IAM token auth not enabled, skipping token refresh task"
)
return
if self._token_refresh_task is not None:
verbose_proxy_logger.debug("Token refresh task already running")
return
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
verbose_proxy_logger.info(
"Started RDS IAM token proactive refresh background task"
)
async def stop_token_refresh_task(self) -> None:
"""
Stop the background token refresh task gracefully.
Should be called during application shutdown to clean up resources.
"""
if self._token_refresh_task is None:
return
self._token_refresh_task.cancel()
try:
await self._token_refresh_task
except asyncio.CancelledError:
pass
self._token_refresh_task = None
verbose_proxy_logger.info("Stopped RDS IAM token refresh background task")
async def _token_refresh_loop(self) -> None:
"""
Background loop that proactively refreshes RDS IAM tokens before expiration.
Uses precise timing: calculates the exact sleep duration until the token
needs to be refreshed (expiration - 3 minute buffer), then refreshes.
This is more efficient than polling, requiring only 1 wake-up per token cycle.
"""
verbose_proxy_logger.info(
f"RDS IAM token refresh loop started. "
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
)
while True:
try:
# Calculate exactly how long to sleep until next refresh
sleep_seconds = self._calculate_seconds_until_refresh()
if sleep_seconds > 0:
verbose_proxy_logger.info(
f"RDS IAM token refresh scheduled in {sleep_seconds:.0f} seconds "
f"({sleep_seconds / 60:.1f} minutes)"
)
await asyncio.sleep(sleep_seconds)
# Refresh the token
verbose_proxy_logger.info("Proactively refreshing RDS IAM token...")
await self._safe_refresh_token()
except asyncio.CancelledError:
verbose_proxy_logger.info("RDS IAM token refresh loop cancelled")
break
except Exception as e:
verbose_proxy_logger.error(
f"Error in RDS IAM token refresh loop: {e}. "
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
)
# On error, wait before retrying to avoid tight error loops
try:
await asyncio.sleep(self.FALLBACK_REFRESH_INTERVAL_SECONDS)
except asyncio.CancelledError:
break
async def _safe_refresh_token(self) -> None:
"""
Refresh the RDS IAM token with proper locking to prevent race conditions.
Uses an asyncio lock to ensure only one refresh operation happens at a time,
preventing multiple concurrent reconnection attempts.
"""
async with self._reconnection_lock:
new_db_url = self.get_rds_iam_token()
if new_db_url:
await self.recreate_prisma_client(new_db_url)
self._last_refresh_time = datetime.utcnow()
verbose_proxy_logger.info(
"RDS IAM token refreshed successfully. New token valid for ~15 minutes."
)
else:
verbose_proxy_logger.error(
"Failed to generate new RDS IAM token during proactive refresh"
)
def __getattr__(self, name: str):
"""
Proxy attribute access to the underlying Prisma client.
If IAM token auth is enabled and the token is expired, this method
provides a synchronous fallback to refresh the token. However, this
should rarely be needed since the background task proactively refreshes
tokens before they expire.
FIXED: Now properly waits for reconnection to complete before returning,
instead of the previous fire-and-forget pattern that caused the bug.
"""
original_attr = getattr(self._original_prisma, name)
if self.iam_token_db_auth:
db_url = os.getenv("DATABASE_URL")
# Check if token is expired (should be rare if background task is running)
if self.is_token_expired(db_url):
verbose_proxy_logger.warning(
"RDS IAM token expired in __getattr__ - proactive refresh may have failed. "
"Triggering synchronous fallback refresh..."
)
new_db_url = self.get_rds_iam_token()
if new_db_url:
loop = asyncio.get_event_loop()
if loop.is_running():
# FIXED: Actually wait for the reconnection to complete!
# The previous code used fire-and-forget which caused the bug.
future = asyncio.run_coroutine_threadsafe(
self.recreate_prisma_client(new_db_url), loop
)
try:
# Wait up to 30 seconds for reconnection
future.result(timeout=30)
verbose_proxy_logger.info(
"Synchronous token refresh completed successfully"
)
except Exception as e:
verbose_proxy_logger.error(
f"Failed to refresh token synchronously: {e}"
)
raise
else:
asyncio.run(self.recreate_prisma_client(new_db_url))
# Get the NEW attribute after reconnection
original_attr = getattr(self._original_prisma, name)
else:
raise ValueError("Failed to get RDS IAM token")
return original_attr
class PrismaManager:
@staticmethod
def _get_prisma_dir() -> str:
"""Get the path to the migrations directory"""
abspath = os.path.abspath(__file__)
dname = os.path.dirname(os.path.dirname(abspath))
return dname
@staticmethod
def setup_database(use_migrate: bool = False) -> bool:
"""
Set up the database using either prisma migrate or prisma db push
Returns:
bool: True if setup was successful, False otherwise
"""
for attempt in range(4):
original_dir = os.getcwd()
prisma_dir = PrismaManager._get_prisma_dir()
os.chdir(prisma_dir)
try:
if use_migrate:
try:
from litellm_proxy_extras.utils import ProxyExtrasDBManager
except ImportError as e:
verbose_proxy_logger.error(
f"\033[1;31mLiteLLM: Failed to import proxy extras. Got {e}\033[0m"
)
return False
prisma_dir = PrismaManager._get_prisma_dir()
return ProxyExtrasDBManager.setup_database(use_migrate=use_migrate)
else:
# Use prisma db push with increased timeout
subprocess.run(
["prisma", "db", "push", "--accept-data-loss"],
timeout=60,
check=True,
)
return True
except subprocess.TimeoutExpired:
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
time.sleep(random.randrange(5, 15))
except subprocess.CalledProcessError as e:
attempts_left = 3 - attempt
retry_msg = (
f" Retrying... ({attempts_left} attempts left)"
if attempts_left > 0
else ""
)
verbose_proxy_logger.warning(
f"The process failed to execute. Details: {e}.{retry_msg}"
)
time.sleep(random.randrange(5, 15))
finally:
os.chdir(original_dir)
return False
def should_update_prisma_schema(
disable_updates: Optional[Union[bool, str]] = None,
) -> bool:
"""
Determines if Prisma Schema updates should be applied during startup.
Args:
disable_updates: Controls whether schema updates are disabled.
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
Returns:
bool: True if schema updates should be applied, False if updates are disabled.
Examples:
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
>>> should_update_prisma_schema(True) # Explicitly disable updates
>>> should_update_prisma_schema("false") # Enable updates using string
"""
if disable_updates is None:
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
if isinstance(disable_updates, str):
disable_updates = str_to_bool(disable_updates)
return not bool(disable_updates)

View File

@@ -0,0 +1,151 @@
"""
Track tool usage for the dashboard: insert into SpendLogToolIndex when spend logs
are written, so "last N requests for tool X" and "how is this tool called in production"
queries are fast.
"""
from datetime import datetime, timezone
from typing import Any, Dict, List, Set
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
from litellm.proxy.utils import PrismaClient
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
"""Extract tool names from OpenAI-style tool_calls list into out."""
if not isinstance(tool_calls, list):
return
for tc in tool_calls:
if not isinstance(tc, dict):
continue
fn = tc.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if name and isinstance(name, str) and name.strip():
out.add(name.strip())
def _parse_tool_names_from_payload(payload: Dict[str, Any]) -> Set[str]:
"""
Extract deduplicated tool names from a spend log payload.
Sources: mcp_namespaced_tool_name, response (tool_calls), proxy_server_request (tools).
"""
tool_names: Set[str] = set()
# Top-level MCP tool name (single tool per request for that flow)
mcp_name = payload.get("mcp_namespaced_tool_name")
if mcp_name and isinstance(mcp_name, str) and mcp_name.strip():
tool_names.add(mcp_name.strip())
# Response: OpenAI-style tool_calls[].function.name or choices[0].message.tool_calls
response_raw = payload.get("response")
if response_raw:
response_obj = (
safe_json_loads(response_raw, default=None)
if isinstance(response_raw, str)
else response_raw
)
if isinstance(response_obj, dict):
_add_tool_calls_to_set(response_obj.get("tool_calls"), tool_names)
choices = response_obj.get("choices")
if isinstance(choices, list) and choices:
msg = (
choices[0].get("message") if isinstance(choices[0], dict) else None
)
if isinstance(msg, dict):
_add_tool_calls_to_set(msg.get("tool_calls"), tool_names)
# Request body: tools[].function.name
request_raw = payload.get("proxy_server_request")
if request_raw:
request_obj = (
safe_json_loads(request_raw, default=None)
if isinstance(request_raw, str)
else request_raw
)
if isinstance(request_obj, dict):
body = request_obj.get("body", request_obj)
if isinstance(body, dict):
request_obj = body
if isinstance(request_obj, dict):
tools = request_obj.get("tools")
if isinstance(tools, list):
for t in tools:
if isinstance(t, dict):
fn = t.get("function")
if isinstance(fn, dict):
name = fn.get("name")
if name and isinstance(name, str) and name.strip():
tool_names.add(name.strip())
return tool_names
async def process_spend_logs_tool_usage(
prisma_client: PrismaClient,
logs_to_process: List[Dict[str, Any]],
) -> None:
"""
After spend logs are written: insert SpendLogToolIndex rows from each payload.
Extracts tool names from mcp_namespaced_tool_name, response tool_calls, and
proxy_server_request tools.
"""
if not logs_to_process:
return
index_rows: List[Dict[str, Any]] = []
for payload in logs_to_process:
request_id = payload.get("request_id")
start_time = payload.get("startTime")
if not request_id or not start_time:
continue
if isinstance(start_time, str):
try:
start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
except (ValueError, TypeError):
continue
if start_time.tzinfo is None:
start_time = start_time.replace(tzinfo=timezone.utc)
tool_names = _parse_tool_names_from_payload(payload)
for tool_name in tool_names:
index_rows.append(
{
"request_id": request_id,
"tool_name": tool_name,
"start_time": start_time,
}
)
if not index_rows:
return
try:
index_data = []
for r in index_rows:
st = r["start_time"]
if isinstance(st, str):
try:
st = datetime.fromisoformat(st.replace("Z", "+00:00"))
except (ValueError, TypeError):
continue
if st.tzinfo is None:
st = st.replace(tzinfo=timezone.utc)
index_data.append(
{
"request_id": r["request_id"],
"tool_name": r["tool_name"],
"start_time": st,
}
)
if index_data:
await prisma_client.db.litellm_spendlogtoolindex.create_many(
data=index_data,
skip_duplicates=True,
)
except Exception as e:
verbose_proxy_logger.warning(
"Tool usage tracking (SpendLogToolIndex) failed (non-fatal): %s", e
)

View File

@@ -0,0 +1,440 @@
"""
DB helpers for LiteLLM_ToolTable — the global tool registry.
Tools are auto-discovered from LLM responses and upserted here.
Admins use the management endpoints to read and update input_policy / output_policy.
"""
import uuid
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem
from litellm.types.tool_management import (
LiteLLM_ToolTableRow,
ToolPolicyOverrideRow,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
model_dump = getattr(row, "model_dump", None)
if callable(model_dump):
row = model_dump()
elif not isinstance(row, dict):
row = {
k: getattr(row, k, None)
for k in (
"tool_id",
"tool_name",
"origin",
"input_policy",
"output_policy",
"call_count",
"assignments",
"key_hash",
"team_id",
"key_alias",
"user_agent",
"last_used_at",
"created_at",
"updated_at",
"created_by",
"updated_by",
)
}
return LiteLLM_ToolTableRow(
tool_id=row.get("tool_id", ""),
tool_name=row.get("tool_name", ""),
origin=row.get("origin"),
input_policy=row.get("input_policy") or "untrusted",
output_policy=row.get("output_policy") or "untrusted",
call_count=int(row.get("call_count") or 0),
assignments=row.get("assignments"),
key_hash=row.get("key_hash"),
team_id=row.get("team_id"),
key_alias=row.get("key_alias"),
user_agent=row.get("user_agent"),
last_used_at=row.get("last_used_at"),
created_at=row.get("created_at"),
updated_at=row.get("updated_at"),
created_by=row.get("created_by"),
updated_by=row.get("updated_by"),
)
async def batch_upsert_tools(
prisma_client: "PrismaClient",
items: List[ToolDiscoveryQueueItem],
) -> None:
"""
Batch-upsert tool registry rows via Prisma.
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
On conflict: increments call_count; preserves existing policies.
"""
if not items:
return
try:
data = [item for item in items if item.get("tool_name")]
if not data:
return
now = datetime.now(timezone.utc)
table = prisma_client.db.litellm_tooltable
for item in data:
tool_name = item.get("tool_name", "")
origin = item.get("origin") or "user_defined"
created_by = item.get("created_by") or "system"
key_hash = item.get("key_hash")
team_id = item.get("team_id")
key_alias = item.get("key_alias")
user_agent = item.get("user_agent")
await table.upsert(
where={"tool_name": tool_name},
data={
"create": {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"origin": origin,
"input_policy": "untrusted",
"output_policy": "untrusted",
"call_count": 1,
"created_by": created_by,
"updated_by": created_by,
"key_hash": key_hash,
"team_id": team_id,
"key_alias": key_alias,
"user_agent": user_agent,
"last_used_at": now,
},
"update": {
"call_count": {"increment": 1},
"updated_at": now,
"last_used_at": now,
},
},
)
verbose_proxy_logger.debug(
"tool_registry_writer: upserted %d tool(s)", len(data)
)
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer batch_upsert_tools error: %s", e
)
async def list_tools(
prisma_client: "PrismaClient",
input_policy: Optional[str] = None,
) -> List[LiteLLM_ToolTableRow]:
"""Return all tools, optionally filtered by input_policy."""
try:
where = {"input_policy": input_policy} if input_policy is not None else {}
rows = await prisma_client.db.litellm_tooltable.find_many(
where=where,
order={"created_at": "desc"},
)
return [_row_to_model(row) for row in rows]
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
return []
async def get_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> Optional[LiteLLM_ToolTableRow]:
"""Return a single tool row by tool_name."""
try:
row = await prisma_client.db.litellm_tooltable.find_unique(
where={"tool_name": tool_name},
)
if row is None:
return None
return _row_to_model(row)
except Exception as e:
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
return None
async def update_tool_policy(
prisma_client: "PrismaClient",
tool_name: str,
updated_by: Optional[str],
input_policy: Optional[str] = None,
output_policy: Optional[str] = None,
) -> Optional[LiteLLM_ToolTableRow]:
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
try:
_updated_by = updated_by or "system"
now = datetime.now(timezone.utc)
create_data: dict = {
"tool_id": str(uuid.uuid4()),
"tool_name": tool_name,
"input_policy": input_policy or "untrusted",
"output_policy": output_policy or "untrusted",
"created_by": _updated_by,
"updated_by": _updated_by,
"created_at": now,
"updated_at": now,
}
update_data: dict = {
"updated_by": _updated_by,
"updated_at": now,
}
if input_policy is not None:
update_data["input_policy"] = input_policy
if output_policy is not None:
update_data["output_policy"] = output_policy
await prisma_client.db.litellm_tooltable.upsert(
where={"tool_name": tool_name},
data={
"create": create_data,
"update": update_data,
},
)
return await get_tool(prisma_client, tool_name)
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer update_tool_policy error: %s", e
)
return None
async def get_tools_by_names(
prisma_client: "PrismaClient",
tool_names: List[str],
) -> Dict[str, Tuple[str, str]]:
"""
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
"""
if not tool_names:
return {}
try:
rows = await prisma_client.db.litellm_tooltable.find_many(
where={"tool_name": {"in": tool_names}},
)
return {
row.tool_name: (
getattr(row, "input_policy", "untrusted") or "untrusted",
getattr(row, "output_policy", "untrusted") or "untrusted",
)
for row in rows
}
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer get_tools_by_names error: %s", e
)
return {}
async def list_overrides_for_tool(
prisma_client: "PrismaClient",
tool_name: str,
) -> List[ToolPolicyOverrideRow]:
"""
Return override-like rows for a tool by finding object permissions that have
this tool in blocked_tools, then resolving each permission to key/team scope for display.
"""
out: List[ToolPolicyOverrideRow] = []
try:
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
where={"blocked_tools": {"has": tool_name}},
include={
"verification_tokens": True,
"teams": True,
},
)
for perm in perms:
op_id = getattr(perm, "object_permission_id", None) or ""
tokens = getattr(perm, "verification_tokens", []) or []
teams = getattr(perm, "teams", []) or []
for t in tokens:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=None,
key_hash=getattr(t, "token", None),
input_policy="blocked",
key_alias=getattr(t, "key_alias", None),
created_at=None,
updated_at=None,
)
)
for team in teams:
out.append(
ToolPolicyOverrideRow(
override_id=op_id,
tool_name=tool_name,
team_id=getattr(team, "team_id", None),
key_hash=None,
input_policy="blocked",
key_alias=getattr(team, "team_alias", None),
created_at=None,
updated_at=None,
)
)
return out
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer list_overrides_for_tool error: %s", e
)
return []
class ToolPolicyRegistry:
"""
In-memory registry of tool policies synced from DB.
Hot path uses get_effective_policies only — no DB, no cache.
"""
def __init__(self) -> None:
self._tool_input_policies: Dict[str, str] = {}
self._tool_output_policies: Dict[str, str] = {}
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
self._initialized: bool = False
def is_initialized(self) -> bool:
return self._initialized
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
"""Load all tool policies and object-permission blocked_tools from DB."""
try:
tools = await prisma_client.db.litellm_tooltable.find_many()
self._tool_input_policies = {
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
for row in tools
}
self._tool_output_policies = {
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
for row in tools
}
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
self._blocked_tools_by_op_id = {}
for row in perms:
op_id = getattr(row, "object_permission_id", None)
blocked = getattr(row, "blocked_tools", None) or []
if op_id:
self._blocked_tools_by_op_id[op_id] = list(blocked)
self._initialized = True
verbose_proxy_logger.info(
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
len(self._tool_input_policies),
len(self._blocked_tools_by_op_id),
)
except Exception as e:
verbose_proxy_logger.exception(
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
)
raise
def get_input_policy(self, tool_name: str) -> str:
return self._tool_input_policies.get(tool_name, "untrusted")
def get_output_policy(self, tool_name: str) -> str:
return self._tool_output_policies.get(tool_name, "untrusted")
def get_effective_policies(
self,
tool_names: List[str],
object_permission_id: Optional[str] = None,
team_object_permission_id: Optional[str] = None,
) -> Dict[str, str]:
"""
Return effective input_policy per tool from in-memory state.
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
"""
if not tool_names:
return {}
blocked: set = set()
for op_id in (object_permission_id, team_object_permission_id):
if op_id and op_id.strip():
blocked.update(self._blocked_tools_by_op_id.get(op_id.strip(), []))
result: Dict[str, str] = {}
for name in tool_names:
if name in blocked:
result[name] = "blocked"
else:
result[name] = self._tool_input_policies.get(name, "untrusted")
return result
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
def get_tool_policy_registry() -> ToolPolicyRegistry:
"""Return the global ToolPolicyRegistry singleton."""
global _tool_policy_registry
if _tool_policy_registry is None:
_tool_policy_registry = ToolPolicyRegistry()
return _tool_policy_registry
async def add_tool_to_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Add tool_name to the permission's blocked_tools if not already present."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name in current:
return True
current.append(tool_name)
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
)
return False
async def remove_tool_from_object_permission_blocked(
prisma_client: "PrismaClient",
object_permission_id: str,
tool_name: str,
) -> bool:
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
if not object_permission_id or not tool_name:
return False
try:
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
where={"object_permission_id": object_permission_id},
)
if row is None:
return False
current = list(getattr(row, "blocked_tools", []) or [])
if tool_name not in current:
return False
current = [t for t in current if t != tool_name]
await prisma_client.db.litellm_objectpermissiontable.update(
where={"object_permission_id": object_permission_id},
data={"blocked_tools": current},
)
return True
except Exception as e:
verbose_proxy_logger.error(
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
e,
)
return False