chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,53 @@
|
||||
from typing import Any, Literal, List
|
||||
|
||||
|
||||
class CustomDB:
|
||||
"""
|
||||
Implements a base class that we expect any custom db implementation (e.g. DynamoDB) to follow
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
Check if key valid
|
||||
"""
|
||||
pass
|
||||
|
||||
def insert_data(self, value: Any, table_name: Literal["user", "key", "config"]):
|
||||
"""
|
||||
For new key / user logic
|
||||
"""
|
||||
pass
|
||||
|
||||
def update_data(
|
||||
self, key: str, value: Any, table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For cost tracking logic
|
||||
"""
|
||||
pass
|
||||
|
||||
def delete_data(
|
||||
self, keys: List[str], table_name: Literal["user", "key", "config"]
|
||||
):
|
||||
"""
|
||||
For /key/delete endpoint s
|
||||
"""
|
||||
|
||||
def connect(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
For connecting to db and creating / updating any tables
|
||||
"""
|
||||
pass
|
||||
|
||||
def disconnect(
|
||||
self,
|
||||
):
|
||||
"""
|
||||
For closing connection on server shutdown
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Module for checking differences between Prisma schema and database."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
def extract_sql_commands(diff_output: str) -> List[str]:
|
||||
"""
|
||||
Extract SQL commands from the Prisma migrate diff output.
|
||||
Args:
|
||||
diff_output (str): The full output from prisma migrate diff.
|
||||
Returns:
|
||||
List[str]: A list of SQL commands extracted from the diff output.
|
||||
"""
|
||||
# Split the output into lines and remove empty lines
|
||||
lines = [line.strip() for line in diff_output.split("\n") if line.strip()]
|
||||
|
||||
sql_commands = []
|
||||
current_command = ""
|
||||
in_sql_block = False
|
||||
|
||||
for line in lines:
|
||||
if line.startswith("-- "): # Comment line, likely a table operation description
|
||||
if in_sql_block and current_command:
|
||||
sql_commands.append(current_command.strip())
|
||||
current_command = ""
|
||||
in_sql_block = True
|
||||
elif in_sql_block:
|
||||
if line.endswith(";"):
|
||||
current_command += line
|
||||
sql_commands.append(current_command.strip())
|
||||
current_command = ""
|
||||
in_sql_block = False
|
||||
else:
|
||||
current_command += line + " "
|
||||
|
||||
# Add any remaining command
|
||||
if current_command:
|
||||
sql_commands.append(current_command.strip())
|
||||
|
||||
return sql_commands
|
||||
|
||||
|
||||
def check_prisma_schema_diff_helper(db_url: str) -> Tuple[bool, List[str]]:
|
||||
"""Checks for differences between current database and Prisma schema.
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- A boolean indicating if differences were found (True) or not (False).
|
||||
- A string with the diff output or error message.
|
||||
Raises:
|
||||
subprocess.CalledProcessError: If the Prisma command fails.
|
||||
Exception: For any other errors during execution.
|
||||
"""
|
||||
verbose_logger.debug("Checking for Prisma schema diff...") # noqa: T201
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"prisma",
|
||||
"migrate",
|
||||
"diff",
|
||||
"--from-url",
|
||||
db_url,
|
||||
"--to-schema-datamodel",
|
||||
"./schema.prisma",
|
||||
"--script",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# return True, "Migration diff generated successfully."
|
||||
sql_commands = extract_sql_commands(result.stdout)
|
||||
|
||||
if sql_commands:
|
||||
print("Changes to DB Schema detected") # noqa: T201
|
||||
print("Required SQL commands:") # noqa: T201
|
||||
for command in sql_commands:
|
||||
print(command) # noqa: T201
|
||||
return True, sql_commands
|
||||
else:
|
||||
return False, []
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Failed to generate migration diff. Error: {e.stderr}"
|
||||
print(error_message) # noqa: T201
|
||||
return False, []
|
||||
|
||||
|
||||
def check_prisma_schema_diff(db_url: Optional[str] = None) -> None:
|
||||
"""Main function to run the Prisma schema diff check."""
|
||||
if db_url is None:
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
if db_url is None:
|
||||
raise Exception("DATABASE_URL not set")
|
||||
has_diff, message = check_prisma_schema_diff_helper(db_url)
|
||||
if has_diff:
|
||||
verbose_logger.exception(
|
||||
"🚨🚨🚨 prisma schema out of sync with db. Consider running these sql_commands to sync the two - {}".format(
|
||||
message
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,227 @@
|
||||
from typing import Any
|
||||
|
||||
from litellm import verbose_logger
|
||||
|
||||
_db = Any
|
||||
|
||||
|
||||
async def create_missing_views(db: _db): # noqa: PLR0915
|
||||
"""
|
||||
--------------------------------------------------
|
||||
NOTE: Copy of `litellm/db_scripts/create_views.py`.
|
||||
--------------------------------------------------
|
||||
Checks if the LiteLLM_VerificationTokenView and MonthlyGlobalSpend exists in the user's db.
|
||||
|
||||
LiteLLM_VerificationTokenView: This view is used for getting the token + team data in user_api_key_auth
|
||||
|
||||
MonthlyGlobalSpend: This view is used for the admin view to see global spend for this month
|
||||
|
||||
If the view doesn't exist, one will be created.
|
||||
"""
|
||||
try:
|
||||
# Try to select one row from the view
|
||||
await db.query_raw("""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1""")
|
||||
print("LiteLLM_VerificationTokenView Exists!") # noqa
|
||||
except Exception:
|
||||
# If an error occurs, the view does not exist, so create it
|
||||
await db.execute_raw(
|
||||
"""
|
||||
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||
SELECT
|
||||
v.*,
|
||||
t.spend AS team_spend,
|
||||
t.max_budget AS team_max_budget,
|
||||
t.tpm_limit AS team_tpm_limit,
|
||||
t.rpm_limit AS team_rpm_limit
|
||||
FROM "LiteLLM_VerificationToken" v
|
||||
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||
"""
|
||||
)
|
||||
|
||||
print("LiteLLM_VerificationTokenView Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpend" LIMIT 1""")
|
||||
print("MonthlyGlobalSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpend" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime");
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dKeysBySpend" LIMIT 1""")
|
||||
print("Last30dKeysBySpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "Last30dKeysBySpend" AS
|
||||
SELECT
|
||||
L."api_key",
|
||||
V."key_alias",
|
||||
V."key_name",
|
||||
SUM(L."spend") AS total_spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs" L
|
||||
LEFT JOIN
|
||||
"LiteLLM_VerificationToken" V
|
||||
ON
|
||||
L."api_key" = V."token"
|
||||
WHERE
|
||||
L."startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
L."api_key", V."key_alias", V."key_name"
|
||||
ORDER BY
|
||||
total_spend DESC;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dKeysBySpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dModelsBySpend" LIMIT 1""")
|
||||
print("Last30dModelsBySpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "Last30dModelsBySpend" AS
|
||||
SELECT
|
||||
"model",
|
||||
SUM("spend") AS total_spend
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
AND "model" != ''
|
||||
GROUP BY
|
||||
"model"
|
||||
ORDER BY
|
||||
total_spend DESC;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dModelsBySpend Created!") # noqa
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "MonthlyGlobalSpendPerKey" LIMIT 1""")
|
||||
print("MonthlyGlobalSpendPerKey Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerKey" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend,
|
||||
api_key as api_key
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime"),
|
||||
api_key;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpendPerKey Created!") # noqa
|
||||
try:
|
||||
await db.query_raw(
|
||||
"""SELECT 1 FROM "MonthlyGlobalSpendPerUserPerKey" LIMIT 1"""
|
||||
)
|
||||
print("MonthlyGlobalSpendPerUserPerKey Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "MonthlyGlobalSpendPerUserPerKey" AS
|
||||
SELECT
|
||||
DATE("startTime") AS date,
|
||||
SUM("spend") AS spend,
|
||||
api_key as api_key,
|
||||
"user" as "user"
|
||||
FROM
|
||||
"LiteLLM_SpendLogs"
|
||||
WHERE
|
||||
"startTime" >= (CURRENT_DATE - INTERVAL '30 days')
|
||||
GROUP BY
|
||||
DATE("startTime"),
|
||||
"user",
|
||||
api_key;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("MonthlyGlobalSpendPerUserPerKey Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "DailyTagSpend" LIMIT 1""")
|
||||
print("DailyTagSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE OR REPLACE VIEW "DailyTagSpend" AS
|
||||
SELECT
|
||||
jsonb_array_elements_text(request_tags) AS individual_request_tag,
|
||||
DATE(s."startTime") AS spend_date,
|
||||
COUNT(*) AS log_count,
|
||||
SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs" s
|
||||
GROUP BY individual_request_tag, DATE(s."startTime");
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("DailyTagSpend Created!") # noqa
|
||||
|
||||
try:
|
||||
await db.query_raw("""SELECT 1 FROM "Last30dTopEndUsersSpend" LIMIT 1""")
|
||||
print("Last30dTopEndUsersSpend Exists!") # noqa
|
||||
except Exception:
|
||||
sql_query = """
|
||||
CREATE VIEW "Last30dTopEndUsersSpend" AS
|
||||
SELECT end_user, COUNT(*) AS total_events, SUM(spend) AS total_spend
|
||||
FROM "LiteLLM_SpendLogs"
|
||||
WHERE end_user <> '' AND end_user <> user
|
||||
AND "startTime" >= CURRENT_DATE - INTERVAL '30 days'
|
||||
GROUP BY end_user
|
||||
ORDER BY total_spend DESC
|
||||
LIMIT 100;
|
||||
"""
|
||||
await db.execute_raw(query=sql_query)
|
||||
|
||||
print("Last30dTopEndUsersSpend Created!") # noqa
|
||||
|
||||
return
|
||||
|
||||
|
||||
async def should_create_missing_views(db: _db) -> bool:
|
||||
"""
|
||||
Run only on first time startup.
|
||||
|
||||
If SpendLogs table already has values, then don't create views on startup.
|
||||
"""
|
||||
|
||||
sql_query = """
|
||||
SELECT reltuples::BIGINT
|
||||
FROM pg_class
|
||||
WHERE oid = '"LiteLLM_SpendLogs"'::regclass;
|
||||
"""
|
||||
|
||||
result = await db.query_raw(query=sql_query)
|
||||
|
||||
verbose_logger.debug("Estimated Row count of LiteLLM_SpendLogs = {}".format(result))
|
||||
if (
|
||||
result
|
||||
and isinstance(result, list)
|
||||
and len(result) > 0
|
||||
and isinstance(result[0], dict)
|
||||
and "reltuples" in result[0]
|
||||
and result[0]["reltuples"]
|
||||
and (result[0]["reltuples"] == 0 or result[0]["reltuples"] == -1)
|
||||
):
|
||||
verbose_logger.debug("Should create views")
|
||||
return True
|
||||
|
||||
return False
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
Base class for in memory buffer for database transactions
|
||||
"""
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._service_logger import ServiceLogging
|
||||
|
||||
service_logger_obj = (
|
||||
ServiceLogging()
|
||||
) # used for tracking metrics for In memory buffer, redis buffer, pod lock manager
|
||||
from litellm.constants import (
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
MAX_IN_MEMORY_QUEUE_FLUSH_COUNT,
|
||||
MAX_SIZE_IN_MEMORY_QUEUE,
|
||||
)
|
||||
|
||||
|
||||
class BaseUpdateQueue:
|
||||
"""Base class for in memory buffer for database transactions"""
|
||||
|
||||
def __init__(self):
|
||||
self.update_queue = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
|
||||
self.MAX_SIZE_IN_MEMORY_QUEUE = MAX_SIZE_IN_MEMORY_QUEUE
|
||||
if MAX_SIZE_IN_MEMORY_QUEUE >= LITELLM_ASYNCIO_QUEUE_MAXSIZE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Misconfigured queue thresholds: MAX_SIZE_IN_MEMORY_QUEUE (%d) >= LITELLM_ASYNCIO_QUEUE_MAXSIZE (%d). "
|
||||
"The spend aggregation check will never trigger because the asyncio.Queue blocks at %d items. "
|
||||
"Set MAX_SIZE_IN_MEMORY_QUEUE to a value less than LITELLM_ASYNCIO_QUEUE_MAXSIZE (recommended: 80%% of it).",
|
||||
MAX_SIZE_IN_MEMORY_QUEUE,
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
LITELLM_ASYNCIO_QUEUE_MAXSIZE,
|
||||
)
|
||||
|
||||
async def add_update(self, update):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
await self._emit_new_item_added_to_queue_event(
|
||||
queue_size=self.update_queue.qsize()
|
||||
)
|
||||
|
||||
async def flush_all_updates_from_in_memory_queue(self):
|
||||
"""Get all updates from the queue."""
|
||||
updates = []
|
||||
while not self.update_queue.empty():
|
||||
# Circuit breaker to ensure we're not stuck dequeuing updates. Protect CPU utilization
|
||||
if len(updates) >= MAX_IN_MEMORY_QUEUE_FLUSH_COUNT:
|
||||
verbose_proxy_logger.debug(
|
||||
"Max in memory queue flush count reached, stopping flush"
|
||||
)
|
||||
break
|
||||
updates.append(await self.update_queue.get())
|
||||
return updates
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
"""placeholder, emit event when a new item is added to the queue"""
|
||||
pass
|
||||
@@ -0,0 +1,155 @@
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.proxy._types import BaseDailySpendTransaction
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class DailySpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for daily spend updates that should be committed to the database
|
||||
|
||||
To add a new daily spend update transaction, use the following format:
|
||||
daily_spend_update_queue.add_update({
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
}
|
||||
})
|
||||
|
||||
Queue contains a list of daily spend update transactions
|
||||
|
||||
eg
|
||||
queue = [
|
||||
{
|
||||
"user1_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
},
|
||||
{
|
||||
"user2_date_api_key_model_custom_llm_provider": {
|
||||
"spend": 10,
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 100,
|
||||
"api_requests": 100,
|
||||
"successful_requests": 100,
|
||||
"failed_requests": 100,
|
||||
}
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
] = asyncio.Queue(maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE)
|
||||
|
||||
async def add_update(self, update: Dict[str, BaseDailySpendTransaction]):
|
||||
"""Enqueue an update."""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""
|
||||
Combine all updates in the queue into a single update.
|
||||
This is used to reduce the size of the in-memory queue.
|
||||
"""
|
||||
updates: List[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
] = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_updates = self.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
await self.update_queue.put(aggregated_updates)
|
||||
|
||||
async def flush_and_get_aggregated_daily_spend_update_transactions(
|
||||
self,
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Get all updates from the queue and return all updates aggregated by daily_transaction_key. Works for both user and team spend updates."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
if len(updates) > 0:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - flushed %d daily spend update items from in-memory queue",
|
||||
len(updates),
|
||||
)
|
||||
aggregated_daily_spend_update_transactions = (
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
updates
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated daily spend update transactions: %s",
|
||||
aggregated_daily_spend_update_transactions,
|
||||
)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_daily_spend_update_transactions(
|
||||
updates: List[Dict[str, BaseDailySpendTransaction]],
|
||||
) -> Dict[str, BaseDailySpendTransaction]:
|
||||
"""Aggregate updates by daily_transaction_key."""
|
||||
aggregated_daily_spend_update_transactions: Dict[
|
||||
str, BaseDailySpendTransaction
|
||||
] = {}
|
||||
for _update in updates:
|
||||
for _key, payload in _update.items():
|
||||
if _key in aggregated_daily_spend_update_transactions:
|
||||
daily_transaction = aggregated_daily_spend_update_transactions[_key]
|
||||
daily_transaction["spend"] += payload["spend"]
|
||||
daily_transaction["prompt_tokens"] += payload["prompt_tokens"]
|
||||
daily_transaction["completion_tokens"] += payload[
|
||||
"completion_tokens"
|
||||
]
|
||||
daily_transaction["api_requests"] += payload["api_requests"]
|
||||
daily_transaction["successful_requests"] += payload[
|
||||
"successful_requests"
|
||||
]
|
||||
daily_transaction["failed_requests"] += payload["failed_requests"]
|
||||
|
||||
# Add optional metrics cache_read_input_tokens and cache_creation_input_tokens
|
||||
daily_transaction["cache_read_input_tokens"] = (
|
||||
payload.get("cache_read_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_read_input_tokens", 0)
|
||||
|
||||
daily_transaction["cache_creation_input_tokens"] = (
|
||||
payload.get("cache_creation_input_tokens", 0) or 0
|
||||
) + daily_transaction.get("cache_creation_input_tokens", 0)
|
||||
|
||||
else:
|
||||
aggregated_daily_spend_update_transactions[_key] = deepcopy(payload)
|
||||
return aggregated_daily_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_DAILY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,186 @@
|
||||
import asyncio
|
||||
from litellm._uuid import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
from litellm.constants import DEFAULT_CRON_JOB_LOCK_TTL_SECONDS
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ProxyLogging = Any
|
||||
else:
|
||||
ProxyLogging = Any
|
||||
|
||||
|
||||
class PodLockManager:
|
||||
"""
|
||||
Manager for acquiring and releasing locks for cron jobs using Redis.
|
||||
|
||||
Ensures that only one pod can run a cron job at a time.
|
||||
"""
|
||||
|
||||
def __init__(self, redis_cache: Optional[RedisCache] = None):
|
||||
self.pod_id = str(uuid.uuid4())
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def get_redis_lock_key(cronjob_id: str) -> str:
|
||||
return f"cronjob_lock:{cronjob_id}"
|
||||
|
||||
async def acquire_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Attempt to acquire the lock for a specific cron job using Redis.
|
||||
Uses the SET command with NX and EX options to ensure atomicity.
|
||||
|
||||
Args:
|
||||
cronjob_id: The ID of the cron job to lock
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping acquire_lock")
|
||||
return None
|
||||
try:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to acquire Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
# Try to set the lock key with the pod_id as its value, only if it doesn't exist (NX)
|
||||
# and with an expiration (EX) to avoid deadlocks.
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
acquired = await self.redis_cache.async_set_cache(
|
||||
lock_key,
|
||||
self.pod_id,
|
||||
nx=True,
|
||||
ttl=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
)
|
||||
if acquired:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully acquired Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
# Check if the current pod already holds the lock
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s already holds the Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_acquired_lock_event(cronjob_id, self.pod_id)
|
||||
return True
|
||||
else:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - pod %s could not acquire lock for cronjob_id=%s, "
|
||||
"held by pod %s. Spend updates in Redis will wait for the leader pod to commit.",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
current_value,
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error acquiring Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def release_lock(
|
||||
self,
|
||||
cronjob_id: str,
|
||||
):
|
||||
"""
|
||||
Release the lock if the current pod holds it.
|
||||
Uses get and delete commands to ensure that only the owner can release the lock.
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug("redis_cache is None, skipping release_lock")
|
||||
return
|
||||
try:
|
||||
cronjob_id = cronjob_id
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempting to release Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
lock_key = PodLockManager.get_redis_lock_key(cronjob_id)
|
||||
|
||||
current_value = await self.redis_cache.async_get_cache(lock_key)
|
||||
if current_value is not None:
|
||||
if isinstance(current_value, bytes):
|
||||
current_value = current_value.decode("utf-8")
|
||||
if current_value == self.pod_id:
|
||||
result = await self.redis_cache.async_delete_cache(lock_key)
|
||||
if result == 1:
|
||||
verbose_proxy_logger.info(
|
||||
"Pod %s successfully released Redis lock for cronjob_id=%s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
self._emit_released_lock_event(
|
||||
cronjob_id=cronjob_id,
|
||||
pod_id=self.pod_id,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend tracking - pod %s failed to release Redis lock for cronjob_id=%s. "
|
||||
"Lock will expire after TTL=%ds.",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s cannot release Redis lock for cronjob_id=%s because it is held by pod %s",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
current_value,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Pod %s attempted to release Redis lock for cronjob_id=%s, but no lock was found",
|
||||
self.pod_id,
|
||||
cronjob_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error releasing Redis lock for {cronjob_id}: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_acquired_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_acquired_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 1,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _emit_released_lock_event(cronjob_id: str, pod_id: str):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.POD_LOCK_MANAGER,
|
||||
duration=DEFAULT_CRON_JOB_LOCK_TTL_SECONDS,
|
||||
call_type="_emit_released_lock_event",
|
||||
event_metadata={
|
||||
"gauge_labels": f"{cronjob_id}:{pod_id}",
|
||||
"gauge_value": 0,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,677 @@
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.constants import (
|
||||
MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
DailyAgentSpendTransaction,
|
||||
DailyEndUserSpendTransaction,
|
||||
DailyOrganizationSpendTransaction,
|
||||
DailyTagSpendTransaction,
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
DailySpendUpdateQueue,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.caching import (
|
||||
RedisPipelineLpopOperation,
|
||||
RedisPipelineRpushOperation,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
else:
|
||||
PrismaClient = Any
|
||||
|
||||
|
||||
class RedisUpdateBuffer:
|
||||
"""
|
||||
Handles buffering database `UPDATE` transactions in Redis before committing them to the database
|
||||
|
||||
This is to prevent deadlocks and improve reliability
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
):
|
||||
self.redis_cache = redis_cache
|
||||
|
||||
@staticmethod
|
||||
def _should_commit_spend_updates_to_redis() -> bool:
|
||||
"""
|
||||
Checks if the Pod should commit spend updates to Redis
|
||||
|
||||
This setting enables buffering database transactions in Redis
|
||||
to improve reliability and reduce database contention
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_use_redis_transaction_buffer: Optional[
|
||||
Union[bool, str]
|
||||
] = general_settings.get("use_redis_transaction_buffer", False)
|
||||
if isinstance(_use_redis_transaction_buffer, str):
|
||||
_use_redis_transaction_buffer = str_to_bool(_use_redis_transaction_buffer)
|
||||
if _use_redis_transaction_buffer is None:
|
||||
return False
|
||||
return _use_redis_transaction_buffer
|
||||
|
||||
async def _store_transactions_in_redis(
|
||||
self,
|
||||
transactions: Any,
|
||||
redis_key: str,
|
||||
service_type: ServiceTypes,
|
||||
) -> None:
|
||||
"""
|
||||
Helper method to store transactions in Redis and emit an event
|
||||
|
||||
Args:
|
||||
transactions: The transactions to store
|
||||
redis_key: The Redis key to store under
|
||||
service_type: The service type for event emission
|
||||
"""
|
||||
if transactions is None or len(transactions) == 0:
|
||||
return
|
||||
|
||||
list_of_transactions = [safe_dumps(transactions)]
|
||||
if self.redis_cache is None:
|
||||
return
|
||||
try:
|
||||
current_redis_buffer_size = await self.redis_cache.async_rpush(
|
||||
key=redis_key,
|
||||
values=list_of_transactions,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Spend tracking - pushed spend updates to Redis buffer. "
|
||||
"redis_key=%s, buffer_size=%s",
|
||||
redis_key,
|
||||
current_redis_buffer_size,
|
||||
)
|
||||
await self._emit_new_item_added_to_redis_buffer_event(
|
||||
queue_size=current_redis_buffer_size,
|
||||
service=service_type,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"Spend tracking - failed to push spend updates to Redis (redis_key=%s). "
|
||||
"Error: %s",
|
||||
redis_key,
|
||||
str(e),
|
||||
)
|
||||
|
||||
async def store_in_memory_spend_updates_in_redis(
|
||||
self,
|
||||
spend_update_queue: SpendUpdateQueue,
|
||||
daily_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_team_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_org_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_agent_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_tag_spend_update_queue: DailySpendUpdateQueue,
|
||||
):
|
||||
"""
|
||||
Stores the in-memory spend updates to Redis
|
||||
|
||||
Stores the following in memory data structures in Redis:
|
||||
- SpendUpdateQueue - Key, User, Team, TeamMember, Org, EndUser Spend updates
|
||||
- DailySpendUpdateQueue - Daily Spend updates Aggregate view
|
||||
|
||||
For SpendUpdateQueue:
|
||||
Each transaction is a dict stored as following:
|
||||
- key is the entity id
|
||||
- value is the spend amount
|
||||
|
||||
```
|
||||
Redis List:
|
||||
key_list_transactions:
|
||||
[
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
"0929880203": 0.001,
|
||||
]
|
||||
```
|
||||
|
||||
For DailySpendUpdateQueue:
|
||||
Each transaction is a Dict[str, DailyUserSpendTransaction] stored as following:
|
||||
- key is the daily_transaction_key
|
||||
- value is the DailyUserSpendTransaction
|
||||
|
||||
```
|
||||
Redis List:
|
||||
daily_spend_update_transactions:
|
||||
[
|
||||
{
|
||||
"user_keyhash_1_model_1": {
|
||||
"spend": 1.2,
|
||||
"prompt_tokens": 1000,
|
||||
"completion_tokens": 1000,
|
||||
"api_requests": 1000,
|
||||
"successful_requests": 1000,
|
||||
},
|
||||
|
||||
}
|
||||
]
|
||||
```
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"redis_cache is None, skipping store_in_memory_spend_updates_in_redis"
|
||||
)
|
||||
return
|
||||
|
||||
# Get all transactions
|
||||
db_spend_update_transactions = (
|
||||
await spend_update_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
daily_spend_update_transactions = (
|
||||
await daily_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_team_spend_update_transactions = (
|
||||
await daily_team_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_org_spend_update_transactions = (
|
||||
await daily_org_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_end_user_spend_update_transactions = (
|
||||
await daily_end_user_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_agent_spend_update_transactions = (
|
||||
await daily_agent_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
daily_tag_spend_update_transactions = (
|
||||
await daily_tag_spend_update_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DB SPEND UPDATE TRANSACTIONS: %s", db_spend_update_transactions
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"ALL DAILY SPEND UPDATE TRANSACTIONS: %s", daily_spend_update_transactions
|
||||
)
|
||||
|
||||
# Build a list of rpush operations, skipping empty/None transaction sets
|
||||
_queue_configs: List[Tuple[Any, str, ServiceTypes]] = [
|
||||
(
|
||||
db_spend_update_transactions,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_spend_update_transactions,
|
||||
REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_team_spend_update_transactions,
|
||||
REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_TEAM_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_org_spend_update_transactions,
|
||||
REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_ORG_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_end_user_spend_update_transactions,
|
||||
REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_END_USER_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_agent_spend_update_transactions,
|
||||
REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_AGENT_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
(
|
||||
daily_tag_spend_update_transactions,
|
||||
REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
ServiceTypes.REDIS_DAILY_TAG_SPEND_UPDATE_QUEUE,
|
||||
),
|
||||
]
|
||||
|
||||
rpush_list: List[RedisPipelineRpushOperation] = []
|
||||
service_types: List[ServiceTypes] = []
|
||||
for transactions, redis_key, service_type in _queue_configs:
|
||||
if transactions is None or len(transactions) == 0:
|
||||
continue
|
||||
rpush_list.append(
|
||||
RedisPipelineRpushOperation(
|
||||
key=redis_key,
|
||||
values=[safe_dumps(transactions)],
|
||||
)
|
||||
)
|
||||
service_types.append(service_type)
|
||||
|
||||
if len(rpush_list) == 0:
|
||||
return
|
||||
|
||||
result_lengths = await self.redis_cache.async_rpush_pipeline(
|
||||
rpush_list=rpush_list,
|
||||
)
|
||||
|
||||
# Emit gauge events for each queue
|
||||
for i, queue_size in enumerate(result_lengths):
|
||||
if i < len(service_types):
|
||||
await self._emit_new_item_added_to_redis_buffer_event(
|
||||
queue_size=queue_size,
|
||||
service=service_types[i],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _number_of_transactions_to_store_in_redis(
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the number of transactions to store in Redis
|
||||
"""
|
||||
num_transactions = 0
|
||||
for v in db_spend_update_transactions.values():
|
||||
if isinstance(v, dict):
|
||||
num_transactions += len(v)
|
||||
return num_transactions
|
||||
|
||||
@staticmethod
|
||||
def _remove_prefix_from_keys(data: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes the specified prefix from the keys of a dictionary.
|
||||
"""
|
||||
return {key.replace(prefix, "", 1): value for key, value in data.items()}
|
||||
|
||||
async def get_all_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Gets all the update transactions from Redis
|
||||
|
||||
On Redis we store a list of transactions as a JSON string
|
||||
|
||||
eg.
|
||||
[
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_1": 1.2,
|
||||
"user_id_2": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"0929880201": 1.2,
|
||||
"0929880202": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
),
|
||||
DBSpendUpdateTransactions(
|
||||
user_list_transactions={
|
||||
"user_id_3": 1.2,
|
||||
"user_id_4": 0.01,
|
||||
},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={
|
||||
"key_id_1": 1.2,
|
||||
"key_id_2": 0.01,
|
||||
},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
]
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - popped %d spend update batches from Redis buffer (key=%s). "
|
||||
"These items are now removed from Redis and must be committed to DB.",
|
||||
len(list_of_transactions) if isinstance(list_of_transactions, list) else 1,
|
||||
REDIS_UPDATE_BUFFER_KEY,
|
||||
)
|
||||
|
||||
# Parse the list of transactions from JSON strings
|
||||
parsed_transactions = self._parse_list_of_transactions(list_of_transactions)
|
||||
|
||||
# If there are no transactions, return None
|
||||
if len(parsed_transactions) == 0:
|
||||
return None
|
||||
|
||||
# Combine all transactions into a single transaction
|
||||
combined_transaction = self._combine_list_of_transactions(parsed_transactions)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def get_all_transactions_from_redis_buffer_pipeline(
|
||||
self,
|
||||
) -> Tuple[
|
||||
Optional[DBSpendUpdateTransactions],
|
||||
Optional[Dict[str, DailyUserSpendTransaction]],
|
||||
Optional[Dict[str, DailyTeamSpendTransaction]],
|
||||
Optional[Dict[str, DailyOrganizationSpendTransaction]],
|
||||
Optional[Dict[str, DailyEndUserSpendTransaction]],
|
||||
Optional[Dict[str, DailyAgentSpendTransaction]],
|
||||
Optional[Dict[str, DailyTagSpendTransaction]],
|
||||
]:
|
||||
"""
|
||||
Drains all 7 Redis buffer queues in a single pipeline round-trip.
|
||||
|
||||
Returns a 7-tuple of parsed results in this order:
|
||||
0: DBSpendUpdateTransactions
|
||||
1: daily user spend
|
||||
2: daily team spend
|
||||
3: daily org spend
|
||||
4: daily end-user spend
|
||||
5: daily agent spend
|
||||
6: daily tag spend
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None, None, None, None, None, None, None
|
||||
|
||||
lpop_list: List[RedisPipelineLpopOperation] = [
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_UPDATE_BUFFER_KEY, count=MAX_REDIS_BUFFER_DEQUEUE_COUNT
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
RedisPipelineLpopOperation(
|
||||
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
),
|
||||
]
|
||||
|
||||
raw_results = await self.redis_cache.async_lpop_pipeline(lpop_list=lpop_list)
|
||||
|
||||
# Pad with None if pipeline returned fewer results than expected
|
||||
while len(raw_results) < 7:
|
||||
raw_results.append(None)
|
||||
|
||||
# Slot 0: DBSpendUpdateTransactions
|
||||
db_spend: Optional[DBSpendUpdateTransactions] = None
|
||||
if raw_results[0] is not None:
|
||||
parsed = self._parse_list_of_transactions(raw_results[0])
|
||||
if len(parsed) > 0:
|
||||
db_spend = self._combine_list_of_transactions(parsed)
|
||||
|
||||
# Slots 1-6: daily spend categories
|
||||
daily_results: List[Optional[Dict[str, Any]]] = []
|
||||
for slot in range(1, 7):
|
||||
if raw_results[slot] is None:
|
||||
daily_results.append(None)
|
||||
else:
|
||||
list_of_daily = [json.loads(t) for t in raw_results[slot]] # type: ignore
|
||||
aggregated = DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily
|
||||
)
|
||||
daily_results.append(aggregated)
|
||||
|
||||
return (
|
||||
db_spend,
|
||||
cast(Optional[Dict[str, DailyUserSpendTransaction]], daily_results[0]),
|
||||
cast(Optional[Dict[str, DailyTeamSpendTransaction]], daily_results[1]),
|
||||
cast(
|
||||
Optional[Dict[str, DailyOrganizationSpendTransaction]], daily_results[2]
|
||||
),
|
||||
cast(Optional[Dict[str, DailyEndUserSpendTransaction]], daily_results[3]),
|
||||
cast(Optional[Dict[str, DailyAgentSpendTransaction]], daily_results[4]),
|
||||
cast(Optional[Dict[str, DailyTagSpendTransaction]], daily_results[5]),
|
||||
)
|
||||
|
||||
async def get_all_daily_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyUserSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyUserSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_team_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyTeamSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily team spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_TEAM_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyTeamSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_org_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyOrganizationSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily organization spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_ORG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyOrganizationSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_end_user_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyEndUserSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily end-user spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_END_USER_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyEndUserSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_agent_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyAgentSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily agent spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_AGENT_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyAgentSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
async def get_all_daily_tag_spend_update_transactions_from_redis_buffer(
|
||||
self,
|
||||
) -> Optional[Dict[str, DailyTagSpendTransaction]]:
|
||||
"""
|
||||
Gets all the daily tag spend update transactions from Redis
|
||||
"""
|
||||
if self.redis_cache is None:
|
||||
return None
|
||||
list_of_transactions = await self.redis_cache.async_lpop(
|
||||
key=REDIS_DAILY_TAG_SPEND_UPDATE_BUFFER_KEY,
|
||||
count=MAX_REDIS_BUFFER_DEQUEUE_COUNT,
|
||||
)
|
||||
if list_of_transactions is None:
|
||||
return None
|
||||
list_of_daily_spend_update_transactions = [
|
||||
json.loads(transaction) for transaction in list_of_transactions
|
||||
]
|
||||
return cast(
|
||||
Dict[str, DailyTagSpendTransaction],
|
||||
DailySpendUpdateQueue.get_aggregated_daily_spend_update_transactions(
|
||||
list_of_daily_spend_update_transactions
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_list_of_transactions(
|
||||
list_of_transactions: Union[Any, List[Any]],
|
||||
) -> List[DBSpendUpdateTransactions]:
|
||||
"""
|
||||
Parses the list of transactions from Redis
|
||||
"""
|
||||
if isinstance(list_of_transactions, list):
|
||||
return [json.loads(transaction) for transaction in list_of_transactions]
|
||||
else:
|
||||
return [json.loads(list_of_transactions)]
|
||||
|
||||
@staticmethod
|
||||
def _combine_list_of_transactions(
|
||||
list_of_transactions: List[DBSpendUpdateTransactions],
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""
|
||||
Combines the list of transactions into a single DBSpendUpdateTransactions object
|
||||
"""
|
||||
# Initialize a new combined transaction object with empty dictionaries
|
||||
combined_transaction = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
agent_list_transactions={},
|
||||
)
|
||||
|
||||
# Define the transaction fields to process
|
||||
transaction_fields = [
|
||||
"user_list_transactions",
|
||||
"end_user_list_transactions",
|
||||
"key_list_transactions",
|
||||
"team_list_transactions",
|
||||
"team_member_list_transactions",
|
||||
"org_list_transactions",
|
||||
"tag_list_transactions",
|
||||
"agent_list_transactions",
|
||||
]
|
||||
|
||||
# Loop through each transaction and combine the values
|
||||
for transaction in list_of_transactions:
|
||||
# Process each field type
|
||||
for field in transaction_fields:
|
||||
if transaction.get(field):
|
||||
for entity_id, amount in transaction[field].items(): # type: ignore
|
||||
combined_transaction[field][entity_id] = ( # type: ignore
|
||||
combined_transaction[field].get(entity_id, 0) + amount # type: ignore
|
||||
)
|
||||
|
||||
return combined_transaction
|
||||
|
||||
async def _emit_new_item_added_to_redis_buffer_event(
|
||||
self,
|
||||
service: ServiceTypes,
|
||||
queue_size: int,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=service,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": service,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,172 @@
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.constants import (
|
||||
SPEND_LOG_CLEANUP_BATCH_SIZE,
|
||||
SPEND_LOG_CLEANUP_JOB_NAME,
|
||||
SPEND_LOG_RUN_LOOPS,
|
||||
)
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class SpendLogCleanup:
|
||||
"""
|
||||
Handles cleaning up old spend logs based on maximum retention period.
|
||||
Deletes logs in batches to prevent timeouts.
|
||||
Uses PodLockManager to ensure only one pod runs cleanup in multi-pod deployments.
|
||||
"""
|
||||
|
||||
def __init__(self, general_settings=None, redis_cache: Optional[RedisCache] = None):
|
||||
self.batch_size = SPEND_LOG_CLEANUP_BATCH_SIZE
|
||||
self.retention_seconds: Optional[int] = None
|
||||
from litellm.proxy.proxy_server import general_settings as default_settings
|
||||
|
||||
self.general_settings = general_settings or default_settings
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
pod_lock_manager = proxy_logging_obj.db_spend_update_writer.pod_lock_manager
|
||||
self.pod_lock_manager = pod_lock_manager
|
||||
verbose_proxy_logger.info(
|
||||
f"SpendLogCleanup initialized with batch size: {self.batch_size}"
|
||||
)
|
||||
|
||||
def _should_delete_spend_logs(self) -> bool:
|
||||
"""
|
||||
Determines if logs should be deleted based on the max retention period in settings.
|
||||
"""
|
||||
retention_setting = self.general_settings.get(
|
||||
"maximum_spend_logs_retention_period"
|
||||
)
|
||||
verbose_proxy_logger.info(f"Checking retention setting: {retention_setting}")
|
||||
|
||||
if retention_setting is None:
|
||||
verbose_proxy_logger.info("No retention setting found")
|
||||
return False
|
||||
|
||||
try:
|
||||
if isinstance(retention_setting, int):
|
||||
verbose_proxy_logger.warning(
|
||||
f"maximum_spend_logs_retention_period is an integer ({retention_setting}); treating as days. "
|
||||
"Use a string like '3d' to be explicit."
|
||||
)
|
||||
retention_setting = f"{retention_setting}d"
|
||||
self.retention_seconds = duration_in_seconds(retention_setting)
|
||||
verbose_proxy_logger.info(
|
||||
f"Retention period set to {self.retention_seconds} seconds"
|
||||
)
|
||||
return True
|
||||
except ValueError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Invalid maximum_spend_logs_retention_period value: {retention_setting}, error: {str(e)}"
|
||||
)
|
||||
return False
|
||||
|
||||
async def _delete_old_logs(
|
||||
self, prisma_client: PrismaClient, cutoff_date: datetime
|
||||
) -> int:
|
||||
"""
|
||||
Helper method to delete old logs in batches.
|
||||
Returns the total number of logs deleted.
|
||||
"""
|
||||
total_deleted = 0
|
||||
run_count = 0
|
||||
while True:
|
||||
if run_count > SPEND_LOG_RUN_LOOPS:
|
||||
verbose_proxy_logger.info(
|
||||
"Max logs deleted - 1,00,000, rest of the logs will be deleted in next run"
|
||||
)
|
||||
break
|
||||
# Step 1: Find logs and delete them in one go without fetching to application
|
||||
# Delete in batches, limited by self.batch_size
|
||||
deleted_count = await prisma_client.db.execute_raw(
|
||||
"""
|
||||
DELETE FROM "LiteLLM_SpendLogs"
|
||||
WHERE "request_id" IN (
|
||||
SELECT "request_id" FROM "LiteLLM_SpendLogs"
|
||||
WHERE "startTime" < $1::timestamptz
|
||||
LIMIT $2
|
||||
)
|
||||
""",
|
||||
cutoff_date,
|
||||
self.batch_size,
|
||||
)
|
||||
verbose_proxy_logger.info(f"Deleted {deleted_count} logs in this batch")
|
||||
|
||||
if deleted_count == 0:
|
||||
verbose_proxy_logger.info(
|
||||
f"No more logs to delete. Total deleted: {total_deleted}"
|
||||
)
|
||||
break
|
||||
|
||||
total_deleted += deleted_count
|
||||
run_count += 1
|
||||
|
||||
# Add a small sleep to prevent overwhelming the database
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
return total_deleted
|
||||
|
||||
async def cleanup_old_spend_logs(self, prisma_client: PrismaClient) -> None:
|
||||
"""
|
||||
Main cleanup function. Deletes old spend logs in batches.
|
||||
If pod_lock_manager is available, ensures only one pod runs cleanup.
|
||||
If no pod_lock_manager, runs cleanup without distributed locking.
|
||||
"""
|
||||
lock_acquired = False
|
||||
try:
|
||||
verbose_proxy_logger.info(f"Cleanup job triggered at {datetime.now()}")
|
||||
|
||||
if not self._should_delete_spend_logs():
|
||||
return
|
||||
|
||||
if self.retention_seconds is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Retention seconds is None, cannot proceed with cleanup"
|
||||
)
|
||||
return
|
||||
|
||||
# If we have a pod lock manager, try to acquire the lock
|
||||
if self.pod_lock_manager and self.pod_lock_manager.redis_cache:
|
||||
lock_acquired = (
|
||||
await self.pod_lock_manager.acquire_lock(
|
||||
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME,
|
||||
)
|
||||
or False
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Lock acquisition attempt: {'successful' if lock_acquired else 'failed'} at {datetime.now()}"
|
||||
)
|
||||
|
||||
if not lock_acquired:
|
||||
verbose_proxy_logger.info("Another pod is already running cleanup")
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=float(self.retention_seconds)
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Deleting logs older than {cutoff_date.isoformat()}"
|
||||
)
|
||||
|
||||
# Perform the actual deletion
|
||||
total_deleted = await self._delete_old_logs(prisma_client, cutoff_date)
|
||||
verbose_proxy_logger.info(f"Deleted {total_deleted} logs")
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error during cleanup: {str(e)}")
|
||||
return # Return after error handling
|
||||
finally:
|
||||
# Only release the lock if it was actually acquired
|
||||
if (
|
||||
lock_acquired
|
||||
and self.pod_lock_manager
|
||||
and self.pod_lock_manager.redis_cache
|
||||
):
|
||||
await self.pod_lock_manager.release_lock(
|
||||
cronjob_id=SPEND_LOG_CLEANUP_JOB_NAME
|
||||
)
|
||||
verbose_proxy_logger.info("Released cleanup lock")
|
||||
@@ -0,0 +1,246 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
from litellm.proxy._types import (
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import (
|
||||
BaseUpdateQueue,
|
||||
service_logger_obj,
|
||||
)
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
class SpendUpdateQueue(BaseUpdateQueue):
|
||||
"""
|
||||
In memory buffer for spend updates that should be committed to the database
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.update_queue: asyncio.Queue[SpendUpdateQueueItem] = asyncio.Queue(
|
||||
maxsize=LITELLM_ASYNCIO_QUEUE_MAXSIZE
|
||||
)
|
||||
|
||||
async def flush_and_get_aggregated_db_spend_update_transactions(
|
||||
self,
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Flush all updates from the queue and return all updates aggregated by entity type."""
|
||||
updates = await self.flush_all_updates_from_in_memory_queue()
|
||||
if len(updates) > 0:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - flushed %d spend update items from in-memory queue",
|
||||
len(updates),
|
||||
)
|
||||
verbose_proxy_logger.debug("Aggregating updates by entity type: %s", updates)
|
||||
return self.get_aggregated_db_spend_update_transactions(updates)
|
||||
|
||||
async def add_update(self, update: SpendUpdateQueueItem):
|
||||
"""Enqueue an update to the spend update queue"""
|
||||
verbose_proxy_logger.debug("Adding update to queue: %s", update)
|
||||
await self.update_queue.put(update)
|
||||
|
||||
# if the queue is full, aggregate the updates
|
||||
if self.update_queue.qsize() >= self.MAX_SIZE_IN_MEMORY_QUEUE:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend update queue is full. Aggregating all entries in queue to concatenate entries."
|
||||
)
|
||||
await self.aggregate_queue_updates()
|
||||
|
||||
async def aggregate_queue_updates(self):
|
||||
"""Concatenate all updates in the queue to reduce the size of in-memory queue"""
|
||||
updates: List[
|
||||
SpendUpdateQueueItem
|
||||
] = await self.flush_all_updates_from_in_memory_queue()
|
||||
aggregated_updates = self._get_aggregated_spend_update_queue_item(updates)
|
||||
for update in aggregated_updates:
|
||||
await self.update_queue.put(update)
|
||||
return
|
||||
|
||||
def _get_aggregated_spend_update_queue_item(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> List[SpendUpdateQueueItem]:
|
||||
"""
|
||||
This is used to reduce the size of the in-memory queue by aggregating updates by entity type + id
|
||||
|
||||
|
||||
Aggregate updates by entity type + id
|
||||
|
||||
eg.
|
||||
|
||||
```
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 100
|
||||
},
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 200
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
|
||||
becomes
|
||||
|
||||
```
|
||||
|
||||
[
|
||||
{
|
||||
"entity_type": "user",
|
||||
"entity_id": "123",
|
||||
"response_cost": 300
|
||||
}
|
||||
]
|
||||
|
||||
```
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregating spend updates, current queue size: %s",
|
||||
self.update_queue.qsize(),
|
||||
)
|
||||
aggregated_spend_updates: List[SpendUpdateQueueItem] = []
|
||||
|
||||
_in_memory_map: Dict[str, SpendUpdateQueueItem] = {}
|
||||
"""
|
||||
Used for combining several updates into a single update
|
||||
Key=entity_type:entity_id
|
||||
Value=SpendUpdateQueueItem
|
||||
"""
|
||||
for update in updates:
|
||||
_key = f"{update.get('entity_type')}:{update.get('entity_id')}"
|
||||
if _key not in _in_memory_map:
|
||||
# avoid mutating caller-owned dicts while aggregating queue entries
|
||||
_in_memory_map[_key] = update.copy()
|
||||
else:
|
||||
current_cost = _in_memory_map[_key].get("response_cost", 0) or 0
|
||||
update_cost = update.get("response_cost", 0) or 0
|
||||
_in_memory_map[_key]["response_cost"] = current_cost + update_cost
|
||||
|
||||
for _key, update in _in_memory_map.items():
|
||||
aggregated_spend_updates.append(update)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Aggregated spend updates: %s", aggregated_spend_updates
|
||||
)
|
||||
return aggregated_spend_updates
|
||||
|
||||
def get_aggregated_db_spend_update_transactions(
|
||||
self, updates: List[SpendUpdateQueueItem]
|
||||
) -> DBSpendUpdateTransactions:
|
||||
"""Aggregate updates by entity type."""
|
||||
# Initialize all transaction lists as empty dicts
|
||||
db_spend_update_transactions = DBSpendUpdateTransactions(
|
||||
user_list_transactions={},
|
||||
end_user_list_transactions={},
|
||||
key_list_transactions={},
|
||||
team_list_transactions={},
|
||||
team_member_list_transactions={},
|
||||
org_list_transactions={},
|
||||
tag_list_transactions={},
|
||||
agent_list_transactions={},
|
||||
)
|
||||
|
||||
# Map entity types to their corresponding transaction dictionary keys
|
||||
entity_type_to_dict_key = {
|
||||
Litellm_EntityType.USER: "user_list_transactions",
|
||||
Litellm_EntityType.END_USER: "end_user_list_transactions",
|
||||
Litellm_EntityType.KEY: "key_list_transactions",
|
||||
Litellm_EntityType.TEAM: "team_list_transactions",
|
||||
Litellm_EntityType.TEAM_MEMBER: "team_member_list_transactions",
|
||||
Litellm_EntityType.ORGANIZATION: "org_list_transactions",
|
||||
Litellm_EntityType.TAG: "tag_list_transactions",
|
||||
Litellm_EntityType.AGENT: "agent_list_transactions",
|
||||
}
|
||||
|
||||
for update in updates:
|
||||
entity_type = update.get("entity_type")
|
||||
entity_id = update.get("entity_id") or ""
|
||||
response_cost = update.get("response_cost") or 0
|
||||
|
||||
if entity_type is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is None",
|
||||
update,
|
||||
)
|
||||
continue
|
||||
|
||||
dict_key = entity_type_to_dict_key.get(entity_type)
|
||||
if dict_key is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping update spend for update: %s, because entity_type is not in entity_type_to_dict_key",
|
||||
update,
|
||||
)
|
||||
continue # Skip unknown entity types
|
||||
|
||||
# Type-safe access using if/elif statements
|
||||
if dict_key == "user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"user_list_transactions"
|
||||
]
|
||||
elif dict_key == "end_user_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"end_user_list_transactions"
|
||||
]
|
||||
elif dict_key == "key_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"key_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_list_transactions"
|
||||
]
|
||||
elif dict_key == "team_member_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"team_member_list_transactions"
|
||||
]
|
||||
elif dict_key == "org_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"org_list_transactions"
|
||||
]
|
||||
elif dict_key == "tag_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"tag_list_transactions"
|
||||
]
|
||||
elif dict_key == "agent_list_transactions":
|
||||
transactions_dict = db_spend_update_transactions[
|
||||
"agent_list_transactions"
|
||||
]
|
||||
else:
|
||||
continue
|
||||
|
||||
if transactions_dict is None:
|
||||
transactions_dict = {}
|
||||
|
||||
# type ignore: dict_key is guaranteed to be one of "one of ("user_list_transactions", "end_user_list_transactions", "key_list_transactions", "team_list_transactions", "team_member_list_transactions", "org_list_transactions")"
|
||||
db_spend_update_transactions[dict_key] = transactions_dict # type: ignore
|
||||
|
||||
if entity_id not in transactions_dict:
|
||||
transactions_dict[entity_id] = 0
|
||||
|
||||
transactions_dict[entity_id] += response_cost or 0
|
||||
|
||||
return db_spend_update_transactions
|
||||
|
||||
async def _emit_new_item_added_to_queue_event(
|
||||
self,
|
||||
queue_size: Optional[int] = None,
|
||||
):
|
||||
asyncio.create_task(
|
||||
service_logger_obj.async_service_success_hook(
|
||||
service=ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
duration=0,
|
||||
call_type="_emit_new_item_added_to_queue_event",
|
||||
event_metadata={
|
||||
"gauge_labels": ServiceTypes.IN_MEMORY_SPEND_UPDATE_QUEUE,
|
||||
"gauge_value": queue_size,
|
||||
},
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""
|
||||
In-memory buffer for tool registry upserts.
|
||||
|
||||
Unlike SpendUpdateQueue (which aggregates increments), ToolDiscoveryQueue
|
||||
uses set-deduplication: each unique tool_name is only queued once per flush
|
||||
cycle (~30s). The seen-set is cleared on every flush so that call_count
|
||||
increments in subsequent cycles rather than stopping after the first flush.
|
||||
"""
|
||||
|
||||
from typing import List, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
|
||||
|
||||
class ToolDiscoveryQueue:
|
||||
"""
|
||||
In-memory buffer for tool registry upserts.
|
||||
|
||||
Deduplicates by tool_name within each flush cycle: a tool is only queued
|
||||
once per ~30s batch, so call_count increments once per flush cycle the
|
||||
tool appears in (not once per invocation, but not once per pod lifetime
|
||||
either). The seen-set is cleared on flush so subsequent batches can
|
||||
re-count the same tool.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._seen_tool_names: Set[str] = set()
|
||||
self._pending: List[ToolDiscoveryQueueItem] = []
|
||||
|
||||
def add_update(self, item: ToolDiscoveryQueueItem) -> None:
|
||||
"""Enqueue a tool discovery item if tool_name has not been seen before."""
|
||||
tool_name = item.get("tool_name", "")
|
||||
if not tool_name:
|
||||
return
|
||||
if tool_name in self._seen_tool_names:
|
||||
verbose_proxy_logger.debug(
|
||||
"ToolDiscoveryQueue: skipping already-seen tool %s", tool_name
|
||||
)
|
||||
return
|
||||
self._seen_tool_names.add(tool_name)
|
||||
self._pending.append(item)
|
||||
verbose_proxy_logger.debug(
|
||||
"ToolDiscoveryQueue: queued new tool %s (origin=%s)",
|
||||
tool_name,
|
||||
item.get("origin"),
|
||||
)
|
||||
|
||||
def flush(self) -> List[ToolDiscoveryQueueItem]:
|
||||
"""Return and clear all pending items. Resets seen-set so the next
|
||||
flush cycle can re-count the same tools."""
|
||||
items, self._pending = self._pending, []
|
||||
self._seen_tool_names.clear()
|
||||
return items
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Deprecated. Only PostgresSQL is supported.
|
||||
"""
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import DynamoDBArgs
|
||||
from litellm.proxy.db.base_client import CustomDB
|
||||
|
||||
|
||||
class DynamoDBWrapper(CustomDB):
|
||||
from aiodynamo.credentials import Credentials, StaticCredentials
|
||||
|
||||
credentials: Credentials
|
||||
|
||||
def __init__(self, database_arguments: DynamoDBArgs):
|
||||
from aiodynamo.models import PayPerRequest, Throughput
|
||||
|
||||
self.throughput_type = None
|
||||
if database_arguments.billing_mode == "PAY_PER_REQUEST":
|
||||
self.throughput_type = PayPerRequest()
|
||||
elif database_arguments.billing_mode == "PROVISIONED_THROUGHPUT":
|
||||
if (
|
||||
database_arguments.read_capacity_units is not None
|
||||
and isinstance(database_arguments.read_capacity_units, int)
|
||||
and database_arguments.write_capacity_units is not None
|
||||
and isinstance(database_arguments.write_capacity_units, int)
|
||||
):
|
||||
self.throughput_type = Throughput(read=database_arguments.read_capacity_units, write=database_arguments.write_capacity_units) # type: ignore
|
||||
else:
|
||||
raise Exception(
|
||||
f"Invalid args passed in. Need to set both read_capacity_units and write_capacity_units. Args passed in - {database_arguments}"
|
||||
)
|
||||
self.database_arguments = database_arguments
|
||||
self.region_name = database_arguments.region_name
|
||||
|
||||
def set_env_vars_based_on_arn(self):
|
||||
if self.database_arguments.aws_role_name is None:
|
||||
return
|
||||
verbose_proxy_logger.debug(
|
||||
f"DynamoDB: setting env vars based on arn={self.database_arguments.aws_role_name}"
|
||||
)
|
||||
import os
|
||||
|
||||
import boto3
|
||||
|
||||
sts_client = boto3.client("sts")
|
||||
|
||||
# call 1
|
||||
sts_client.assume_role_with_web_identity(
|
||||
RoleArn=self.database_arguments.aws_role_name,
|
||||
RoleSessionName=self.database_arguments.aws_session_name,
|
||||
WebIdentityToken=self.database_arguments.aws_web_identity_token,
|
||||
)
|
||||
|
||||
# call 2
|
||||
assumed_role = sts_client.assume_role(
|
||||
RoleArn=self.database_arguments.assume_role_aws_role_name,
|
||||
RoleSessionName=self.database_arguments.assume_role_aws_session_name,
|
||||
)
|
||||
|
||||
aws_access_key_id = assumed_role["Credentials"]["AccessKeyId"]
|
||||
aws_secret_access_key = assumed_role["Credentials"]["SecretAccessKey"]
|
||||
aws_session_token = assumed_role["Credentials"]["SessionToken"]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Got STS assumed Role, aws_access_key_id={aws_access_key_id}"
|
||||
)
|
||||
# set these in the env so aiodynamo can use them
|
||||
os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id
|
||||
os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key
|
||||
os.environ["AWS_SESSION_TOKEN"] = aws_session_token
|
||||
@@ -0,0 +1,106 @@
|
||||
from typing import Union
|
||||
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
class PrismaDBExceptionHandler:
|
||||
"""
|
||||
Class to handle DB Exceptions or Connection Errors
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def should_allow_request_on_db_unavailable() -> bool:
|
||||
"""
|
||||
Returns True if the request should be allowed to proceed despite the DB connection error
|
||||
"""
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
|
||||
_allow_requests_on_db_unavailable: Union[bool, str] = general_settings.get(
|
||||
"allow_requests_on_db_unavailable", False
|
||||
)
|
||||
if isinstance(_allow_requests_on_db_unavailable, bool):
|
||||
return _allow_requests_on_db_unavailable
|
||||
if str_to_bool(_allow_requests_on_db_unavailable) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_database_connection_error(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is from a database outage / connection error.
|
||||
Any PrismaError qualifies — the DB failed to serve the request.
|
||||
Used by allow_requests_on_db_unavailable logic and endpoint 503 responses.
|
||||
"""
|
||||
import prisma
|
||||
|
||||
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
|
||||
return True
|
||||
if isinstance(e, prisma.errors.PrismaError):
|
||||
return True
|
||||
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def is_database_transport_error(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True only for transport/connectivity failures where a reconnect
|
||||
attempt makes sense (e.g. DB is unreachable, connection dropped).
|
||||
|
||||
Use this for reconnect logic — data-layer errors like UniqueViolationError
|
||||
mean the DB IS reachable, so reconnecting would be pointless.
|
||||
"""
|
||||
import prisma
|
||||
|
||||
if isinstance(e, DB_CONNECTION_ERROR_TYPES):
|
||||
return True
|
||||
if isinstance(
|
||||
e,
|
||||
(
|
||||
prisma.errors.ClientNotConnectedError,
|
||||
prisma.errors.HTTPClientClosedError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
if isinstance(e, prisma.errors.PrismaError):
|
||||
error_message = str(e).lower()
|
||||
connection_keywords = (
|
||||
"can't reach database server",
|
||||
"cannot reach database server",
|
||||
"can't connect",
|
||||
"cannot connect",
|
||||
"connection error",
|
||||
"connection closed",
|
||||
"timed out",
|
||||
"timeout",
|
||||
"connection refused",
|
||||
"network is unreachable",
|
||||
"no route to host",
|
||||
"broken pipe",
|
||||
)
|
||||
if any(keyword in error_message for keyword in connection_keywords):
|
||||
return True
|
||||
if isinstance(e, ProxyException) and e.type == ProxyErrorTypes.no_db_connection:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def handle_db_exception(e: Exception):
|
||||
"""
|
||||
Primary handler for `allow_requests_on_db_unavailable` flag. Decides whether to raise a DB Exception or not based on the flag.
|
||||
|
||||
- If exception is a DB Connection Error, and `allow_requests_on_db_unavailable` is True,
|
||||
- Do not raise an exception, return None
|
||||
- Else, raise the exception
|
||||
"""
|
||||
if (
|
||||
PrismaDBExceptionHandler.is_database_connection_error(e)
|
||||
and PrismaDBExceptionHandler.should_allow_request_on_db_unavailable()
|
||||
):
|
||||
return None
|
||||
raise e
|
||||
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Handles logging DB success/failure to ServiceLogger()
|
||||
|
||||
ServiceLogger() then sends DB logs to Prometheus, OTEL, Datadog etc
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
from litellm._service_logger import ServiceTypes
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
)
|
||||
|
||||
|
||||
def log_db_metrics(func):
|
||||
"""
|
||||
Decorator to log the duration of a DB related function to ServiceLogger()
|
||||
|
||||
Handles logging DB success/failure to ServiceLogger(), which logs to Prometheus, OTEL, Datadog
|
||||
|
||||
When logging Failure it checks if the Exception is a PrismaError, httpx.ConnectError or httpx.TimeoutException and then logs that as a DB Service Failure
|
||||
|
||||
Args:
|
||||
func: The function to be decorated
|
||||
|
||||
Returns:
|
||||
Result from the decorated function
|
||||
|
||||
Raises:
|
||||
Exception: If the decorated function raises an exception
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time: datetime = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time: datetime = datetime.now()
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
if "PROXY" not in func.__name__:
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span", None),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif (
|
||||
# in litellm custom callbacks kwargs is passed as arg[0]
|
||||
# https://docs.litellm.ai/docs/observability/custom_callback#callback-functions
|
||||
args is not None
|
||||
and len(args) > 1
|
||||
and isinstance(args[1], dict)
|
||||
):
|
||||
passed_kwargs = args[1]
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(
|
||||
kwargs=passed_kwargs
|
||||
)
|
||||
if parent_otel_span is not None:
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs=passed_kwargs)
|
||||
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.service_logging_obj.async_service_success_hook(
|
||||
service=ServiceTypes.BATCH_WRITE_TO_DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=parent_otel_span,
|
||||
duration=0.0,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata=metadata,
|
||||
)
|
||||
)
|
||||
# end of logging to otel
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time: datetime = datetime.now()
|
||||
await _handle_logging_db_exception(
|
||||
e=e,
|
||||
func=func,
|
||||
kwargs=kwargs,
|
||||
args=args,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _is_exception_related_to_db(e: Exception) -> bool:
|
||||
"""
|
||||
Returns True if the exception is related to the DB
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from prisma.errors import PrismaError
|
||||
|
||||
return isinstance(e, (PrismaError, httpx.ConnectError, httpx.TimeoutException))
|
||||
|
||||
|
||||
async def _handle_logging_db_exception(
|
||||
e: Exception,
|
||||
func: Callable,
|
||||
kwargs: Dict,
|
||||
args: Tuple,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
) -> None:
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
# don't log this as a DB Service Failure, if the DB did not raise an exception
|
||||
if _is_exception_related_to_db(e) is not True:
|
||||
return
|
||||
|
||||
await proxy_logging_obj.service_logging_obj.async_service_failure_hook(
|
||||
error=e,
|
||||
service=ServiceTypes.DB,
|
||||
call_type=func.__name__,
|
||||
parent_otel_span=kwargs.get("parent_otel_span"),
|
||||
duration=(end_time - start_time).total_seconds(),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
event_metadata={
|
||||
"function_name": func.__name__,
|
||||
"function_kwargs": kwargs,
|
||||
"function_args": args,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,438 @@
|
||||
"""
|
||||
This file contains the PrismaWrapper class, which is used to wrap the Prisma client and handle the RDS IAM token.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import random
|
||||
import subprocess
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
|
||||
|
||||
class PrismaWrapper:
|
||||
"""
|
||||
Wrapper around Prisma client that handles RDS IAM token authentication.
|
||||
|
||||
When iam_token_db_auth is enabled, this wrapper:
|
||||
1. Proactively refreshes IAM tokens before they expire (background task)
|
||||
2. Falls back to synchronous refresh if a token is found expired
|
||||
3. Uses proper locking to prevent race conditions during reconnection
|
||||
|
||||
RDS IAM tokens are valid for 15 minutes. This wrapper refreshes them
|
||||
3 minutes before expiration to ensure uninterrupted database connectivity.
|
||||
"""
|
||||
|
||||
# Buffer time in seconds before token expiration to trigger refresh
|
||||
# Refresh 3 minutes (180 seconds) before the token expires
|
||||
TOKEN_REFRESH_BUFFER_SECONDS = 180
|
||||
|
||||
# Fallback refresh interval if token parsing fails (10 minutes)
|
||||
FALLBACK_REFRESH_INTERVAL_SECONDS = 600
|
||||
|
||||
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
|
||||
self._original_prisma = original_prisma
|
||||
self.iam_token_db_auth = iam_token_db_auth
|
||||
|
||||
# Background token refresh task management
|
||||
self._token_refresh_task: Optional[asyncio.Task] = None
|
||||
self._reconnection_lock = asyncio.Lock()
|
||||
self._last_refresh_time: Optional[datetime] = None
|
||||
|
||||
def _extract_token_from_db_url(self, db_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Extract the token (password) from the DATABASE_URL.
|
||||
|
||||
The token contains the AWS signature with X-Amz-Date and X-Amz-Expires parameters.
|
||||
|
||||
Important: We must parse the URL while it's still encoded to preserve structure,
|
||||
then decode the password portion. Otherwise the '?' in the token breaks URL parsing.
|
||||
"""
|
||||
if db_url is None:
|
||||
return None
|
||||
try:
|
||||
# Parse URL while still encoded to preserve structure
|
||||
parsed = urllib.parse.urlparse(db_url)
|
||||
if parsed.password:
|
||||
# Now decode just the password/token
|
||||
return urllib.parse.unquote(parsed.password)
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def _parse_token_expiration(self, token: Optional[str]) -> Optional[datetime]:
|
||||
"""
|
||||
Parse the token to extract its expiration time.
|
||||
|
||||
Returns the datetime when the token expires, or None if parsing fails.
|
||||
"""
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Token format: ...?X-Amz-Date=YYYYMMDDTHHMMSSZ&X-Amz-Expires=900&...
|
||||
if "?" not in token:
|
||||
return None
|
||||
|
||||
query_string = token.split("?", 1)[1]
|
||||
params = urllib.parse.parse_qs(query_string)
|
||||
|
||||
expires_str = params.get("X-Amz-Expires", [None])[0]
|
||||
date_str = params.get("X-Amz-Date", [None])[0]
|
||||
|
||||
if not expires_str or not date_str:
|
||||
return None
|
||||
|
||||
token_created = datetime.strptime(date_str, "%Y%m%dT%H%M%SZ")
|
||||
expires_in = int(expires_str)
|
||||
|
||||
return token_created + timedelta(seconds=expires_in)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Failed to parse token expiration: {e}")
|
||||
return None
|
||||
|
||||
def _calculate_seconds_until_refresh(self) -> float:
|
||||
"""
|
||||
Calculate exactly how many seconds until we need to refresh the token.
|
||||
|
||||
Uses precise timing: sleeps until (token_expiration - buffer_seconds).
|
||||
For a 15-minute (900s) token with 180s buffer, this returns ~720s (12 min).
|
||||
|
||||
Returns:
|
||||
Number of seconds to sleep before the next refresh.
|
||||
Returns 0 if token should be refreshed immediately.
|
||||
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
|
||||
"""
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
token = self._extract_token_from_db_url(db_url)
|
||||
expiration_time = self._parse_token_expiration(token)
|
||||
|
||||
if expiration_time is None:
|
||||
# If we can't parse the token, use fallback interval
|
||||
verbose_proxy_logger.debug(
|
||||
f"Could not parse token expiration, using fallback interval of "
|
||||
f"{self.FALLBACK_REFRESH_INTERVAL_SECONDS}s"
|
||||
)
|
||||
return self.FALLBACK_REFRESH_INTERVAL_SECONDS
|
||||
|
||||
# Calculate when we should refresh (expiration - buffer)
|
||||
refresh_at = expiration_time - timedelta(
|
||||
seconds=self.TOKEN_REFRESH_BUFFER_SECONDS
|
||||
)
|
||||
|
||||
# How long until refresh time?
|
||||
now = datetime.utcnow()
|
||||
seconds_until_refresh = (refresh_at - now).total_seconds()
|
||||
|
||||
# If already past refresh time, return 0 (refresh immediately)
|
||||
return max(0, seconds_until_refresh)
|
||||
|
||||
def is_token_expired(self, token_url: Optional[str]) -> bool:
|
||||
"""Check if the token in the given URL is expired."""
|
||||
if token_url is None:
|
||||
return True
|
||||
|
||||
token = self._extract_token_from_db_url(token_url)
|
||||
expiration_time = self._parse_token_expiration(token)
|
||||
|
||||
if expiration_time is None:
|
||||
# If we can't parse the token, assume it's expired to trigger refresh
|
||||
verbose_proxy_logger.debug(
|
||||
"Could not parse token expiration, treating as expired"
|
||||
)
|
||||
return True
|
||||
|
||||
return datetime.utcnow() > expiration_time
|
||||
|
||||
def get_rds_iam_token(self) -> Optional[str]:
|
||||
"""Generate a new RDS IAM token and update DATABASE_URL."""
|
||||
if self.iam_token_db_auth:
|
||||
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||
|
||||
db_host = os.getenv("DATABASE_HOST")
|
||||
db_port = os.getenv("DATABASE_PORT")
|
||||
db_user = os.getenv("DATABASE_USER")
|
||||
db_name = os.getenv("DATABASE_NAME")
|
||||
db_schema = os.getenv("DATABASE_SCHEMA")
|
||||
|
||||
token = generate_iam_auth_token(
|
||||
db_host=db_host, db_port=db_port, db_user=db_user
|
||||
)
|
||||
|
||||
_db_url = f"postgresql://{db_user}:{token}@{db_host}:{db_port}/{db_name}"
|
||||
if db_schema:
|
||||
_db_url += f"?schema={db_schema}"
|
||||
|
||||
os.environ["DATABASE_URL"] = _db_url
|
||||
return _db_url
|
||||
return None
|
||||
|
||||
async def recreate_prisma_client(
|
||||
self, new_db_url: str, http_client: Optional[Any] = None
|
||||
):
|
||||
"""Disconnect and reconnect the Prisma client with a new database URL."""
|
||||
from prisma import Prisma # type: ignore
|
||||
|
||||
try:
|
||||
await self._original_prisma.disconnect()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Failed to disconnect Prisma client: {e}")
|
||||
|
||||
if http_client is not None:
|
||||
self._original_prisma = Prisma(http=http_client)
|
||||
else:
|
||||
self._original_prisma = Prisma()
|
||||
|
||||
await self._original_prisma.connect()
|
||||
|
||||
async def start_token_refresh_task(self) -> None:
|
||||
"""
|
||||
Start the background token refresh task.
|
||||
|
||||
This task proactively refreshes RDS IAM tokens before they expire,
|
||||
preventing connection failures. Should be called after the initial
|
||||
Prisma client connection is established.
|
||||
"""
|
||||
if not self.iam_token_db_auth:
|
||||
verbose_proxy_logger.debug(
|
||||
"IAM token auth not enabled, skipping token refresh task"
|
||||
)
|
||||
return
|
||||
|
||||
if self._token_refresh_task is not None:
|
||||
verbose_proxy_logger.debug("Token refresh task already running")
|
||||
return
|
||||
|
||||
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
|
||||
verbose_proxy_logger.info(
|
||||
"Started RDS IAM token proactive refresh background task"
|
||||
)
|
||||
|
||||
async def stop_token_refresh_task(self) -> None:
|
||||
"""
|
||||
Stop the background token refresh task gracefully.
|
||||
|
||||
Should be called during application shutdown to clean up resources.
|
||||
"""
|
||||
if self._token_refresh_task is None:
|
||||
return
|
||||
|
||||
self._token_refresh_task.cancel()
|
||||
try:
|
||||
await self._token_refresh_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._token_refresh_task = None
|
||||
verbose_proxy_logger.info("Stopped RDS IAM token refresh background task")
|
||||
|
||||
async def _token_refresh_loop(self) -> None:
|
||||
"""
|
||||
Background loop that proactively refreshes RDS IAM tokens before expiration.
|
||||
|
||||
Uses precise timing: calculates the exact sleep duration until the token
|
||||
needs to be refreshed (expiration - 3 minute buffer), then refreshes.
|
||||
This is more efficient than polling, requiring only 1 wake-up per token cycle.
|
||||
"""
|
||||
verbose_proxy_logger.info(
|
||||
f"RDS IAM token refresh loop started. "
|
||||
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Calculate exactly how long to sleep until next refresh
|
||||
sleep_seconds = self._calculate_seconds_until_refresh()
|
||||
|
||||
if sleep_seconds > 0:
|
||||
verbose_proxy_logger.info(
|
||||
f"RDS IAM token refresh scheduled in {sleep_seconds:.0f} seconds "
|
||||
f"({sleep_seconds / 60:.1f} minutes)"
|
||||
)
|
||||
await asyncio.sleep(sleep_seconds)
|
||||
|
||||
# Refresh the token
|
||||
verbose_proxy_logger.info("Proactively refreshing RDS IAM token...")
|
||||
await self._safe_refresh_token()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
verbose_proxy_logger.info("RDS IAM token refresh loop cancelled")
|
||||
break
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in RDS IAM token refresh loop: {e}. "
|
||||
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
|
||||
)
|
||||
# On error, wait before retrying to avoid tight error loops
|
||||
try:
|
||||
await asyncio.sleep(self.FALLBACK_REFRESH_INTERVAL_SECONDS)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
async def _safe_refresh_token(self) -> None:
|
||||
"""
|
||||
Refresh the RDS IAM token with proper locking to prevent race conditions.
|
||||
|
||||
Uses an asyncio lock to ensure only one refresh operation happens at a time,
|
||||
preventing multiple concurrent reconnection attempts.
|
||||
"""
|
||||
async with self._reconnection_lock:
|
||||
new_db_url = self.get_rds_iam_token()
|
||||
if new_db_url:
|
||||
await self.recreate_prisma_client(new_db_url)
|
||||
self._last_refresh_time = datetime.utcnow()
|
||||
verbose_proxy_logger.info(
|
||||
"RDS IAM token refreshed successfully. New token valid for ~15 minutes."
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.error(
|
||||
"Failed to generate new RDS IAM token during proactive refresh"
|
||||
)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
"""
|
||||
Proxy attribute access to the underlying Prisma client.
|
||||
|
||||
If IAM token auth is enabled and the token is expired, this method
|
||||
provides a synchronous fallback to refresh the token. However, this
|
||||
should rarely be needed since the background task proactively refreshes
|
||||
tokens before they expire.
|
||||
|
||||
FIXED: Now properly waits for reconnection to complete before returning,
|
||||
instead of the previous fire-and-forget pattern that caused the bug.
|
||||
"""
|
||||
original_attr = getattr(self._original_prisma, name)
|
||||
|
||||
if self.iam_token_db_auth:
|
||||
db_url = os.getenv("DATABASE_URL")
|
||||
|
||||
# Check if token is expired (should be rare if background task is running)
|
||||
if self.is_token_expired(db_url):
|
||||
verbose_proxy_logger.warning(
|
||||
"RDS IAM token expired in __getattr__ - proactive refresh may have failed. "
|
||||
"Triggering synchronous fallback refresh..."
|
||||
)
|
||||
|
||||
new_db_url = self.get_rds_iam_token()
|
||||
if new_db_url:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
if loop.is_running():
|
||||
# FIXED: Actually wait for the reconnection to complete!
|
||||
# The previous code used fire-and-forget which caused the bug.
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self.recreate_prisma_client(new_db_url), loop
|
||||
)
|
||||
try:
|
||||
# Wait up to 30 seconds for reconnection
|
||||
future.result(timeout=30)
|
||||
verbose_proxy_logger.info(
|
||||
"Synchronous token refresh completed successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Failed to refresh token synchronously: {e}"
|
||||
)
|
||||
raise
|
||||
else:
|
||||
asyncio.run(self.recreate_prisma_client(new_db_url))
|
||||
|
||||
# Get the NEW attribute after reconnection
|
||||
original_attr = getattr(self._original_prisma, name)
|
||||
else:
|
||||
raise ValueError("Failed to get RDS IAM token")
|
||||
|
||||
return original_attr
|
||||
|
||||
|
||||
class PrismaManager:
|
||||
@staticmethod
|
||||
def _get_prisma_dir() -> str:
|
||||
"""Get the path to the migrations directory"""
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(os.path.dirname(abspath))
|
||||
return dname
|
||||
|
||||
@staticmethod
|
||||
def setup_database(use_migrate: bool = False) -> bool:
|
||||
"""
|
||||
Set up the database using either prisma migrate or prisma db push
|
||||
|
||||
Returns:
|
||||
bool: True if setup was successful, False otherwise
|
||||
"""
|
||||
|
||||
for attempt in range(4):
|
||||
original_dir = os.getcwd()
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
os.chdir(prisma_dir)
|
||||
try:
|
||||
if use_migrate:
|
||||
try:
|
||||
from litellm_proxy_extras.utils import ProxyExtrasDBManager
|
||||
except ImportError as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"\033[1;31mLiteLLM: Failed to import proxy extras. Got {e}\033[0m"
|
||||
)
|
||||
return False
|
||||
|
||||
prisma_dir = PrismaManager._get_prisma_dir()
|
||||
|
||||
return ProxyExtrasDBManager.setup_database(use_migrate=use_migrate)
|
||||
else:
|
||||
# Use prisma db push with increased timeout
|
||||
subprocess.run(
|
||||
["prisma", "db", "push", "--accept-data-loss"],
|
||||
timeout=60,
|
||||
check=True,
|
||||
)
|
||||
return True
|
||||
except subprocess.TimeoutExpired:
|
||||
verbose_proxy_logger.warning(f"Attempt {attempt + 1} timed out")
|
||||
time.sleep(random.randrange(5, 15))
|
||||
except subprocess.CalledProcessError as e:
|
||||
attempts_left = 3 - attempt
|
||||
retry_msg = (
|
||||
f" Retrying... ({attempts_left} attempts left)"
|
||||
if attempts_left > 0
|
||||
else ""
|
||||
)
|
||||
verbose_proxy_logger.warning(
|
||||
f"The process failed to execute. Details: {e}.{retry_msg}"
|
||||
)
|
||||
time.sleep(random.randrange(5, 15))
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
return False
|
||||
|
||||
|
||||
def should_update_prisma_schema(
|
||||
disable_updates: Optional[Union[bool, str]] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Determines if Prisma Schema updates should be applied during startup.
|
||||
|
||||
Args:
|
||||
disable_updates: Controls whether schema updates are disabled.
|
||||
Accepts boolean or string ('true'/'false'). Defaults to checking DISABLE_SCHEMA_UPDATE env var.
|
||||
|
||||
Returns:
|
||||
bool: True if schema updates should be applied, False if updates are disabled.
|
||||
|
||||
Examples:
|
||||
>>> should_update_prisma_schema() # Checks DISABLE_SCHEMA_UPDATE env var
|
||||
>>> should_update_prisma_schema(True) # Explicitly disable updates
|
||||
>>> should_update_prisma_schema("false") # Enable updates using string
|
||||
"""
|
||||
if disable_updates is None:
|
||||
disable_updates = os.getenv("DISABLE_SCHEMA_UPDATE", "false")
|
||||
|
||||
if isinstance(disable_updates, str):
|
||||
disable_updates = str_to_bool(disable_updates)
|
||||
|
||||
return not bool(disable_updates)
|
||||
@@ -0,0 +1,151 @@
|
||||
"""
|
||||
Track tool usage for the dashboard: insert into SpendLogToolIndex when spend logs
|
||||
are written, so "last N requests for tool X" and "how is this tool called in production"
|
||||
queries are fast.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
|
||||
"""Extract tool names from OpenAI-style tool_calls list into out."""
|
||||
if not isinstance(tool_calls, list):
|
||||
return
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
fn = tc.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if name and isinstance(name, str) and name.strip():
|
||||
out.add(name.strip())
|
||||
|
||||
|
||||
def _parse_tool_names_from_payload(payload: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Extract deduplicated tool names from a spend log payload.
|
||||
Sources: mcp_namespaced_tool_name, response (tool_calls), proxy_server_request (tools).
|
||||
"""
|
||||
tool_names: Set[str] = set()
|
||||
|
||||
# Top-level MCP tool name (single tool per request for that flow)
|
||||
mcp_name = payload.get("mcp_namespaced_tool_name")
|
||||
if mcp_name and isinstance(mcp_name, str) and mcp_name.strip():
|
||||
tool_names.add(mcp_name.strip())
|
||||
|
||||
# Response: OpenAI-style tool_calls[].function.name or choices[0].message.tool_calls
|
||||
response_raw = payload.get("response")
|
||||
if response_raw:
|
||||
response_obj = (
|
||||
safe_json_loads(response_raw, default=None)
|
||||
if isinstance(response_raw, str)
|
||||
else response_raw
|
||||
)
|
||||
if isinstance(response_obj, dict):
|
||||
_add_tool_calls_to_set(response_obj.get("tool_calls"), tool_names)
|
||||
choices = response_obj.get("choices")
|
||||
if isinstance(choices, list) and choices:
|
||||
msg = (
|
||||
choices[0].get("message") if isinstance(choices[0], dict) else None
|
||||
)
|
||||
if isinstance(msg, dict):
|
||||
_add_tool_calls_to_set(msg.get("tool_calls"), tool_names)
|
||||
|
||||
# Request body: tools[].function.name
|
||||
request_raw = payload.get("proxy_server_request")
|
||||
if request_raw:
|
||||
request_obj = (
|
||||
safe_json_loads(request_raw, default=None)
|
||||
if isinstance(request_raw, str)
|
||||
else request_raw
|
||||
)
|
||||
if isinstance(request_obj, dict):
|
||||
body = request_obj.get("body", request_obj)
|
||||
if isinstance(body, dict):
|
||||
request_obj = body
|
||||
if isinstance(request_obj, dict):
|
||||
tools = request_obj.get("tools")
|
||||
if isinstance(tools, list):
|
||||
for t in tools:
|
||||
if isinstance(t, dict):
|
||||
fn = t.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if name and isinstance(name, str) and name.strip():
|
||||
tool_names.add(name.strip())
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
async def process_spend_logs_tool_usage(
|
||||
prisma_client: PrismaClient,
|
||||
logs_to_process: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
After spend logs are written: insert SpendLogToolIndex rows from each payload.
|
||||
Extracts tool names from mcp_namespaced_tool_name, response tool_calls, and
|
||||
proxy_server_request tools.
|
||||
"""
|
||||
if not logs_to_process:
|
||||
return
|
||||
|
||||
index_rows: List[Dict[str, Any]] = []
|
||||
|
||||
for payload in logs_to_process:
|
||||
request_id = payload.get("request_id")
|
||||
start_time = payload.get("startTime")
|
||||
if not request_id or not start_time:
|
||||
continue
|
||||
if isinstance(start_time, str):
|
||||
try:
|
||||
start_time = datetime.fromisoformat(start_time.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if start_time.tzinfo is None:
|
||||
start_time = start_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
tool_names = _parse_tool_names_from_payload(payload)
|
||||
for tool_name in tool_names:
|
||||
index_rows.append(
|
||||
{
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"start_time": start_time,
|
||||
}
|
||||
)
|
||||
|
||||
if not index_rows:
|
||||
return
|
||||
|
||||
try:
|
||||
index_data = []
|
||||
for r in index_rows:
|
||||
st = r["start_time"]
|
||||
if isinstance(st, str):
|
||||
try:
|
||||
st = datetime.fromisoformat(st.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if st.tzinfo is None:
|
||||
st = st.replace(tzinfo=timezone.utc)
|
||||
index_data.append(
|
||||
{
|
||||
"request_id": r["request_id"],
|
||||
"tool_name": r["tool_name"],
|
||||
"start_time": st,
|
||||
}
|
||||
)
|
||||
if index_data:
|
||||
await prisma_client.db.litellm_spendlogtoolindex.create_many(
|
||||
data=index_data,
|
||||
skip_duplicates=True,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Tool usage tracking (SpendLogToolIndex) failed (non-fatal): %s", e
|
||||
)
|
||||
@@ -0,0 +1,440 @@
|
||||
"""
|
||||
DB helpers for LiteLLM_ToolTable — the global tool registry.
|
||||
|
||||
Tools are auto-discovered from LLM responses and upserted here.
|
||||
Admins use the management endpoints to read and update input_policy / output_policy.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolPolicyOverrideRow,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
|
||||
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
|
||||
model_dump = getattr(row, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
row = model_dump()
|
||||
elif not isinstance(row, dict):
|
||||
row = {
|
||||
k: getattr(row, k, None)
|
||||
for k in (
|
||||
"tool_id",
|
||||
"tool_name",
|
||||
"origin",
|
||||
"input_policy",
|
||||
"output_policy",
|
||||
"call_count",
|
||||
"assignments",
|
||||
"key_hash",
|
||||
"team_id",
|
||||
"key_alias",
|
||||
"user_agent",
|
||||
"last_used_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"created_by",
|
||||
"updated_by",
|
||||
)
|
||||
}
|
||||
return LiteLLM_ToolTableRow(
|
||||
tool_id=row.get("tool_id", ""),
|
||||
tool_name=row.get("tool_name", ""),
|
||||
origin=row.get("origin"),
|
||||
input_policy=row.get("input_policy") or "untrusted",
|
||||
output_policy=row.get("output_policy") or "untrusted",
|
||||
call_count=int(row.get("call_count") or 0),
|
||||
assignments=row.get("assignments"),
|
||||
key_hash=row.get("key_hash"),
|
||||
team_id=row.get("team_id"),
|
||||
key_alias=row.get("key_alias"),
|
||||
user_agent=row.get("user_agent"),
|
||||
last_used_at=row.get("last_used_at"),
|
||||
created_at=row.get("created_at"),
|
||||
updated_at=row.get("updated_at"),
|
||||
created_by=row.get("created_by"),
|
||||
updated_by=row.get("updated_by"),
|
||||
)
|
||||
|
||||
|
||||
async def batch_upsert_tools(
|
||||
prisma_client: "PrismaClient",
|
||||
items: List[ToolDiscoveryQueueItem],
|
||||
) -> None:
|
||||
"""
|
||||
Batch-upsert tool registry rows via Prisma.
|
||||
|
||||
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
|
||||
On conflict: increments call_count; preserves existing policies.
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
try:
|
||||
data = [item for item in items if item.get("tool_name")]
|
||||
if not data:
|
||||
return
|
||||
now = datetime.now(timezone.utc)
|
||||
table = prisma_client.db.litellm_tooltable
|
||||
for item in data:
|
||||
tool_name = item.get("tool_name", "")
|
||||
origin = item.get("origin") or "user_defined"
|
||||
created_by = item.get("created_by") or "system"
|
||||
key_hash = item.get("key_hash")
|
||||
team_id = item.get("team_id")
|
||||
key_alias = item.get("key_alias")
|
||||
user_agent = item.get("user_agent")
|
||||
await table.upsert(
|
||||
where={"tool_name": tool_name},
|
||||
data={
|
||||
"create": {
|
||||
"tool_id": str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"origin": origin,
|
||||
"input_policy": "untrusted",
|
||||
"output_policy": "untrusted",
|
||||
"call_count": 1,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"key_hash": key_hash,
|
||||
"team_id": team_id,
|
||||
"key_alias": key_alias,
|
||||
"user_agent": user_agent,
|
||||
"last_used_at": now,
|
||||
},
|
||||
"update": {
|
||||
"call_count": {"increment": 1},
|
||||
"updated_at": now,
|
||||
"last_used_at": now,
|
||||
},
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"tool_registry_writer: upserted %d tool(s)", len(data)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer batch_upsert_tools error: %s", e
|
||||
)
|
||||
|
||||
|
||||
async def list_tools(
|
||||
prisma_client: "PrismaClient",
|
||||
input_policy: Optional[str] = None,
|
||||
) -> List[LiteLLM_ToolTableRow]:
|
||||
"""Return all tools, optionally filtered by input_policy."""
|
||||
try:
|
||||
where = {"input_policy": input_policy} if input_policy is not None else {}
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
where=where,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
return [_row_to_model(row) for row in rows]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
|
||||
return []
|
||||
|
||||
|
||||
async def get_tool(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_name: str,
|
||||
) -> Optional[LiteLLM_ToolTableRow]:
|
||||
"""Return a single tool row by tool_name."""
|
||||
try:
|
||||
row = await prisma_client.db.litellm_tooltable.find_unique(
|
||||
where={"tool_name": tool_name},
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_model(row)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
async def update_tool_policy(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_name: str,
|
||||
updated_by: Optional[str],
|
||||
input_policy: Optional[str] = None,
|
||||
output_policy: Optional[str] = None,
|
||||
) -> Optional[LiteLLM_ToolTableRow]:
|
||||
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
|
||||
try:
|
||||
_updated_by = updated_by or "system"
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
create_data: dict = {
|
||||
"tool_id": str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"input_policy": input_policy or "untrusted",
|
||||
"output_policy": output_policy or "untrusted",
|
||||
"created_by": _updated_by,
|
||||
"updated_by": _updated_by,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
update_data: dict = {
|
||||
"updated_by": _updated_by,
|
||||
"updated_at": now,
|
||||
}
|
||||
if input_policy is not None:
|
||||
update_data["input_policy"] = input_policy
|
||||
if output_policy is not None:
|
||||
update_data["output_policy"] = output_policy
|
||||
|
||||
await prisma_client.db.litellm_tooltable.upsert(
|
||||
where={"tool_name": tool_name},
|
||||
data={
|
||||
"create": create_data,
|
||||
"update": update_data,
|
||||
},
|
||||
)
|
||||
return await get_tool(prisma_client, tool_name)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer update_tool_policy error: %s", e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def get_tools_by_names(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_names: List[str],
|
||||
) -> Dict[str, Tuple[str, str]]:
|
||||
"""
|
||||
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
|
||||
"""
|
||||
if not tool_names:
|
||||
return {}
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
where={"tool_name": {"in": tool_names}},
|
||||
)
|
||||
return {
|
||||
row.tool_name: (
|
||||
getattr(row, "input_policy", "untrusted") or "untrusted",
|
||||
getattr(row, "output_policy", "untrusted") or "untrusted",
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer get_tools_by_names error: %s", e
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
async def list_overrides_for_tool(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_name: str,
|
||||
) -> List[ToolPolicyOverrideRow]:
|
||||
"""
|
||||
Return override-like rows for a tool by finding object permissions that have
|
||||
this tool in blocked_tools, then resolving each permission to key/team scope for display.
|
||||
"""
|
||||
out: List[ToolPolicyOverrideRow] = []
|
||||
try:
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
where={"blocked_tools": {"has": tool_name}},
|
||||
include={
|
||||
"verification_tokens": True,
|
||||
"teams": True,
|
||||
},
|
||||
)
|
||||
for perm in perms:
|
||||
op_id = getattr(perm, "object_permission_id", None) or ""
|
||||
tokens = getattr(perm, "verification_tokens", []) or []
|
||||
teams = getattr(perm, "teams", []) or []
|
||||
for t in tokens:
|
||||
out.append(
|
||||
ToolPolicyOverrideRow(
|
||||
override_id=op_id,
|
||||
tool_name=tool_name,
|
||||
team_id=None,
|
||||
key_hash=getattr(t, "token", None),
|
||||
input_policy="blocked",
|
||||
key_alias=getattr(t, "key_alias", None),
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
)
|
||||
for team in teams:
|
||||
out.append(
|
||||
ToolPolicyOverrideRow(
|
||||
override_id=op_id,
|
||||
tool_name=tool_name,
|
||||
team_id=getattr(team, "team_id", None),
|
||||
key_hash=None,
|
||||
input_policy="blocked",
|
||||
key_alias=getattr(team, "team_alias", None),
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
)
|
||||
return out
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer list_overrides_for_tool error: %s", e
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
class ToolPolicyRegistry:
|
||||
"""
|
||||
In-memory registry of tool policies synced from DB.
|
||||
Hot path uses get_effective_policies only — no DB, no cache.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tool_input_policies: Dict[str, str] = {}
|
||||
self._tool_output_policies: Dict[str, str] = {}
|
||||
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
|
||||
"""Load all tool policies and object-permission blocked_tools from DB."""
|
||||
try:
|
||||
tools = await prisma_client.db.litellm_tooltable.find_many()
|
||||
self._tool_input_policies = {
|
||||
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
}
|
||||
self._tool_output_policies = {
|
||||
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
}
|
||||
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
|
||||
self._blocked_tools_by_op_id = {}
|
||||
for row in perms:
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
blocked = getattr(row, "blocked_tools", None) or []
|
||||
if op_id:
|
||||
self._blocked_tools_by_op_id[op_id] = list(blocked)
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(
|
||||
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
|
||||
len(self._tool_input_policies),
|
||||
len(self._blocked_tools_by_op_id),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
|
||||
)
|
||||
raise
|
||||
|
||||
def get_input_policy(self, tool_name: str) -> str:
|
||||
return self._tool_input_policies.get(tool_name, "untrusted")
|
||||
|
||||
def get_output_policy(self, tool_name: str) -> str:
|
||||
return self._tool_output_policies.get(tool_name, "untrusted")
|
||||
|
||||
def get_effective_policies(
|
||||
self,
|
||||
tool_names: List[str],
|
||||
object_permission_id: Optional[str] = None,
|
||||
team_object_permission_id: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Return effective input_policy per tool from in-memory state.
|
||||
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
|
||||
"""
|
||||
if not tool_names:
|
||||
return {}
|
||||
blocked: set = set()
|
||||
for op_id in (object_permission_id, team_object_permission_id):
|
||||
if op_id and op_id.strip():
|
||||
blocked.update(self._blocked_tools_by_op_id.get(op_id.strip(), []))
|
||||
result: Dict[str, str] = {}
|
||||
for name in tool_names:
|
||||
if name in blocked:
|
||||
result[name] = "blocked"
|
||||
else:
|
||||
result[name] = self._tool_input_policies.get(name, "untrusted")
|
||||
return result
|
||||
|
||||
|
||||
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
|
||||
|
||||
|
||||
def get_tool_policy_registry() -> ToolPolicyRegistry:
|
||||
"""Return the global ToolPolicyRegistry singleton."""
|
||||
global _tool_policy_registry
|
||||
if _tool_policy_registry is None:
|
||||
_tool_policy_registry = ToolPolicyRegistry()
|
||||
return _tool_policy_registry
|
||||
|
||||
|
||||
async def add_tool_to_object_permission_blocked(
|
||||
prisma_client: "PrismaClient",
|
||||
object_permission_id: str,
|
||||
tool_name: str,
|
||||
) -> bool:
|
||||
"""Add tool_name to the permission's blocked_tools if not already present."""
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
return False
|
||||
current = list(getattr(row, "blocked_tools", []) or [])
|
||||
if tool_name in current:
|
||||
return True
|
||||
current.append(tool_name)
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def remove_tool_from_object_permission_blocked(
|
||||
prisma_client: "PrismaClient",
|
||||
object_permission_id: str,
|
||||
tool_name: str,
|
||||
) -> bool:
|
||||
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
return False
|
||||
current = list(getattr(row, "blocked_tools", []) or [])
|
||||
if tool_name not in current:
|
||||
return False
|
||||
current = [t for t in current if t != tool_name]
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
|
||||
e,
|
||||
)
|
||||
return False
|
||||
Reference in New Issue
Block a user