chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,579 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.types.proxy.cloudzero_endpoints import (
|
||||
CloudZeroExportRequest,
|
||||
CloudZeroExportResponse,
|
||||
CloudZeroInitRequest,
|
||||
CloudZeroInitResponse,
|
||||
CloudZeroSettingsUpdate,
|
||||
CloudZeroSettingsView,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Initialize the sensitive data masker for API key masking
|
||||
_sensitive_masker = SensitiveDataMasker()
|
||||
|
||||
|
||||
async def _set_cloudzero_settings(api_key: str, connection_id: str, timezone: str):
|
||||
"""
|
||||
Store CloudZero settings in the database with encrypted API key.
|
||||
|
||||
Args:
|
||||
api_key: CloudZero API key to encrypt and store
|
||||
connection_id: CloudZero connection ID
|
||||
timezone: Timezone for date handling
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Encrypt the API key before storing
|
||||
encrypted_api_key = encrypt_value_helper(api_key)
|
||||
|
||||
cloudzero_settings = {
|
||||
"api_key": encrypted_api_key,
|
||||
"connection_id": connection_id,
|
||||
"timezone": timezone,
|
||||
}
|
||||
|
||||
await prisma_client.db.litellm_config.upsert(
|
||||
where={"param_name": "cloudzero_settings"},
|
||||
data={
|
||||
"create": {
|
||||
"param_name": "cloudzero_settings",
|
||||
"param_value": json.dumps(cloudzero_settings),
|
||||
},
|
||||
"update": {"param_value": json.dumps(cloudzero_settings)},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def _get_cloudzero_settings():
|
||||
"""
|
||||
Retrieve CloudZero settings from the database with decrypted API key.
|
||||
|
||||
Returns:
|
||||
dict: CloudZero settings with decrypted API key, or empty dict if not configured
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
if cloudzero_config is None or cloudzero_config.param_value is None:
|
||||
return {}
|
||||
|
||||
# Handle both dict and JSON string cases
|
||||
if isinstance(cloudzero_config.param_value, dict):
|
||||
settings = cloudzero_config.param_value
|
||||
elif isinstance(cloudzero_config.param_value, str):
|
||||
settings = json.loads(cloudzero_config.param_value)
|
||||
else:
|
||||
settings = dict(cloudzero_config.param_value)
|
||||
|
||||
# Decrypt the API key
|
||||
encrypted_api_key = settings.get("api_key")
|
||||
if encrypted_api_key:
|
||||
decrypted_api_key = decrypt_value_helper(
|
||||
encrypted_api_key, key="cloudzero_api_key", exception_type="error"
|
||||
)
|
||||
if decrypted_api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Failed to decrypt CloudZero API key. Check your salt key configuration."
|
||||
},
|
||||
)
|
||||
settings["api_key"] = decrypted_api_key
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
@router.get(
|
||||
"/cloudzero/settings",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroSettingsView,
|
||||
)
|
||||
async def get_cloudzero_settings(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
View current CloudZero settings.
|
||||
|
||||
Returns the current CloudZero configuration with the API key masked for security.
|
||||
Only the first 4 and last 4 characters of the API key are shown.
|
||||
Returns null/empty values when settings are not configured (consistent with other settings endpoints).
|
||||
|
||||
Only admin users can view CloudZero settings.
|
||||
"""
|
||||
# Validation
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get CloudZero settings using the accessor method
|
||||
settings = await _get_cloudzero_settings()
|
||||
|
||||
# If settings are empty, return null/empty values (consistent with other endpoints)
|
||||
if not settings:
|
||||
return CloudZeroSettingsView(
|
||||
api_key_masked=None,
|
||||
connection_id=None,
|
||||
timezone=None,
|
||||
status=None,
|
||||
)
|
||||
|
||||
# Use SensitiveDataMasker to mask the API key
|
||||
masked_settings = _sensitive_masker.mask_dict(settings)
|
||||
|
||||
return CloudZeroSettingsView(
|
||||
api_key_masked=masked_settings.get("api_key"),
|
||||
connection_id=settings.get("connection_id"),
|
||||
timezone=settings.get("timezone"),
|
||||
status="configured",
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
# Re-raise HTTPExceptions as-is
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error retrieving CloudZero settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to retrieve CloudZero settings: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/cloudzero/settings",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroInitResponse,
|
||||
)
|
||||
async def update_cloudzero_settings(
|
||||
request: CloudZeroSettingsUpdate,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update existing CloudZero settings.
|
||||
|
||||
Allows updating individual CloudZero configuration fields without requiring all fields.
|
||||
Only provided fields will be updated; others will remain unchanged.
|
||||
|
||||
Parameters:
|
||||
- api_key: (Optional) New CloudZero API key for authentication
|
||||
- connection_id: (Optional) New CloudZero connection ID for data submission
|
||||
- timezone: (Optional) New timezone for date handling
|
||||
|
||||
Only admin users can update CloudZero settings.
|
||||
"""
|
||||
# Validation
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
# Check if at least one field is provided
|
||||
if not any([request.api_key, request.connection_id, request.timezone]):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "At least one field must be provided for update"},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get current settings
|
||||
current_settings = await _get_cloudzero_settings()
|
||||
|
||||
# Update only provided fields
|
||||
updated_api_key = (
|
||||
request.api_key
|
||||
if request.api_key is not None
|
||||
else current_settings["api_key"]
|
||||
)
|
||||
updated_connection_id = (
|
||||
request.connection_id
|
||||
if request.connection_id is not None
|
||||
else current_settings["connection_id"]
|
||||
)
|
||||
updated_timezone = (
|
||||
request.timezone
|
||||
if request.timezone is not None
|
||||
else current_settings["timezone"]
|
||||
)
|
||||
|
||||
# Store updated settings using the setter method with encryption
|
||||
await _set_cloudzero_settings(
|
||||
api_key=updated_api_key,
|
||||
connection_id=updated_connection_id,
|
||||
timezone=updated_timezone,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("CloudZero settings updated successfully")
|
||||
|
||||
return CloudZeroInitResponse(
|
||||
message="CloudZero settings updated successfully", status="success"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
if e.status_code == 400:
|
||||
# Settings not configured yet
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": "CloudZero settings not found. Please initialize settings first using /cloudzero/init"
|
||||
},
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error updating CloudZero settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to update CloudZero settings: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
# Global variable to track if CloudZero background job has been initialized
|
||||
_cloudzero_background_job_initialized = False
|
||||
|
||||
|
||||
async def is_cloudzero_setup_in_db() -> bool:
|
||||
"""
|
||||
Check if CloudZero is setup in the database.
|
||||
|
||||
CloudZero is considered setup in the database if:
|
||||
- CloudZero settings exist in the database
|
||||
- The settings have a non-None value
|
||||
|
||||
Returns:
|
||||
bool: True if CloudZero is active, False otherwise
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
return False
|
||||
|
||||
# Check for CloudZero settings in database
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
# CloudZero is setup in the database if config exists and has non-None value
|
||||
return cloudzero_config is not None and cloudzero_config.param_value is not None
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error checking CloudZero status: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def is_cloudzero_setup_in_config() -> bool:
|
||||
"""
|
||||
Check if CloudZero is setup in config.yaml or environment variables.
|
||||
|
||||
CloudZero is considered setup in config if:
|
||||
- "cloudzero" is in the callbacks list in config.yaml, OR
|
||||
Returns:
|
||||
bool: True if CloudZero is configured, False otherwise
|
||||
"""
|
||||
import litellm
|
||||
|
||||
return "cloudzero" in litellm.callbacks
|
||||
|
||||
|
||||
async def is_cloudzero_setup() -> bool:
|
||||
"""
|
||||
Check if CloudZero is setup in either config.yaml/env vars OR database.
|
||||
|
||||
CloudZero is considered setup if:
|
||||
- CloudZero is configured in config.yaml callbacks, OR
|
||||
- CloudZero environment variables are set, OR
|
||||
- CloudZero settings exist in the database
|
||||
|
||||
Returns:
|
||||
bool: True if CloudZero is configured anywhere, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Check config.yaml/environment variables first
|
||||
if is_cloudzero_setup_in_config():
|
||||
return True
|
||||
|
||||
# Check database as fallback
|
||||
if await is_cloudzero_setup_in_db():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error checking CloudZero setup: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cloudzero/init",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroInitResponse,
|
||||
)
|
||||
async def init_cloudzero_settings(
|
||||
request: CloudZeroInitRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Initialize CloudZero settings and store in the database.
|
||||
|
||||
This endpoint stores the CloudZero API key, connection ID, and timezone configuration
|
||||
in the proxy database for use by the CloudZero logger.
|
||||
|
||||
Parameters:
|
||||
- api_key: CloudZero API key for authentication
|
||||
- connection_id: CloudZero connection ID for data submission
|
||||
- timezone: Timezone for date handling (default: UTC)
|
||||
|
||||
Only admin users can configure CloudZero settings.
|
||||
"""
|
||||
# Validation
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Store settings using the setter method with encryption
|
||||
await _set_cloudzero_settings(
|
||||
api_key=request.api_key,
|
||||
connection_id=request.connection_id,
|
||||
timezone=request.timezone,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("CloudZero settings initialized successfully")
|
||||
|
||||
return CloudZeroInitResponse(
|
||||
message="CloudZero settings initialized successfully", status="success"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error initializing CloudZero settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to initialize CloudZero settings: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cloudzero/dry-run",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroExportResponse,
|
||||
)
|
||||
async def cloudzero_dry_run_export(
|
||||
request: CloudZeroExportRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Perform a dry run export using the CloudZero logger.
|
||||
|
||||
This endpoint uses the CloudZero logger to perform a dry run export,
|
||||
which returns the data that would be exported without actually sending it to CloudZero.
|
||||
|
||||
Parameters:
|
||||
- limit: Optional limit on number of records to process (default: 10000)
|
||||
|
||||
Returns:
|
||||
- usage_data: Sample of the raw usage data (first 50 records)
|
||||
- cbf_data: CloudZero CBF formatted data ready for export
|
||||
- summary: Statistics including total cost, tokens, and record counts
|
||||
|
||||
Only admin users can perform CloudZero exports.
|
||||
"""
|
||||
# Validation
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Import and initialize CloudZero logger with credentials
|
||||
from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger
|
||||
|
||||
# Initialize logger with credentials directly
|
||||
logger = CloudZeroLogger()
|
||||
dry_run_result = await logger.dry_run_export_usage_data(limit=request.limit)
|
||||
|
||||
verbose_proxy_logger.info("CloudZero dry run export completed successfully")
|
||||
|
||||
return CloudZeroExportResponse(
|
||||
message="CloudZero dry run export completed successfully.",
|
||||
status="success",
|
||||
dry_run_data=dry_run_result,
|
||||
summary=dry_run_result.get("summary") if dry_run_result else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error performing CloudZero dry run export: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to perform CloudZero dry run export: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/cloudzero/export",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroExportResponse,
|
||||
)
|
||||
async def cloudzero_export(
|
||||
request: CloudZeroExportRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Perform an actual export using the CloudZero logger.
|
||||
|
||||
This endpoint uses the CloudZero logger to export usage data to CloudZero AnyCost API.
|
||||
|
||||
Parameters:
|
||||
- limit: Optional limit on number of records to export
|
||||
- operation: CloudZero operation type ("replace_hourly" or "sum", default: "replace_hourly")
|
||||
|
||||
Only admin users can perform CloudZero exports.
|
||||
"""
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
try:
|
||||
# Get CloudZero settings using the accessor method with decryption
|
||||
settings = await _get_cloudzero_settings()
|
||||
|
||||
# Import and initialize CloudZero logger with credentials
|
||||
from litellm.integrations.cloudzero.cloudzero import CloudZeroLogger
|
||||
|
||||
# Initialize logger with credentials directly
|
||||
logger = CloudZeroLogger(
|
||||
api_key=settings.get("api_key"),
|
||||
connection_id=settings.get("connection_id"),
|
||||
timezone=settings.get("timezone"),
|
||||
)
|
||||
await logger.export_usage_data(
|
||||
limit=request.limit,
|
||||
operation=request.operation,
|
||||
start_time_utc=request.start_time_utc,
|
||||
end_time_utc=request.end_time_utc,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("CloudZero export completed successfully")
|
||||
|
||||
return CloudZeroExportResponse(
|
||||
message="CloudZero export completed successfully",
|
||||
status="success",
|
||||
dry_run_data=None,
|
||||
summary=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error performing CloudZero export: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to perform CloudZero export: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/cloudzero/delete",
|
||||
tags=["CloudZero"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=CloudZeroInitResponse,
|
||||
)
|
||||
async def delete_cloudzero_settings(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Delete CloudZero settings from the database.
|
||||
|
||||
This endpoint removes the CloudZero configuration (API key, connection ID, timezone)
|
||||
from the proxy database. Only the CloudZero settings entry will be deleted;
|
||||
other configuration values in the database will remain unchanged.
|
||||
|
||||
Only admin users can delete CloudZero settings.
|
||||
"""
|
||||
# Validation
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": CommonProxyErrors.not_allowed_access.value},
|
||||
)
|
||||
|
||||
try:
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
# Check if CloudZero settings exist
|
||||
cloudzero_config = await prisma_client.db.litellm_config.find_first(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
if cloudzero_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": "CloudZero settings not found"},
|
||||
)
|
||||
|
||||
# Delete only the CloudZero settings entry
|
||||
# This uses a specific where clause to target only the cloudzero_settings row
|
||||
await prisma_client.db.litellm_config.delete(
|
||||
where={"param_name": "cloudzero_settings"}
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info("CloudZero settings deleted successfully")
|
||||
|
||||
return CloudZeroInitResponse(
|
||||
message="CloudZero settings deleted successfully", status="success"
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error deleting CloudZero settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to delete CloudZero settings: {str(e)}"},
|
||||
)
|
||||
@@ -0,0 +1,67 @@
|
||||
"""
|
||||
This module is responsible for handling Getting/Setting the proxy server request from cold storage.
|
||||
|
||||
It allows fetching a dict of the proxy server request from s3 or GCS bucket.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm import _custom_logger_compatible_callbacks_literal
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class ColdStorageHandler:
|
||||
"""
|
||||
This class is responsible for handling Getting/Setting the proxy server request from cold storage.
|
||||
|
||||
It allows fetching a dict of the proxy server request from s3 or GCS bucket.
|
||||
"""
|
||||
|
||||
async def get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
self,
|
||||
object_key: str,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the proxy server request from cold storage using the object key directly.
|
||||
|
||||
Args:
|
||||
object_key: The S3/GCS object key to retrieve
|
||||
|
||||
Returns:
|
||||
Optional[dict]: The proxy server request dict or None if not found
|
||||
"""
|
||||
|
||||
# select the custom logger to use for cold storage
|
||||
custom_logger_name: Optional[
|
||||
_custom_logger_compatible_callbacks_literal
|
||||
] = self._select_custom_logger_for_cold_storage()
|
||||
|
||||
# if no custom logger name is configured, return None
|
||||
if custom_logger_name is None:
|
||||
return None
|
||||
|
||||
# get the active/initialized custom logger
|
||||
custom_logger: Optional[
|
||||
CustomLogger
|
||||
] = litellm.logging_callback_manager.get_active_custom_logger_for_callback_name(
|
||||
custom_logger_name
|
||||
)
|
||||
|
||||
# if no custom logger is found, return None
|
||||
if custom_logger is None:
|
||||
return None
|
||||
|
||||
proxy_server_request = await custom_logger.get_proxy_server_request_from_cold_storage_with_object_key(
|
||||
object_key=object_key,
|
||||
)
|
||||
|
||||
return proxy_server_request
|
||||
|
||||
def _select_custom_logger_for_cold_storage(
|
||||
self,
|
||||
) -> Optional[_custom_logger_compatible_callbacks_literal]:
|
||||
cold_storage_custom_logger: Optional[
|
||||
_custom_logger_compatible_callbacks_literal
|
||||
] = litellm.cold_storage_custom_logger
|
||||
|
||||
return cold_storage_custom_logger
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,932 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
from datetime import datetime
|
||||
from datetime import datetime as dt
|
||||
from datetime import timezone
|
||||
from typing import Any, List, Literal, Optional, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import (
|
||||
LITELLM_TRUNCATED_PAYLOAD_FIELD,
|
||||
LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
|
||||
)
|
||||
from litellm.constants import (
|
||||
MAX_STRING_LENGTH_PROMPT_IN_DB as DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB,
|
||||
)
|
||||
from litellm.constants import REDACTED_BY_LITELM_STRING
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
get_litellm_metadata_from_kwargs,
|
||||
reconstruct_model_name,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
from litellm.types.utils import (
|
||||
CostBreakdown,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayload,
|
||||
StandardLoggingVectorStoreRequest,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
|
||||
def _get_max_string_length_prompt_in_db() -> int:
|
||||
"""
|
||||
Resolve prompt truncation threshold at runtime so values loaded later via
|
||||
proxy config environment_variables are honored.
|
||||
"""
|
||||
max_length_str = os.getenv("MAX_STRING_LENGTH_PROMPT_IN_DB")
|
||||
if max_length_str is None:
|
||||
return DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
|
||||
try:
|
||||
return int(max_length_str)
|
||||
except (TypeError, ValueError):
|
||||
return DEFAULT_MAX_STRING_LENGTH_PROMPT_IN_DB
|
||||
|
||||
|
||||
def _is_master_key(api_key: Optional[str], _master_key: Optional[str]) -> bool:
|
||||
if _master_key is None or api_key is None:
|
||||
return False
|
||||
|
||||
## string comparison
|
||||
is_master_key = secrets.compare_digest(api_key, _master_key)
|
||||
if is_master_key:
|
||||
return True
|
||||
|
||||
## hash comparison
|
||||
is_master_key = secrets.compare_digest(api_key, hash_token(_master_key))
|
||||
if is_master_key:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _get_spend_logs_metadata(
|
||||
metadata: Optional[dict],
|
||||
applied_guardrails: Optional[List[str]] = None,
|
||||
batch_models: Optional[List[str]] = None,
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = None,
|
||||
vector_store_request_metadata: Optional[
|
||||
List[StandardLoggingVectorStoreRequest]
|
||||
] = None,
|
||||
guardrail_information: Optional[List[StandardLoggingGuardrailInformation]] = None,
|
||||
usage_object: Optional[dict] = None,
|
||||
model_map_information: Optional[StandardLoggingModelInformation] = None,
|
||||
cold_storage_object_key: Optional[str] = None,
|
||||
litellm_overhead_time_ms: Optional[float] = None,
|
||||
cost_breakdown: Optional[CostBreakdown] = None,
|
||||
) -> SpendLogsMetadata:
|
||||
if metadata is None:
|
||||
return SpendLogsMetadata(
|
||||
user_api_key=None,
|
||||
user_api_key_alias=None,
|
||||
user_api_key_team_id=None,
|
||||
user_api_key_project_id=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
additional_usage_values=None,
|
||||
applied_guardrails=None,
|
||||
status=None or "success",
|
||||
error_information=None,
|
||||
proxy_server_request=None,
|
||||
batch_models=None,
|
||||
mcp_tool_call_metadata=None,
|
||||
vector_store_request_metadata=None,
|
||||
model_map_information=None,
|
||||
usage_object=None,
|
||||
guardrail_information=None,
|
||||
cold_storage_object_key=cold_storage_object_key,
|
||||
litellm_overhead_time_ms=None,
|
||||
attempted_retries=None,
|
||||
max_retries=None,
|
||||
cost_breakdown=None,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"getting payload for SpendLogs, available keys in metadata: "
|
||||
+ str(list(metadata.keys()))
|
||||
)
|
||||
|
||||
# Filter the metadata dictionary to include only the specified keys
|
||||
clean_metadata = SpendLogsMetadata(
|
||||
**{ # type: ignore
|
||||
key: metadata.get(key) for key in SpendLogsMetadata.__annotations__.keys()
|
||||
}
|
||||
)
|
||||
clean_metadata["applied_guardrails"] = applied_guardrails
|
||||
clean_metadata["batch_models"] = batch_models
|
||||
clean_metadata["mcp_tool_call_metadata"] = mcp_tool_call_metadata
|
||||
clean_metadata[
|
||||
"vector_store_request_metadata"
|
||||
] = _get_vector_store_request_for_spend_logs_payload(vector_store_request_metadata)
|
||||
clean_metadata["guardrail_information"] = guardrail_information
|
||||
clean_metadata["usage_object"] = usage_object
|
||||
clean_metadata["model_map_information"] = model_map_information
|
||||
clean_metadata["cold_storage_object_key"] = cold_storage_object_key
|
||||
clean_metadata["litellm_overhead_time_ms"] = litellm_overhead_time_ms
|
||||
clean_metadata["cost_breakdown"] = cost_breakdown
|
||||
|
||||
return clean_metadata
|
||||
|
||||
|
||||
def generate_hash_from_response(response_obj: Any) -> str:
|
||||
"""
|
||||
Generate a stable hash from a response object.
|
||||
|
||||
Args:
|
||||
response_obj: The response object to hash (can be dict, list, etc.)
|
||||
|
||||
Returns:
|
||||
A hex string representation of the MD5 hash
|
||||
"""
|
||||
try:
|
||||
# Create a stable JSON string of the entire response object
|
||||
# Sort keys to ensure consistent ordering
|
||||
json_str = json.dumps(response_obj, sort_keys=True)
|
||||
|
||||
# Generate a hash of the response object
|
||||
unique_hash = hashlib.md5(json_str.encode()).hexdigest()
|
||||
return unique_hash
|
||||
except Exception:
|
||||
# Return a fallback hash if serialization fails
|
||||
return hashlib.md5(str(response_obj).encode()).hexdigest()
|
||||
|
||||
|
||||
def get_spend_logs_id(
|
||||
call_type: str, response_obj: dict, kwargs: dict
|
||||
) -> Optional[str]:
|
||||
if call_type == "aretrieve_batch" or call_type == "acreate_file":
|
||||
# Generate a hash from the response object
|
||||
id: Optional[str] = generate_hash_from_response(response_obj)
|
||||
else:
|
||||
id = cast(Optional[str], response_obj.get("id")) or cast(
|
||||
Optional[str], kwargs.get("litellm_call_id")
|
||||
)
|
||||
return id
|
||||
|
||||
|
||||
def _extract_usage_for_ocr_call(response_obj: Any, response_obj_dict: dict) -> dict:
|
||||
"""
|
||||
Extract usage information for OCR/AOCR calls.
|
||||
|
||||
OCR responses use usage_info (with pages_processed) instead of token-based usage.
|
||||
|
||||
Args:
|
||||
response_obj: The raw response object (can be dict, BaseModel, or other)
|
||||
response_obj_dict: Dictionary representation of the response object
|
||||
|
||||
Returns:
|
||||
A dict with prompt_tokens=0, completion_tokens=0, total_tokens=0,
|
||||
and pages_processed from usage_info.
|
||||
"""
|
||||
usage_info = None
|
||||
|
||||
# Try to extract usage_info from dict
|
||||
if isinstance(response_obj_dict, dict) and "usage_info" in response_obj_dict:
|
||||
usage_info = response_obj_dict.get("usage_info")
|
||||
|
||||
# Try to extract usage_info from object attributes if not found in dict
|
||||
if not usage_info and hasattr(response_obj, "usage_info"):
|
||||
usage_info = response_obj.usage_info
|
||||
if hasattr(usage_info, "model_dump"):
|
||||
usage_info = usage_info.model_dump()
|
||||
elif hasattr(usage_info, "__dict__"):
|
||||
usage_info = vars(usage_info)
|
||||
|
||||
# For OCR, we track pages instead of tokens
|
||||
if usage_info is not None:
|
||||
# Handle dict or object with attributes
|
||||
if isinstance(usage_info, dict):
|
||||
result = {
|
||||
"prompt_tokens": 0, # OCR doesn't use traditional tokens
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
# Add all fields from usage_info, including pages_processed
|
||||
for key, value in usage_info.items():
|
||||
result[key] = value
|
||||
# Ensure pages_processed exists
|
||||
if "pages_processed" not in result:
|
||||
result["pages_processed"] = 0
|
||||
return result
|
||||
else:
|
||||
return {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"pages_processed": 0,
|
||||
}
|
||||
else:
|
||||
return {}
|
||||
|
||||
|
||||
def get_logging_payload( # noqa: PLR0915
|
||||
kwargs, response_obj, start_time, end_time
|
||||
) -> SpendLogsPayload:
|
||||
from litellm.proxy.proxy_server import general_settings, master_key
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if response_obj is None:
|
||||
response_obj = {}
|
||||
elif not isinstance(response_obj, BaseModel) and not isinstance(response_obj, dict):
|
||||
response_obj = {"result": str(response_obj)}
|
||||
# standardize this function to be used across, s3, dynamoDB, langfuse logging
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = get_litellm_metadata_from_kwargs(kwargs)
|
||||
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||
call_type = kwargs.get("call_type")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
|
||||
# Convert response_obj to dict first
|
||||
if isinstance(response_obj, dict):
|
||||
response_obj_dict = response_obj
|
||||
elif isinstance(response_obj, BaseModel):
|
||||
response_obj_dict = response_obj.model_dump()
|
||||
else:
|
||||
response_obj_dict = {}
|
||||
|
||||
# Handle OCR responses which use usage_info instead of usage
|
||||
usage: dict = {}
|
||||
if call_type in ["ocr", "aocr"]:
|
||||
usage = _extract_usage_for_ocr_call(response_obj, response_obj_dict)
|
||||
else:
|
||||
# Use response_obj_dict instead of response_obj to avoid calling .get() on Pydantic models
|
||||
_usage = response_obj_dict.get("usage", None) or {}
|
||||
if isinstance(_usage, litellm.Usage):
|
||||
usage = dict(_usage)
|
||||
elif isinstance(_usage, dict):
|
||||
usage = _usage
|
||||
|
||||
id = get_spend_logs_id(call_type or "acompletion", response_obj_dict, kwargs)
|
||||
standard_logging_payload = cast(
|
||||
Optional[StandardLoggingPayload], kwargs.get("standard_logging_object", None)
|
||||
)
|
||||
|
||||
end_user_id = get_end_user_id_for_cost_tracking(litellm_params)
|
||||
|
||||
api_key = metadata.get("user_api_key", "")
|
||||
|
||||
standard_logging_prompt_tokens: int = 0
|
||||
standard_logging_completion_tokens: int = 0
|
||||
standard_logging_total_tokens: int = 0
|
||||
if standard_logging_payload is not None:
|
||||
standard_logging_prompt_tokens = standard_logging_payload.get(
|
||||
"prompt_tokens", 0
|
||||
)
|
||||
standard_logging_completion_tokens = standard_logging_payload.get(
|
||||
"completion_tokens", 0
|
||||
)
|
||||
standard_logging_total_tokens = standard_logging_payload.get("total_tokens", 0)
|
||||
if api_key is not None and isinstance(api_key, str):
|
||||
if api_key.startswith("sk-"):
|
||||
# hash the api_key
|
||||
api_key = hash_token(api_key)
|
||||
if (
|
||||
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||
):
|
||||
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||
|
||||
if (
|
||||
standard_logging_payload is not None
|
||||
): # [TODO] migrate completely to sl payload. currently missing pass-through endpoint data
|
||||
api_key = (
|
||||
api_key
|
||||
or standard_logging_payload["metadata"].get("user_api_key_hash")
|
||||
or ""
|
||||
)
|
||||
end_user_id = end_user_id or standard_logging_payload["metadata"].get(
|
||||
"user_api_key_end_user_id"
|
||||
)
|
||||
# BUG FIX: Don't overwrite api_key when standard_logging_payload is None
|
||||
# The api_key was already extracted from metadata (line 243) and hashed (lines 256-259)
|
||||
request_tags = (
|
||||
json.dumps(metadata.get("tags", []))
|
||||
if isinstance(metadata.get("tags", []), list)
|
||||
else "[]"
|
||||
)
|
||||
if (
|
||||
standard_logging_payload is not None
|
||||
and standard_logging_payload.get("request_tags") is not None
|
||||
): # use 'tags' from standard logging payload instead
|
||||
request_tags = json.dumps(standard_logging_payload["request_tags"])
|
||||
if (
|
||||
_is_master_key(api_key=api_key, _master_key=master_key)
|
||||
and general_settings.get("disable_adding_master_key_hash_to_db") is True
|
||||
):
|
||||
api_key = "litellm_proxy_master_key" # use a known alias, if the user disabled storing master key in db
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = metadata.get("model_group", "")
|
||||
|
||||
# Extract overhead from hidden_params if available
|
||||
litellm_overhead_time_ms = None
|
||||
if standard_logging_payload is not None:
|
||||
hidden_params = standard_logging_payload.get("hidden_params", {})
|
||||
litellm_overhead_time_ms = hidden_params.get("litellm_overhead_time_ms")
|
||||
|
||||
# clean up litellm metadata
|
||||
clean_metadata = _get_spend_logs_metadata(
|
||||
metadata,
|
||||
applied_guardrails=(
|
||||
standard_logging_payload["metadata"].get("applied_guardrails", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
batch_models=(
|
||||
standard_logging_payload.get("hidden_params", {}).get("batch_models", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
mcp_tool_call_metadata=(
|
||||
standard_logging_payload["metadata"].get("mcp_tool_call_metadata", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
vector_store_request_metadata=(
|
||||
standard_logging_payload["metadata"].get(
|
||||
"vector_store_request_metadata", None
|
||||
)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
usage_object=(
|
||||
standard_logging_payload["metadata"].get("usage_object", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
model_map_information=(
|
||||
standard_logging_payload["model_map_information"]
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
guardrail_information=(
|
||||
standard_logging_payload.get("guardrail_information", None)
|
||||
if standard_logging_payload is not None
|
||||
else (
|
||||
metadata.get("standard_logging_guardrail_information", None)
|
||||
if metadata is not None
|
||||
else None
|
||||
)
|
||||
),
|
||||
cold_storage_object_key=(
|
||||
standard_logging_payload["metadata"].get("cold_storage_object_key", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
litellm_overhead_time_ms=litellm_overhead_time_ms,
|
||||
cost_breakdown=(
|
||||
standard_logging_payload.get("cost_breakdown", None)
|
||||
if standard_logging_payload is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
special_usage_fields = ["completion_tokens", "prompt_tokens", "total_tokens"]
|
||||
additional_usage_values = {}
|
||||
for k, v in usage.items():
|
||||
if k not in special_usage_fields:
|
||||
if isinstance(v, BaseModel):
|
||||
v = v.model_dump()
|
||||
additional_usage_values.update({k: v})
|
||||
clean_metadata["additional_usage_values"] = additional_usage_values
|
||||
|
||||
if litellm.cache is not None:
|
||||
cache_key = litellm.cache.get_cache_key(**kwargs)
|
||||
else:
|
||||
cache_key = "Cache OFF"
|
||||
if cache_hit is True:
|
||||
import time
|
||||
|
||||
id = f"{id}_cache_hit{time.time()}" # SpendLogs does not allow duplicate request_id
|
||||
|
||||
mcp_namespaced_tool_name = None
|
||||
mcp_tool_call_metadata: Optional[StandardLoggingMCPToolCall] = clean_metadata.get(
|
||||
"mcp_tool_call_metadata"
|
||||
)
|
||||
if mcp_tool_call_metadata is not None:
|
||||
mcp_namespaced_tool_name = mcp_tool_call_metadata.get(
|
||||
"namespaced_tool_name", None
|
||||
)
|
||||
|
||||
# Extract agent_id for A2A requests (set directly on model_call_details)
|
||||
agent_id: Optional[str] = kwargs.get("agent_id") or metadata.get("agent_id")
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider")
|
||||
raw_model = cast(str, kwargs.get("model") or "")
|
||||
model_name = reconstruct_model_name(raw_model, custom_llm_provider, metadata or {})
|
||||
|
||||
try:
|
||||
payload: SpendLogsPayload = SpendLogsPayload(
|
||||
request_id=str(id),
|
||||
call_type=call_type or "",
|
||||
api_key=str(api_key),
|
||||
cache_hit=str(cache_hit),
|
||||
startTime=_ensure_datetime_utc(start_time),
|
||||
endTime=_ensure_datetime_utc(end_time),
|
||||
completionStartTime=_ensure_datetime_utc(completion_start_time),
|
||||
model=model_name,
|
||||
user=metadata.get("user_api_key_user_id", "") or "",
|
||||
team_id=metadata.get("user_api_key_team_id", "") or "",
|
||||
organization_id=metadata.get("user_api_key_org_id") or "",
|
||||
metadata=safe_dumps(clean_metadata),
|
||||
cache_key=cache_key,
|
||||
spend=kwargs.get("response_cost", 0),
|
||||
total_tokens=usage.get("total_tokens", standard_logging_total_tokens),
|
||||
prompt_tokens=usage.get("prompt_tokens", standard_logging_prompt_tokens),
|
||||
completion_tokens=usage.get(
|
||||
"completion_tokens", standard_logging_completion_tokens
|
||||
),
|
||||
request_tags=request_tags,
|
||||
end_user=end_user_id or "",
|
||||
api_base=litellm_params.get("api_base", ""),
|
||||
model_group=_model_group,
|
||||
model_id=_model_id,
|
||||
mcp_namespaced_tool_name=mcp_namespaced_tool_name,
|
||||
agent_id=agent_id,
|
||||
requester_ip_address=clean_metadata.get("requester_ip_address", None),
|
||||
custom_llm_provider=kwargs.get("custom_llm_provider", ""),
|
||||
messages=_get_messages_for_spend_logs_payload(
|
||||
standard_logging_payload=standard_logging_payload, metadata=metadata
|
||||
),
|
||||
response=_get_response_for_spend_logs_payload(
|
||||
payload=standard_logging_payload, kwargs=kwargs
|
||||
),
|
||||
proxy_server_request=_get_proxy_server_request_for_spend_logs_payload(
|
||||
metadata=metadata, litellm_params=litellm_params, kwargs=kwargs
|
||||
),
|
||||
session_id=_get_session_id_for_spend_log(
|
||||
kwargs=kwargs,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
),
|
||||
request_duration_ms=_get_request_duration_ms(start_time, end_time),
|
||||
status=_get_status_for_spend_log(
|
||||
metadata=metadata,
|
||||
),
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"SpendTable: created payload - request_id: %s, model: %s, spend: %s",
|
||||
payload.get("request_id"),
|
||||
payload.get("model"),
|
||||
payload.get("spend"),
|
||||
)
|
||||
|
||||
# Explicitly clear large intermediate objects to reduce memory pressure
|
||||
del response_obj_dict, usage, clean_metadata, additional_usage_values
|
||||
|
||||
return payload
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error creating spendlogs object - {}".format(str(e))
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def _get_session_id_for_spend_log(
|
||||
kwargs: dict,
|
||||
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||
) -> str:
|
||||
"""
|
||||
Get the session id for the spend log.
|
||||
|
||||
This ensures each spend log is associated with a unique session id.
|
||||
|
||||
"""
|
||||
from litellm._uuid import uuid
|
||||
|
||||
if (
|
||||
standard_logging_payload is not None
|
||||
and standard_logging_payload.get("trace_id") is not None
|
||||
):
|
||||
return str(standard_logging_payload.get("trace_id"))
|
||||
|
||||
# Users can dynamically set the trace_id for each request by passing `litellm_trace_id` in kwargs
|
||||
if kwargs.get("litellm_trace_id") is not None:
|
||||
return str(kwargs.get("litellm_trace_id"))
|
||||
|
||||
# Ensure we always have a session id, if none is provided
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def _get_request_duration_ms(start_time: datetime, end_time: datetime) -> Optional[int]:
|
||||
"""Compute request duration in milliseconds from start and end times."""
|
||||
try:
|
||||
return int((end_time - start_time).total_seconds() * 1000)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_datetime_utc(timestamp: datetime) -> datetime:
|
||||
"""Helper to ensure datetime is in UTC"""
|
||||
timestamp = timestamp.astimezone(timezone.utc)
|
||||
return timestamp
|
||||
|
||||
|
||||
async def get_spend_by_team_and_customer(
|
||||
start_date: dt,
|
||||
end_date: dt,
|
||||
team_id: str,
|
||||
customer_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
):
|
||||
sql_query = """
|
||||
WITH SpendByModelApiKey AS (
|
||||
SELECT
|
||||
date_trunc('day', sl."startTime") AS group_by_day,
|
||||
COALESCE(tt.team_alias, 'Unassigned Team') AS team_name,
|
||||
sl.end_user AS customer,
|
||||
sl.model,
|
||||
sl.api_key,
|
||||
SUM(sl.spend) AS model_api_spend,
|
||||
SUM(sl.total_tokens) AS model_api_tokens
|
||||
FROM
|
||||
"LiteLLM_SpendLogs" sl
|
||||
LEFT JOIN
|
||||
"LiteLLM_TeamTable" tt
|
||||
ON
|
||||
sl.team_id = tt.team_id
|
||||
WHERE
|
||||
sl."startTime" >= $1::timestamptz AND sl."startTime" < ($2::timestamptz + INTERVAL '1 day')
|
||||
AND sl.team_id = $3
|
||||
AND sl.end_user = $4
|
||||
GROUP BY
|
||||
date_trunc('day', sl."startTime"),
|
||||
tt.team_alias,
|
||||
sl.end_user,
|
||||
sl.model,
|
||||
sl.api_key
|
||||
)
|
||||
SELECT
|
||||
group_by_day,
|
||||
jsonb_agg(jsonb_build_object(
|
||||
'team_name', team_name,
|
||||
'customer', customer,
|
||||
'total_spend', total_spend,
|
||||
'metadata', metadata
|
||||
)) AS teams_customers
|
||||
FROM (
|
||||
SELECT
|
||||
group_by_day,
|
||||
team_name,
|
||||
customer,
|
||||
SUM(model_api_spend) AS total_spend,
|
||||
jsonb_agg(jsonb_build_object(
|
||||
'model', model,
|
||||
'api_key', api_key,
|
||||
'spend', model_api_spend,
|
||||
'total_tokens', model_api_tokens
|
||||
)) AS metadata
|
||||
FROM
|
||||
SpendByModelApiKey
|
||||
GROUP BY
|
||||
group_by_day,
|
||||
team_name,
|
||||
customer
|
||||
) AS aggregated
|
||||
GROUP BY
|
||||
group_by_day
|
||||
ORDER BY
|
||||
group_by_day;
|
||||
"""
|
||||
|
||||
db_response = await prisma_client.db.query_raw(
|
||||
sql_query, start_date, end_date, team_id, customer_id
|
||||
)
|
||||
if db_response is None:
|
||||
return []
|
||||
|
||||
return db_response
|
||||
|
||||
|
||||
def _get_messages_for_spend_logs_payload(
|
||||
standard_logging_payload: Optional[StandardLoggingPayload],
|
||||
metadata: Optional[dict] = None,
|
||||
) -> str:
|
||||
if _should_store_prompts_and_responses_in_spend_logs():
|
||||
if standard_logging_payload is not None:
|
||||
call_type = standard_logging_payload.get("call_type", "")
|
||||
if call_type == "_arealtime":
|
||||
messages = standard_logging_payload.get("messages")
|
||||
if messages is not None:
|
||||
try:
|
||||
return json.dumps(messages, default=str)
|
||||
except Exception:
|
||||
return "{}"
|
||||
return "{}"
|
||||
|
||||
|
||||
def _sanitize_request_body_for_spend_logs_payload(
|
||||
request_body: dict,
|
||||
visited: Optional[set] = None,
|
||||
max_string_length_prompt_in_db: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Recursively sanitize request body to prevent logging large base64 strings or other large values.
|
||||
Truncates strings longer than MAX_STRING_LENGTH_PROMPT_IN_DB characters and handles nested dictionaries.
|
||||
"""
|
||||
from litellm.constants import (
|
||||
LITELLM_TRUNCATED_PAYLOAD_FIELD,
|
||||
LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
|
||||
)
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
if max_string_length_prompt_in_db is None:
|
||||
max_string_length_prompt_in_db = _get_max_string_length_prompt_in_db()
|
||||
|
||||
# Get the object's memory address to track visited objects
|
||||
obj_id = id(request_body)
|
||||
if obj_id in visited:
|
||||
return {}
|
||||
visited.add(obj_id)
|
||||
|
||||
def _sanitize_value(value: Any) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return _sanitize_request_body_for_spend_logs_payload(
|
||||
value, visited, max_string_length_prompt_in_db
|
||||
)
|
||||
elif isinstance(value, list):
|
||||
return [_sanitize_value(item) for item in value]
|
||||
elif isinstance(value, str):
|
||||
if len(value) > max_string_length_prompt_in_db:
|
||||
# Keep 35% from beginning and 65% from end (end is usually more important)
|
||||
# This split ensures we keep more context from the end of conversations
|
||||
start_ratio = 0.35
|
||||
end_ratio = 0.65
|
||||
|
||||
# Calculate character distribution
|
||||
start_chars = int(max_string_length_prompt_in_db * start_ratio)
|
||||
end_chars = int(max_string_length_prompt_in_db * end_ratio)
|
||||
|
||||
# Ensure we don't exceed the total limit
|
||||
total_keep = start_chars + end_chars
|
||||
if total_keep > max_string_length_prompt_in_db:
|
||||
end_chars = max_string_length_prompt_in_db - start_chars
|
||||
|
||||
# If the string length is less than what we want to keep, just truncate normally
|
||||
if len(value) <= max_string_length_prompt_in_db:
|
||||
return value
|
||||
|
||||
# Calculate how many characters are being skipped
|
||||
skipped_chars = len(value) - total_keep
|
||||
|
||||
# Build the truncated string: beginning + truncation marker + end
|
||||
truncated_value = (
|
||||
f"{value[:start_chars]}"
|
||||
f"... ({LITELLM_TRUNCATED_PAYLOAD_FIELD} skipped {skipped_chars} chars. "
|
||||
f"{LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE}) ..."
|
||||
f"{value[-end_chars:]}"
|
||||
)
|
||||
return truncated_value
|
||||
return value
|
||||
return value
|
||||
|
||||
return {k: _sanitize_value(v) for k, v in request_body.items()}
|
||||
|
||||
|
||||
def _convert_to_json_serializable_dict(
|
||||
obj: Any, visited: Optional[set] = None, max_depth: int = 20
|
||||
) -> Any:
|
||||
"""
|
||||
Convert object to JSON-serializable dict, handling Pydantic models safely.
|
||||
|
||||
This avoids pickle-based deepcopy which fails on Pydantic v2 models
|
||||
containing _thread.RLock objects.
|
||||
|
||||
Args:
|
||||
obj: Object to convert (dict, list, Pydantic model, or primitive)
|
||||
visited: Set of object IDs to track circular references
|
||||
max_depth: Maximum recursion depth to prevent infinite recursion
|
||||
|
||||
Returns:
|
||||
JSON-serializable version of the object
|
||||
"""
|
||||
if max_depth <= 0:
|
||||
# Return a placeholder if max depth is exceeded
|
||||
return "<max_depth_exceeded>"
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
# Get the object's memory address to track visited objects
|
||||
obj_id = id(obj)
|
||||
if obj_id in visited:
|
||||
# Circular reference detected, return placeholder
|
||||
return "<circular_reference>"
|
||||
|
||||
# Only track mutable objects (dict, list, objects with __dict__)
|
||||
if isinstance(obj, (dict, list)) or hasattr(obj, "__dict__"):
|
||||
visited.add(obj_id)
|
||||
|
||||
try:
|
||||
if isinstance(obj, BaseModel):
|
||||
# Use Pydantic's model_dump() instead of pickle
|
||||
result = obj.model_dump()
|
||||
# Recursively process the dumped dict
|
||||
return _convert_to_json_serializable_dict(result, visited, max_depth - 1)
|
||||
elif isinstance(obj, dict):
|
||||
return {
|
||||
k: _convert_to_json_serializable_dict(v, visited, max_depth - 1)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [
|
||||
_convert_to_json_serializable_dict(item, visited, max_depth - 1)
|
||||
for item in obj
|
||||
]
|
||||
elif hasattr(obj, "__dict__"):
|
||||
# Handle objects with __dict__ attribute
|
||||
return _convert_to_json_serializable_dict(
|
||||
obj.__dict__, visited, max_depth - 1
|
||||
)
|
||||
else:
|
||||
# Primitives (str, int, float, bool, None) pass through
|
||||
return obj
|
||||
finally:
|
||||
# Remove from visited set when done processing this object
|
||||
if obj_id in visited:
|
||||
visited.remove(obj_id)
|
||||
|
||||
|
||||
def _get_proxy_server_request_for_spend_logs_payload(
|
||||
metadata: dict,
|
||||
litellm_params: dict,
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Only store if _should_store_prompts_and_responses_in_spend_logs() is True
|
||||
|
||||
If turn_off_message_logging is enabled, redact messages in the request body.
|
||||
"""
|
||||
if _should_store_prompts_and_responses_in_spend_logs():
|
||||
_proxy_server_request = cast(
|
||||
Optional[dict], litellm_params.get("proxy_server_request", {})
|
||||
)
|
||||
if _proxy_server_request is not None:
|
||||
_request_body = _proxy_server_request.get("body", {}) or {}
|
||||
|
||||
if kwargs is not None:
|
||||
realtime_tools = kwargs.get("realtime_tools")
|
||||
if realtime_tools:
|
||||
_request_body = dict(_request_body)
|
||||
_request_body["tools"] = realtime_tools
|
||||
|
||||
# Apply message redaction if turn_off_message_logging is enabled
|
||||
if kwargs is not None:
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
perform_redaction,
|
||||
should_redact_message_logging,
|
||||
)
|
||||
|
||||
# Build model_call_details dict to check redaction settings
|
||||
model_call_details = {
|
||||
"litellm_params": litellm_params,
|
||||
"standard_callback_dynamic_params": kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
),
|
||||
}
|
||||
|
||||
# If redaction is enabled, convert to serializable dict before redacting
|
||||
if should_redact_message_logging(model_call_details=model_call_details):
|
||||
_request_body = _convert_to_json_serializable_dict(_request_body)
|
||||
perform_redaction(model_call_details=_request_body, result=None)
|
||||
|
||||
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
if LITELLM_TRUNCATED_PAYLOAD_FIELD in _request_body_json_str:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend Log: request body was truncated before storing in DB. %s",
|
||||
LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
|
||||
)
|
||||
return _request_body_json_str
|
||||
return "{}"
|
||||
|
||||
|
||||
def _get_vector_store_request_for_spend_logs_payload(
|
||||
vector_store_request_metadata: Optional[List[StandardLoggingVectorStoreRequest]],
|
||||
) -> Optional[List[StandardLoggingVectorStoreRequest]]:
|
||||
"""
|
||||
If user does not want to store prompts and responses, then remove the content from the vector store request metadata
|
||||
"""
|
||||
if _should_store_prompts_and_responses_in_spend_logs():
|
||||
return vector_store_request_metadata
|
||||
|
||||
# if user does not want to store prompts and responses, then remove the content from the vector store request metadata
|
||||
if vector_store_request_metadata is None:
|
||||
return None
|
||||
for vector_store_request in vector_store_request_metadata:
|
||||
vector_store_search_response: VectorStoreSearchResponse = (
|
||||
vector_store_request.get("vector_store_search_response")
|
||||
or VectorStoreSearchResponse()
|
||||
)
|
||||
response_data = vector_store_search_response.get("data", []) or []
|
||||
for response_item in response_data:
|
||||
for content_item in response_item.get("content", []) or []:
|
||||
if "text" in content_item:
|
||||
content_item["text"] = REDACTED_BY_LITELM_STRING
|
||||
return vector_store_request_metadata
|
||||
|
||||
|
||||
def _get_response_for_spend_logs_payload(
|
||||
payload: Optional[StandardLoggingPayload],
|
||||
kwargs: Optional[dict] = None,
|
||||
) -> str:
|
||||
if payload is None:
|
||||
return "{}"
|
||||
if _should_store_prompts_and_responses_in_spend_logs():
|
||||
response_obj: Any = payload.get("response")
|
||||
if response_obj is None:
|
||||
return "{}"
|
||||
|
||||
if kwargs is not None:
|
||||
realtime_tool_calls = kwargs.get("realtime_tool_calls")
|
||||
if realtime_tool_calls and isinstance(response_obj, dict):
|
||||
response_obj = dict(response_obj)
|
||||
response_obj["tool_calls"] = realtime_tool_calls
|
||||
|
||||
# Apply message redaction if turn_off_message_logging is enabled
|
||||
if kwargs is not None:
|
||||
from litellm.litellm_core_utils.redact_messages import (
|
||||
perform_redaction,
|
||||
should_redact_message_logging,
|
||||
)
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
model_call_details = {
|
||||
"litellm_params": litellm_params,
|
||||
"standard_callback_dynamic_params": kwargs.get(
|
||||
"standard_callback_dynamic_params"
|
||||
),
|
||||
}
|
||||
|
||||
# If redaction is enabled, convert to serializable dict before redacting
|
||||
if should_redact_message_logging(model_call_details=model_call_details):
|
||||
response_obj = _convert_to_json_serializable_dict(response_obj)
|
||||
response_obj = perform_redaction(
|
||||
model_call_details={}, result=response_obj
|
||||
)
|
||||
|
||||
sanitized_wrapper = _sanitize_request_body_for_spend_logs_payload(
|
||||
{"response": response_obj}
|
||||
)
|
||||
|
||||
sanitized_response = sanitized_wrapper.get("response", response_obj)
|
||||
|
||||
if sanitized_response is None:
|
||||
return "{}"
|
||||
if isinstance(sanitized_response, str):
|
||||
result_str = sanitized_response
|
||||
else:
|
||||
result_str = safe_dumps(sanitized_response)
|
||||
if LITELLM_TRUNCATED_PAYLOAD_FIELD in result_str:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend Log: response was truncated before storing in DB. %s",
|
||||
LITELLM_TRUNCATION_DB_SAFEGUARD_NOTE,
|
||||
)
|
||||
return result_str
|
||||
return "{}"
|
||||
|
||||
|
||||
def _should_store_prompts_and_responses_in_spend_logs() -> bool:
|
||||
from litellm.proxy.proxy_server import general_settings
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
|
||||
# Check general_settings (from DB or proxy_config.yaml)
|
||||
store_prompts_value = general_settings.get("store_prompts_in_spend_logs")
|
||||
|
||||
# Normalize case: handle True/true/TRUE, False/false/FALSE, None/null
|
||||
if store_prompts_value is True:
|
||||
return True
|
||||
elif isinstance(store_prompts_value, str):
|
||||
# Case-insensitive string comparison
|
||||
if store_prompts_value.lower() == "true":
|
||||
return True
|
||||
|
||||
# Also check environment variable
|
||||
return get_secret_bool("STORE_PROMPTS_IN_SPEND_LOGS") is True
|
||||
|
||||
|
||||
def _get_status_for_spend_log(
|
||||
metadata: dict,
|
||||
) -> Literal["success", "failure"]:
|
||||
"""
|
||||
Get the status for the spend log.
|
||||
|
||||
It's only a failure if metadata.get("status") is "failure"
|
||||
"""
|
||||
_status: Optional[str] = metadata.get("status", None)
|
||||
if _status == "failure":
|
||||
return "failure"
|
||||
return "success"
|
||||
Reference in New Issue
Block a user