chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,71 @@
# Pass-Through Endpoints Architecture
## Why Pass-Through Endpoints Transform Requests
Even "pass-through" endpoints must perform essential transformations. The request **body** passes through unchanged, but:
```mermaid
sequenceDiagram
participant Client
participant Proxy as LiteLLM Proxy
participant Provider as LLM Provider
Client->>Proxy: POST /vertex_ai/v1/projects/.../generateContent
Note over Client,Proxy: Headers: Authorization: Bearer sk-litellm-key
Note over Client,Proxy: Body: { "contents": [...] }
rect rgb(240, 240, 240)
Note over Proxy: 1. URL Construction
Note over Proxy: Build regional/provider-specific URL
end
rect rgb(240, 240, 240)
Note over Proxy: 2. Auth Header Replacement
Note over Proxy: LiteLLM key → provider credentials
end
rect rgb(240, 240, 240)
Note over Proxy: 3. Extra Operations
Note over Proxy: • x-pass-* headers (strip prefix, forward)
Note over Proxy: • x-litellm-tags → metadata
Note over Proxy: • Guardrails (opt-in)
Note over Proxy: • Multipart form reconstruction
end
Proxy->>Provider: POST https://us-central1-aiplatform.googleapis.com/...
Note over Proxy,Provider: Headers: Authorization: Bearer ya29.google-oauth...
Note over Proxy,Provider: Body: { "contents": [...] } ← UNCHANGED
Provider-->>Proxy: Response (streaming or non-streaming)
rect rgb(240, 240, 240)
Note over Proxy: 4. Response Handling (async)
Note over Proxy: • Collect streaming chunks for logging
Note over Proxy: • Cost injection (if enabled)
Note over Proxy: • Parse response → calculate cost → log
end
Proxy-->>Client: Response (unchanged)
```
## Essential Transformations
- **URL Construction** - Build correct provider URL (e.g., regional endpoints for Vertex AI, Bedrock)
- **Auth Header Replacement** - Swap LiteLLM virtual key for actual provider credentials
## Extra Operations
| Operation | Description |
|-----------|-------------|
| `x-pass-*` headers | Strip prefix and forward (e.g., `x-pass-anthropic-beta``anthropic-beta`) |
| `x-litellm-tags` header | Extract tags and add to request metadata for logging |
| Streaming chunk collection | Collect chunks async for logging after stream completes |
| Multipart form handling | Reconstruct multipart/form-data requests for file uploads |
| Guardrails (opt-in) | Run content filtering when explicitly configured |
| Cost injection | Inject cost into streaming chunks when `include_cost_in_streaming_usage` enabled |
## What Does NOT Change
- Request body
- Response body
- Provider-specific parameters

View File

@@ -0,0 +1,16 @@
from fastapi import Request
def get_litellm_virtual_key(request: Request) -> str:
"""
Extract and format API key from request headers.
Prioritizes x-litellm-api-key over Authorization header.
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
"""
litellm_api_key = request.headers.get("x-litellm-api-key")
if litellm_api_key:
return f"Bearer {litellm_api_key}"
return request.headers.get("Authorization", "")

View File

@@ -0,0 +1,94 @@
"""
JSONPath Extractor Module
Extracts field values from data using simple JSONPath-like expressions.
"""
from typing import Any, List, Union
from litellm._logging import verbose_proxy_logger
class JsonPathExtractor:
"""Extracts field values from data using JSONPath-like expressions."""
@staticmethod
def extract_fields(
data: dict,
jsonpath_expressions: List[str],
) -> str:
"""
Extract field values from data using JSONPath-like expressions.
Supports simple expressions like:
- "query" -> data["query"]
- "documents[*].text" -> all text fields from documents array
- "messages[*].content" -> all content fields from messages array
Returns concatenated string of all extracted values.
"""
extracted_values: List[str] = []
for expr in jsonpath_expressions:
try:
value = JsonPathExtractor.evaluate(data, expr)
if value:
if isinstance(value, list):
extracted_values.extend([str(v) for v in value if v])
else:
extracted_values.append(str(value))
except Exception as e:
verbose_proxy_logger.debug(
"Failed to extract field %s: %s", expr, str(e)
)
return "\n".join(extracted_values)
@staticmethod
def evaluate(data: dict, expr: str) -> Union[str, List[str], None]:
"""
Evaluate a simple JSONPath-like expression.
Supports:
- Simple key: "query" -> data["query"]
- Nested key: "foo.bar" -> data["foo"]["bar"]
- Array wildcard: "items[*].text" -> [item["text"] for item in data["items"]]
"""
if not expr or not data:
return None
parts = expr.replace("[*]", ".[*]").split(".")
current: Any = data
for i, part in enumerate(parts):
if current is None:
return None
if part == "[*]":
# Wildcard - current should be a list
if not isinstance(current, list):
return None
# Get remaining path
remaining_path = ".".join(parts[i + 1 :])
if not remaining_path:
return current
# Recursively evaluate remaining path for each item
results = []
for item in current:
if isinstance(item, dict):
result = JsonPathExtractor.evaluate(item, remaining_path)
if result:
if isinstance(result, list):
results.extend(result)
else:
results.append(result)
return results if results else None
elif isinstance(current, dict):
current = current.get(part)
else:
return None
return current

View File

@@ -0,0 +1,619 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.anthropic import get_anthropic_config
from litellm.llms.anthropic.chat.handler import (
ModelResponseIterator as AnthropicModelResponseIterator,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import LiteLLMBatch, ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from ..success_handler import PassThroughEndpointLogging
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class AnthropicPassthroughLoggingHandler:
@staticmethod
def anthropic_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: Optional[dict] = None,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
# Check if this is a batch creation request
if "/v1/messages/batches" in url_route and httpx_response.status_code == 200:
# Get request body from parameter or kwargs
request_body = request_body or kwargs.get("request_body", {})
return AnthropicPassthroughLoggingHandler.batch_creation_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
model = response_body.get("model", "")
anthropic_config = get_anthropic_config(url_route)
litellm_model_response: ModelResponse = anthropic_config().transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
@staticmethod
def _get_user_from_metadata(
passthrough_logging_payload: PassthroughStandardLoggingPayload,
) -> Optional[str]:
request_body = passthrough_logging_payload.get("request_body")
if request_body:
return get_end_user_id_from_request_body(request_body)
return None
@staticmethod
def _create_anthropic_response_logging_payload(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
):
"""
Create the standard logging object for Anthropic passthrough
handles streaming and non-streaming responses
"""
try:
# Get custom_llm_provider from logging object if available (e.g., azure_ai for Azure Anthropic)
custom_llm_provider = logging_obj.model_call_details.get(
"custom_llm_provider"
)
# Prepend custom_llm_provider to model if not already present
model_for_cost = model
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
model_for_cost = f"{custom_llm_provider}/{model}"
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model_for_cost,
custom_llm_provider=custom_llm_provider,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
if passthrough_logging_payload:
user = AnthropicPassthroughLoggingHandler._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs.setdefault("litellm_params", {})
kwargs["litellm_params"].update(
{"proxy_server_request": {"body": {"user": user}}}
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"kwargs= %s",
json.dumps(kwargs, indent=4, default=str),
)
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
litellm_model_response.model = model
logging_obj.model_call_details["model"] = model
if not logging_obj.model_call_details.get("custom_llm_provider"):
logging_obj.model_call_details[
"custom_llm_provider"
] = litellm.LlmProviders.ANTHROPIC.value
return kwargs
except Exception as e:
verbose_proxy_logger.exception(
"Error creating Anthropic response logging payload: %s", e
)
return kwargs
@staticmethod
def _handle_logging_anthropic_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
# Check if it's available in the logging object
if (
not model
and hasattr(litellm_logging_obj, "model_call_details")
and litellm_logging_obj.model_call_details.get("model")
):
model = cast(str, litellm_logging_obj.model_call_details.get("model"))
complete_streaming_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": {},
}
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _split_sse_chunk_into_events(chunk: Union[str, bytes]) -> List[str]:
"""
Split a chunk that may contain multiple SSE events into individual events.
SSE format: "event: type\ndata: {...}\n\n"
Multiple events in a single chunk are separated by double newlines.
Args:
chunk: Raw chunk string that may contain multiple SSE events
Returns:
List of individual SSE event strings (each containing "event: X\ndata: {...}")
"""
# Handle bytes input
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
# Split on double newlines to separate SSE events
# Filter out empty strings
events = [event.strip() for event in chunk.split("\n\n") if event.strip()]
return events
@staticmethod
def _build_complete_streaming_response(
all_chunks: Sequence[Union[str, bytes]],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw Anthropic chunks
- Splits multi-event chunks into individual SSE events
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
verbose_proxy_logger.debug(
"Building complete streaming response from %d chunks", len(all_chunks)
)
anthropic_model_response_iterator = AnthropicModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
all_openai_chunks = []
# Process each chunk - a chunk may contain multiple SSE events
for _chunk_str in all_chunks:
# Split chunk into individual SSE events
individual_events = (
AnthropicPassthroughLoggingHandler._split_sse_chunk_into_events(
_chunk_str
)
)
# Process each individual event
for event_str in individual_events:
try:
transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=event_str
)
if transformed_openai_chunk is not None:
all_openai_chunks.append(transformed_openai_chunk)
except (StopIteration, StopAsyncIteration):
break
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks,
logging_obj=litellm_logging_obj,
)
verbose_proxy_logger.debug(
"Complete streaming response built: %s", complete_streaming_response
)
return complete_streaming_response
@staticmethod
def batch_creation_handler( # noqa: PLR0915
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: Optional[dict] = None,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle Anthropic batch creation passthrough logging.
Creates a managed object for cost tracking when batch job is successfully created.
"""
import base64
from litellm._uuid import uuid
from litellm.llms.anthropic.batches.transformation import (
AnthropicBatchesConfig,
)
from litellm.types.utils import Choices, SpecialEnums
try:
_json_response = httpx_response.json()
# Only handle successful batch job creation (POST requests with 201 status)
if httpx_response.status_code == 200 and "id" in _json_response:
# Transform Anthropic response to LiteLLM batch format
anthropic_batches_config = AnthropicBatchesConfig()
litellm_batch_response = (
anthropic_batches_config.transform_retrieve_batch_response(
model=None,
raw_response=httpx_response,
logging_obj=logging_obj,
litellm_params={},
)
)
# Set status to "validating" for newly created batches so polling mechanism picks them up
# The polling mechanism only looks for status="validating" jobs
litellm_batch_response.status = "validating"
# Extract batch ID from the response
batch_id = _json_response.get("id", "")
# Get model from request body (batch response doesn't include model)
request_body = request_body or {}
# Try to extract model from the batch request body, supporting Anthropic's nested structure
model_name: str = "unknown"
if isinstance(request_body, dict):
# Standard: {"model": ...}
model_name = request_body.get("model") or "unknown"
if model_name == "unknown":
# Anthropic batches: look under requests[0].params.model
requests_list = request_body.get("requests", [])
if isinstance(requests_list, list) and len(requests_list) > 0:
first_req = requests_list[0]
if isinstance(first_req, dict):
params = first_req.get("params", {})
if isinstance(params, dict):
extracted_model = params.get("model")
if extracted_model:
model_name = extracted_model
# Create unified object ID for tracking
# Format: base64(litellm_proxy;model_id:{};llm_batch_id:{})
# For Anthropic passthrough, prefix model with "anthropic/" so router can determine provider
actual_model_id = (
AnthropicPassthroughLoggingHandler.get_actual_model_id_from_router(
model_name
)
)
# If model not in router, use "anthropic/{model_name}" format so router can determine provider
if actual_model_id == model_name and not actual_model_id.startswith(
"anthropic/"
):
actual_model_id = f"anthropic/{model_name}"
unified_id_string = (
SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
actual_model_id, batch_id
)
)
unified_object_id = (
base64.urlsafe_b64encode(unified_id_string.encode())
.decode()
.rstrip("=")
)
# Store the managed object for cost tracking
# This will be picked up by check_batch_cost polling mechanism
AnthropicPassthroughLoggingHandler._store_batch_managed_object(
unified_object_id=unified_object_id,
batch_object=litellm_batch_response,
model_object_id=batch_id,
logging_obj=logging_obj,
**kwargs,
)
# Create a batch job response for logging
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = model_name
litellm_model_response.object = "batch"
litellm_model_response.created = int(start_time.timestamp())
# Add batch-specific metadata to indicate this is a pending batch job
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Batch job {batch_id} created and is pending. Status will be updated when the batch completes.",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_id": batch_id,
"batch_job_state": "in_progress",
"unified_object_id": unified_object_id,
},
},
)
]
# Set response cost to 0 initially (will be updated when batch completes)
response_cost = 0.0
kwargs["response_cost"] = response_cost
kwargs["model"] = model_name
kwargs["batch_id"] = batch_id
kwargs["unified_object_id"] = unified_object_id
kwargs["batch_job_state"] = "in_progress"
logging_obj.model = model_name
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["response_cost"] = response_cost
logging_obj.model_call_details["batch_id"] = batch_id
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
else:
# Handle non-successful responses
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = "anthropic_batch"
litellm_model_response.object = "batch"
litellm_model_response.created = int(start_time.timestamp())
# Add error-specific metadata
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Batch job creation failed. Status: {httpx_response.status_code}",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_state": "failed",
"status_code": httpx_response.status_code,
},
},
)
]
kwargs["response_cost"] = 0.0
kwargs["model"] = "anthropic_batch"
kwargs["batch_job_state"] = "failed"
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.error(f"Error in batch_creation_handler: {e}")
# Return basic response on error
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = "anthropic_batch"
litellm_model_response.object = "batch"
litellm_model_response.created = int(start_time.timestamp())
# Add error-specific metadata
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Error creating batch job: {str(e)}",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_state": "failed",
"error": str(e),
},
},
)
]
kwargs["response_cost"] = 0.0
kwargs["model"] = "anthropic_batch"
kwargs["batch_job_state"] = "failed"
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
@staticmethod
def _store_batch_managed_object(
unified_object_id: str,
batch_object: LiteLLMBatch,
model_object_id: str,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> None:
"""
Store batch managed object for cost tracking.
This will be picked up by the check_batch_cost polling mechanism.
"""
try:
# Get the managed files hook from the logging object
# This is a bit of a hack, but we need access to the proxy logging system
from litellm.proxy.proxy_server import proxy_logging_obj
managed_files_hook = proxy_logging_obj.get_proxy_hook("managed_files")
if managed_files_hook is not None and hasattr(
managed_files_hook, "store_unified_object_id"
):
# Create a mock user API key dict for the managed object storage
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
user_api_key_dict = UserAPIKeyAuth(
user_id=kwargs.get("user_id", "default-user"),
api_key="",
team_id=None,
team_alias=None,
user_role=LitellmUserRoles.CUSTOMER, # Use proper enum value
user_email=None,
max_budget=None,
spend=0.0, # Set to 0.0 instead of None
models=[], # Set to empty list instead of None
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
max_parallel_requests=None,
allowed_model_region=None,
metadata={}, # Set to empty dict instead of None
key_alias=None,
permissions={}, # Set to empty dict instead of None
model_max_budget={}, # Set to empty dict instead of None
model_spend={}, # Set to empty dict instead of None
)
# Store the unified object for batch cost tracking
import asyncio
asyncio.create_task(
managed_files_hook.store_unified_object_id( # type: ignore
unified_object_id=unified_object_id,
file_object=batch_object,
litellm_parent_otel_span=None,
model_object_id=model_object_id,
file_purpose="batch",
user_api_key_dict=user_api_key_dict,
)
)
verbose_proxy_logger.info(
f"Stored Anthropic batch managed object with unified_object_id={unified_object_id}, batch_id={model_object_id}"
)
else:
verbose_proxy_logger.warning(
"Managed files hook not available, cannot store batch object for cost tracking"
)
except Exception as e:
verbose_proxy_logger.error(
f"Error storing Anthropic batch managed object: {e}"
)
@staticmethod
def get_actual_model_id_from_router(model_name: str) -> str:
from litellm.proxy.proxy_server import llm_router
if llm_router is not None:
# Try to find the model in the router by the model name
# Use the existing get_model_ids method from router
model_ids = llm_router.get_model_ids(model_name=model_name)
if model_ids and len(model_ids) > 0:
# Use the first model ID found
actual_model_id = model_ids[0]
verbose_proxy_logger.info(
f"Found model ID in router: {actual_model_id}"
)
return actual_model_id
else:
# Fallback to model name
actual_model_id = model_name
verbose_proxy_logger.warning(
f"Model not found in router, using model name: {actual_model_id}"
)
return actual_model_id
else:
# Fallback if router is not available
verbose_proxy_logger.warning(
f"Router not available, using model name: {model_name}"
)
return model_name

View File

@@ -0,0 +1,333 @@
import asyncio
import json
import time
from datetime import datetime
from typing import Literal, Optional
from urllib.parse import urlparse
import httpx
from typing_extensions import TypedDict
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.types.passthrough_endpoints.assembly_ai import (
ASSEMBLY_AI_MAX_POLLING_ATTEMPTS,
ASSEMBLY_AI_POLLING_INTERVAL,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
class AssemblyAITranscriptResponse(TypedDict, total=False):
id: str
speech_model: str
acoustic_model: str
language_code: str
status: str
audio_duration: float
class AssemblyAIPassthroughLoggingHandler:
def __init__(self):
self.assembly_ai_base_url = "https://api.assemblyai.com"
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
"""
The base URL for the AssemblyAI API
"""
self.polling_interval: float = ASSEMBLY_AI_POLLING_INTERVAL
"""
The polling interval for the AssemblyAI API.
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
"""
self.max_polling_attempts = ASSEMBLY_AI_MAX_POLLING_ATTEMPTS
"""
The maximum number of polling attempts for the AssemblyAI API.
"""
def assemblyai_passthrough_logging_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
"""
executor.submit(
self._handle_assemblyai_passthrough_logging,
httpx_response,
response_body,
logging_obj,
url_route,
result,
start_time,
end_time,
cache_hit,
**kwargs,
)
def _handle_assemblyai_passthrough_logging(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""
Handles logging for AssemblyAI successful passthrough requests
"""
from ..pass_through_endpoints import pass_through_endpoint_logging
model = response_body.get("speech_model", "")
verbose_proxy_logger.debug(
"response body %s", json.dumps(response_body, indent=4)
)
kwargs["model"] = model
kwargs["custom_llm_provider"] = "assemblyai"
response_cost: Optional[float] = None
transcript_id = response_body.get("id")
if transcript_id is None:
raise ValueError(
"Transcript ID is required to log the cost of the transcription"
)
transcript_response = self._poll_assembly_for_transcript_response(
transcript_id=transcript_id, url_route=url_route
)
verbose_proxy_logger.debug(
"finished polling assembly for transcript response- got transcript response %s",
json.dumps(transcript_response, indent=4),
)
if transcript_response:
cost = self.get_cost_for_assembly_transcript(
speech_model=model,
transcript_response=transcript_response,
)
response_cost = cost
# Make standard logging object for Vertex AI
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=transcript_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
verbose_proxy_logger.debug(
"standard_passthrough_logging_object %s",
json.dumps(passthrough_logging_payload, indent=4),
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
)
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
logging_obj.model_call_details["response_cost"] = response_cost
asyncio.run(
pass_through_endpoint_logging._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=self._get_response_to_log(
transcript_response
),
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
)
pass
def _get_response_to_log(
self, transcript_response: Optional[AssemblyAITranscriptResponse]
) -> dict:
if transcript_response is None:
return {}
return dict(transcript_response)
def _get_assembly_transcript(
self,
transcript_id: str,
request_region: Optional[Literal["eu"]] = None,
) -> Optional[dict]:
"""
Get the transcript details from AssemblyAI API
Args:
response_body (dict): Response containing the transcript ID
Returns:
Optional[dict]: Transcript details if successful, None otherwise
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
passthrough_endpoint_router,
)
_base_url = (
self.assembly_ai_eu_base_url
if request_region == "eu"
else self.assembly_ai_base_url
)
_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="assemblyai",
region_name=request_region,
)
if _api_key is None:
raise ValueError("AssemblyAI API key not found")
try:
url = f"{_base_url}/v2/transcript/{transcript_id}"
headers = {
"Authorization": f"Bearer {_api_key}",
"Content-Type": "application/json",
}
response = httpx.get(url, headers=headers)
response.raise_for_status()
return response.json()
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
)
return None
def _poll_assembly_for_transcript_response(
self,
transcript_id: str,
url_route: Optional[str] = None,
) -> Optional[AssemblyAITranscriptResponse]:
"""
Poll the status of the transcript until it is completed or timeout (30 minutes)
"""
for _ in range(
self.max_polling_attempts
): # 180 attempts * 10s = 30 minutes max
transcript = self._get_assembly_transcript(
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
url=url_route
),
transcript_id=transcript_id,
)
if transcript is None:
return None
if (
transcript.get("status") == "completed"
or transcript.get("status") == "error"
):
return AssemblyAITranscriptResponse(**transcript)
time.sleep(self.polling_interval)
return None
@staticmethod
def get_cost_for_assembly_transcript(
transcript_response: AssemblyAITranscriptResponse,
speech_model: str,
) -> Optional[float]:
"""
Get the cost for the assembly transcript
"""
_audio_duration = transcript_response.get("audio_duration")
if _audio_duration is None:
return None
_cost_per_second = (
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
speech_model=speech_model
)
)
if _cost_per_second is None:
return None
return _audio_duration * _cost_per_second
@staticmethod
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
"""
Get the cost per second for the assembly model.
Falls back to assemblyai/nano if the specific speech model info cannot be found.
"""
try:
# First try with the provided speech model
try:
model_info = litellm.get_model_info(
model=speech_model,
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass # Continue to fallback if model not found
# Fallback to assemblyai/nano if speech model info not found
try:
model_info = litellm.get_model_info(
model="assemblyai/nano",
custom_llm_provider="assemblyai",
)
if model_info and model_info.get("input_cost_per_second") is not None:
return model_info.get("input_cost_per_second")
except Exception:
pass
return None
except Exception as e:
verbose_proxy_logger.exception(
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
)
return None
@staticmethod
def _should_log_request(request_method: str) -> bool:
"""
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
"""
return request_method == "POST"
@staticmethod
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
"""
Get the region from the URL
"""
if url is None:
return None
if urlparse(url).hostname == "eu.assemblyai.com":
return "eu"
return None
@staticmethod
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
"""
Get the base URL for the AssemblyAI API
if region == "eu", return "https://api.eu.assemblyai.com"
else return "https://api.assemblyai.com"
"""
if region == "eu":
return "https://api.eu.assemblyai.com"
return "https://api.assemblyai.com"

View File

@@ -0,0 +1,221 @@
import json
from datetime import datetime
from typing import TYPE_CHECKING, Any, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
if TYPE_CHECKING:
from ..success_handler import PassThroughEndpointLogging
from ..types import EndpointType
else:
PassThroughEndpointLogging = Any
EndpointType = Any
from abc import ABC, abstractmethod
class BasePassthroughLoggingHandler(ABC):
@property
@abstractmethod
def llm_provider_name(self) -> LlmProviders:
pass
@abstractmethod
def get_provider_config(self, model: str) -> BaseConfig:
pass
def passthrough_chat_handler(
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Transforms LLM response to OpenAI response, generates a standard logging object so downstream logging can be handled
"""
model = request_body.get("model", response_body.get("model", ""))
provider_config = self.get_provider_config(model=model)
litellm_model_response: ModelResponse = provider_config.transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=[],
logging_obj=logging_obj,
optional_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
json_mode=False,
litellm_params={},
)
kwargs = self._create_response_logging_payload(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
def _get_user_from_metadata(
self,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
) -> Optional[str]:
request_body = passthrough_logging_payload.get("request_body")
if request_body:
return get_end_user_id_from_request_body(request_body)
return None
def _create_response_logging_payload(
self,
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
) -> dict:
"""
Create the standard logging object for Generic LLM passthrough
handles streaming and non-streaming responses
"""
try:
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
kwargs.get("passthrough_logging_payload")
)
if passthrough_logging_payload:
user = self._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs.setdefault("litellm_params", {})
kwargs["litellm_params"].update(
{"proxy_server_request": {"body": {"user": user}}}
)
# Make standard logging object for Anthropic
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# pretty print standard logging object
verbose_proxy_logger.debug(
"standard_logging_object= %s",
json.dumps(standard_logging_object, indent=4),
)
kwargs["standard_logging_object"] = standard_logging_object
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
litellm_model_response.model = model
logging_obj.model_call_details["model"] = model
return kwargs
except Exception as e:
verbose_proxy_logger.exception(
"Error creating LLM passthrough response logging payload: %s", e
)
return kwargs
@abstractmethod
def _build_complete_streaming_response(
self,
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw chunks
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
pass
def _handle_logging_llm_collected_chunks(
self,
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
model = request_body.get("model", "")
complete_streaming_response = self._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": {},
}
kwargs = self._create_response_logging_payload(
litellm_model_response=complete_streaming_response,
model=model,
kwargs={},
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}

View File

@@ -0,0 +1,192 @@
from datetime import datetime
from typing import List, Optional, Union
import httpx
import litellm
from litellm import stream_chunk_builder
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.cohere.chat.v2_transformation import CohereV2ChatConfig
from litellm.llms.cohere.common_utils import (
ModelResponseIterator as CohereModelResponseIterator,
)
from litellm.llms.cohere.embed.v1_transformation import CohereEmbeddingConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import (
LlmProviders,
ModelResponse,
TextCompletionResponse,
)
from .base_passthrough_logging_handler import BasePassthroughLoggingHandler
class CoherePassthroughLoggingHandler(BasePassthroughLoggingHandler):
@property
def llm_provider_name(self) -> LlmProviders:
return LlmProviders.COHERE
def get_provider_config(self, model: str) -> BaseConfig:
return CohereV2ChatConfig()
def _build_complete_streaming_response(
self,
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
cohere_model_response_iterator = CohereModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
litellm_custom_stream_wrapper = CustomStreamWrapper(
completion_stream=cohere_model_response_iterator,
model=model,
logging_obj=litellm_logging_obj,
custom_llm_provider="cohere",
)
all_openai_chunks = []
for _chunk_str in all_chunks:
try:
generic_chunk = (
cohere_model_response_iterator.convert_str_chunk_to_generic_chunk(
chunk=_chunk_str
)
)
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
chunk=generic_chunk
)
if litellm_chunk is not None:
all_openai_chunks.append(litellm_chunk)
except (StopIteration, StopAsyncIteration):
break
complete_streaming_response = stream_chunk_builder(chunks=all_openai_chunks)
return complete_streaming_response
def cohere_passthrough_handler( # noqa: PLR0915
self,
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle Cohere passthrough logging with route detection and cost tracking.
"""
# Check if this is an embed endpoint
if "/v1/embed" in url_route:
model = request_body.get("model", response_body.get("model", ""))
try:
cohere_embed_config = CohereEmbeddingConfig()
litellm_model_response = litellm.EmbeddingResponse()
handler_instance = CoherePassthroughLoggingHandler()
input_texts = request_body.get("texts", [])
if not input_texts:
input_texts = request_body.get("input", [])
# Transform the response
litellm_model_response = cohere_embed_config._transform_response(
response=httpx_response,
api_key="",
logging_obj=logging_obj,
data=request_body,
model_response=litellm_model_response,
model=model,
encoding=litellm.encoding,
input=input_texts,
)
# Calculate cost using LiteLLM's cost calculator
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
custom_llm_provider="cohere",
call_type="aembedding",
)
# Set the calculated cost in _hidden_params to prevent recalculation
if not hasattr(litellm_model_response, "_hidden_params"):
litellm_model_response._hidden_params = {}
litellm_model_response._hidden_params["response_cost"] = response_cost
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = "cohere"
# Extract user information for tracking
passthrough_logging_payload: Optional[
PassthroughStandardLoggingPayload
] = kwargs.get("passthrough_logging_payload")
if passthrough_logging_payload:
user = handler_instance._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs.setdefault("litellm_params", {})
kwargs["litellm_params"].update(
{"proxy_server_request": {"body": {"user": user}}}
)
# Create standard logging object
if litellm_model_response is not None:
get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# Update logging object with cost information
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "cohere"
logging_obj.model_call_details["response_cost"] = response_cost
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
except Exception:
# For other routes (e.g., /v2/chat), fall back to chat handler
return super().passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
# For non-embed routes (e.g., /v2/chat), fall back to chat handler
return super().passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)

View File

@@ -0,0 +1,139 @@
"""
Cursor Cloud Agents API - Pass-through Logging Handler
Transforms Cursor API responses into standardized logging payloads
so they appear cleanly in the LiteLLM Logs page.
"""
from datetime import datetime
from typing import Dict
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import StandardPassThroughResponseObject
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
"POST /v0/agents": "cursor:agent:create",
"GET /v0/agents": "cursor:agent:list",
"POST /v0/agents/{id}/followup": "cursor:agent:followup",
"POST /v0/agents/{id}/stop": "cursor:agent:stop",
"DELETE /v0/agents/{id}": "cursor:agent:delete",
"GET /v0/agents/{id}/conversation": "cursor:agent:conversation",
"GET /v0/agents/{id}": "cursor:agent:status",
"GET /v0/me": "cursor:account:info",
"GET /v0/models": "cursor:models:list",
"GET /v0/repositories": "cursor:repositories:list",
}
def _classify_cursor_request(method: str, path: str) -> str:
"""Classify a Cursor API request into a readable operation name."""
normalized = path.rstrip("/")
for pattern, operation in CURSOR_AGENT_ENDPOINTS.items():
pat_method, pat_path = pattern.split(" ", 1)
if method.upper() != pat_method:
continue
pat_parts = pat_path.strip("/").split("/")
req_parts = normalized.strip("/").split("/")
if len(pat_parts) != len(req_parts):
continue
match = True
for pp, rp in zip(pat_parts, req_parts):
if pp.startswith("{") and pp.endswith("}"):
continue
if pp != rp:
match = False
break
if match:
return operation
return f"cursor:{method.lower()}:{normalized}"
class CursorPassthroughLoggingHandler:
"""Handles logging for Cursor Cloud Agents pass-through requests."""
@staticmethod
def cursor_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Transform a Cursor API response into a standard logging payload.
"""
try:
method = httpx_response.request.method
path = httpx.URL(url_route).path
operation = _classify_cursor_request(method, path)
agent_id = response_body.get("id", "")
agent_name = response_body.get("name", "")
agent_status = response_body.get("status", "")
model_name = f"cursor/{operation}"
summary_parts = []
if agent_id:
summary_parts.append(f"id={agent_id}")
if agent_name:
summary_parts.append(f"name={agent_name}")
if agent_status:
summary_parts.append(f"status={agent_status}")
response_summary = ", ".join(summary_parts) if summary_parts else result
kwargs["model"] = model_name
kwargs["response_cost"] = 0.0
logging_obj.model_call_details["model"] = model_name
logging_obj.model_call_details["custom_llm_provider"] = "cursor"
logging_obj.model_call_details["response_cost"] = 0.0
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=StandardPassThroughResponseObject(
response=response_summary
),
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
kwargs["standard_logging_object"] = standard_logging_object
verbose_proxy_logger.debug(
"Cursor passthrough logging: operation=%s, agent_id=%s",
operation,
agent_id,
)
return {
"result": StandardPassThroughResponseObject(response=response_summary),
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.exception(
"Error in Cursor passthrough logging handler: %s", e
)
return {
"result": StandardPassThroughResponseObject(response=result),
"kwargs": kwargs,
}

View File

@@ -0,0 +1,254 @@
import re
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.gemini.videos.transformation import GeminiVideoConfig
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as GeminiModelResponseIterator,
)
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import (
ModelResponse,
TextCompletionResponse,
)
if TYPE_CHECKING:
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from ..success_handler import PassThroughEndpointLogging
else:
PassThroughEndpointLogging = Any
EndpointType = Any
class GeminiPassthroughLoggingHandler:
@staticmethod
def gemini_passthrough_handler(
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
if "predictLongRunning" in url_route:
model = GeminiPassthroughLoggingHandler.extract_model_from_url(url_route)
gemini_video_config = GeminiVideoConfig()
litellm_video_response = (
gemini_video_config.transform_video_create_response(
model=model,
raw_response=httpx_response,
logging_obj=logging_obj,
custom_llm_provider="gemini",
request_data=request_body,
)
)
logging_obj.model = model
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "gemini"
logging_obj.custom_llm_provider = "gemini"
response_cost = litellm.completion_cost(
completion_response=litellm_video_response,
model=model,
custom_llm_provider="gemini",
call_type="create_video",
)
# Set response_cost in _hidden_params to prevent recalculation
if not hasattr(litellm_video_response, "_hidden_params"):
litellm_video_response._hidden_params = {}
litellm_video_response._hidden_params["response_cost"] = response_cost
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = "gemini"
logging_obj.model_call_details["response_cost"] = response_cost
return {
"result": litellm_video_response,
"kwargs": kwargs,
}
if "generateContent" in url_route:
model = GeminiPassthroughLoggingHandler.extract_model_from_url(url_route)
# Use Gemini config for transformation
instance_of_gemini_llm = litellm.GoogleAIStudioGeminiConfig()
litellm_model_response: ModelResponse = (
instance_of_gemini_llm.transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
)
)
kwargs = GeminiPassthroughLoggingHandler._create_gemini_response_logging_payload_for_generate_content(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
custom_llm_provider="gemini",
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
else:
return {
"result": None,
"kwargs": kwargs,
}
@staticmethod
def _handle_logging_gemini_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
model: Optional[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Gemini passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
kwargs: Dict[str, Any] = {}
model = model or GeminiPassthroughLoggingHandler.extract_model_from_url(
url_route
)
complete_streaming_response = (
GeminiPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
url_route=url_route,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Gemini passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": kwargs,
}
kwargs = GeminiPassthroughLoggingHandler._create_gemini_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
custom_llm_provider="gemini",
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
url_route: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
parsed_chunks = []
if "generateContent" in url_route or "streamGenerateContent" in url_route:
gemini_iterator: Any = GeminiModelResponseIterator(
streaming_response=None,
sync_stream=False,
logging_obj=litellm_logging_obj,
)
chunk_parsing_logic: Any = gemini_iterator._common_chunk_parsing_logic
parsed_chunks = [chunk_parsing_logic(chunk) for chunk in all_chunks]
else:
return None
if len(parsed_chunks) == 0:
return None
all_openai_chunks = []
for parsed_chunk in parsed_chunks:
if parsed_chunk is None:
continue
all_openai_chunks.append(parsed_chunk)
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response
@staticmethod
def extract_model_from_url(url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
@staticmethod
def _create_gemini_response_logging_payload_for_generate_content(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: str,
):
"""
Create the standard logging object for Gemini passthrough generateContent (streaming and non-streaming)
"""
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
custom_llm_provider="gemini",
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = custom_llm_provider
# pretty print standard logging object
verbose_proxy_logger.debug("kwargs= %s", kwargs)
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
logging_obj.model = litellm_model_response.model or model
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
logging_obj.model_call_details["response_cost"] = response_cost
return kwargs

View File

@@ -0,0 +1,608 @@
"""
OpenAI Passthrough Logging Handler
Handles cost tracking and logging for OpenAI passthrough endpoints, specifically /chat/completions.
"""
from datetime import datetime
from typing import List, Optional, Union
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
from litellm.llms.openai.openai import OpenAIConfig
from litellm.llms.openai.openai import OpenAIConfig as OpenAIConfigType
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.base_passthrough_logging_handler import (
BasePassthroughLoggingHandler,
)
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import ImageResponse, LlmProviders, PassthroughCallTypes
from litellm.utils import ModelResponse, TextCompletionResponse
class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler):
"""
OpenAI-specific passthrough logging handler that provides cost tracking for /chat/completions endpoints.
"""
@property
def llm_provider_name(self) -> LlmProviders:
return LlmProviders.OPENAI
def get_provider_config(self, model: str) -> OpenAIConfigType:
"""Get OpenAI provider configuration for the given model."""
return OpenAIConfig()
@staticmethod
def is_openai_chat_completions_route(url_route: str) -> bool:
"""Check if the URL route is an OpenAI chat completions endpoint."""
if not url_route:
return False
parsed_url = urlparse(url_route)
return bool(
parsed_url.hostname
and (
"api.openai.com" in parsed_url.hostname
or "openai.azure.com" in parsed_url.hostname
)
and "/v1/chat/completions" in parsed_url.path
)
@staticmethod
def is_openai_image_generation_route(url_route: str) -> bool:
"""Check if the URL route is an OpenAI image generation endpoint."""
if not url_route:
return False
parsed_url = urlparse(url_route)
return bool(
parsed_url.hostname
and (
"api.openai.com" in parsed_url.hostname
or "openai.azure.com" in parsed_url.hostname
)
and "/v1/images/generations" in parsed_url.path
)
@staticmethod
def is_openai_image_editing_route(url_route: str) -> bool:
"""Check if the URL route is an OpenAI image editing endpoint."""
if not url_route:
return False
parsed_url = urlparse(url_route)
return bool(
parsed_url.hostname
and (
"api.openai.com" in parsed_url.hostname
or "openai.azure.com" in parsed_url.hostname
)
and "/v1/images/edits" in parsed_url.path
)
@staticmethod
def is_openai_responses_route(url_route: str) -> bool:
"""Check if the URL route is an OpenAI responses API endpoint."""
if not url_route:
return False
parsed_url = urlparse(url_route)
return bool(
parsed_url.hostname
and (
"api.openai.com" in parsed_url.hostname
or "openai.azure.com" in parsed_url.hostname
)
and ("/v1/responses" in parsed_url.path or "/responses" in parsed_url.path)
)
def _get_user_from_metadata(
self,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
) -> Optional[str]:
"""Extract user information from passthrough logging payload."""
request_body = passthrough_logging_payload.get("request_body")
if request_body:
return request_body.get("user")
return None
@staticmethod
def _calculate_image_generation_cost(
model: str,
response_body: dict,
request_body: dict,
) -> float:
"""Calculate cost for OpenAI image generation."""
try:
# Extract parameters from request
n = request_body.get("n", 1)
try:
n = int(n)
except Exception:
n = 1
size = request_body.get("size", "1024x1024")
quality = request_body.get("quality", None)
# Use LiteLLM's default image cost calculator
from litellm.cost_calculator import default_image_cost_calculator
cost = default_image_cost_calculator(
model=model,
custom_llm_provider="openai",
quality=quality,
n=n,
size=size,
optional_params=request_body,
)
return cost
except Exception as e:
verbose_proxy_logger.warning(
f"Error calculating image generation cost: {str(e)}"
)
return 0.0
@staticmethod
def _calculate_image_editing_cost(
model: str,
response_body: dict,
request_body: dict,
) -> float:
"""Calculate cost for OpenAI image editing."""
try:
# Extract parameters from request
n = request_body.get("n", 1)
# Image edit typically uses multipart/form-data (because of files), so all fields arrive as strings (e.g., n = "1").
try:
n = int(n)
except Exception:
n = 1
size = request_body.get("size", "1024x1024")
# Use LiteLLM's default image cost calculator
from litellm.cost_calculator import default_image_cost_calculator
cost = default_image_cost_calculator(
model=model,
custom_llm_provider="openai",
quality=None, # Image editing doesn't have quality parameter
n=n,
size=size,
optional_params=request_body,
)
return cost
except Exception as e:
verbose_proxy_logger.warning(
f"Error calculating image editing cost: {str(e)}"
)
return 0.0
@staticmethod
def openai_passthrough_handler( # noqa: PLR0915
httpx_response: httpx.Response,
response_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle OpenAI passthrough logging with cost tracking for chat completions, image generation, image editing, and responses API.
"""
# Check if this is a supported endpoint for cost tracking
is_chat_completions = (
OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route(url_route)
)
is_image_generation = (
OpenAIPassthroughLoggingHandler.is_openai_image_generation_route(url_route)
)
is_image_editing = (
OpenAIPassthroughLoggingHandler.is_openai_image_editing_route(url_route)
)
is_responses = OpenAIPassthroughLoggingHandler.is_openai_responses_route(
url_route
)
if not (
is_chat_completions
or is_image_generation
or is_image_editing
or is_responses
):
# For unsupported endpoints, return None to let the system fall back to generic behavior
return {
"result": None,
"kwargs": kwargs,
}
# Extract model from request or response
model = request_body.get("model", response_body.get("model", ""))
if not model:
verbose_proxy_logger.warning(
"No model found in request or response for OpenAI passthrough cost tracking"
)
base_handler = OpenAIPassthroughLoggingHandler()
return base_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
try:
response_cost = 0.0
litellm_model_response: Optional[
Union[ModelResponse, TextCompletionResponse, ImageResponse]
] = None
handler_instance = OpenAIPassthroughLoggingHandler()
custom_llm_provider = kwargs.get("custom_llm_provider", "openai")
if is_chat_completions:
# Handle chat completions with existing logic
provider_config = handler_instance.get_provider_config(model=model)
# Preserve existing litellm_params to maintain metadata tags
existing_litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_model_response = provider_config.transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=request_body.get("messages", []),
logging_obj=logging_obj,
optional_params=request_body.get("optional_params", {}),
api_key="",
request_data=request_body,
encoding=litellm.encoding,
json_mode=request_body.get("response_format", {}).get("type")
== "json_object",
litellm_params=existing_litellm_params,
)
# Calculate cost using LiteLLM's cost calculator
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
custom_llm_provider=custom_llm_provider,
)
elif is_image_generation:
# Handle image generation cost calculation
response_cost = (
OpenAIPassthroughLoggingHandler._calculate_image_generation_cost(
model=model,
response_body=response_body,
request_body=request_body,
)
)
# Mark call type for downstream image-aware logic/metrics
try:
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
except Exception:
pass
# Create a simple response object for logging
litellm_model_response = ImageResponse(
data=response_body.get("data", []),
model=model,
)
# Set the calculated cost in _hidden_params to prevent recalculation
if not hasattr(litellm_model_response, "_hidden_params"):
litellm_model_response._hidden_params = {}
litellm_model_response._hidden_params["response_cost"] = response_cost
elif is_image_editing:
# Handle image editing cost calculation
response_cost = (
OpenAIPassthroughLoggingHandler._calculate_image_editing_cost(
model=model,
response_body=response_body,
request_body=request_body,
)
)
# Mark call type for downstream image-aware logic/metrics
try:
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
except Exception:
pass
# Create a simple response object for logging
litellm_model_response = ImageResponse(
data=response_body.get("data", []),
model=model,
)
# Set the calculated cost in _hidden_params to prevent recalculation
if not hasattr(litellm_model_response, "_hidden_params"):
litellm_model_response._hidden_params = {}
litellm_model_response._hidden_params["response_cost"] = response_cost
elif is_responses:
# Handle responses API cost calculation
provider_config = handler_instance.get_provider_config(model=model)
existing_litellm_params = kwargs.get("litellm_params", {}) or {}
litellm_model_response = provider_config.transform_response(
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
model=model,
messages=request_body.get("messages", []),
logging_obj=logging_obj,
optional_params=request_body.get("optional_params", {}),
api_key="",
request_data=request_body,
encoding=litellm.encoding,
json_mode=False,
litellm_params=existing_litellm_params,
)
# Calculate cost using LiteLLM's cost calculator with responses call type
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
custom_llm_provider=custom_llm_provider,
call_type="responses",
)
# Update kwargs with cost information
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = custom_llm_provider
# Extract user information for tracking
passthrough_logging_payload: Optional[
PassthroughStandardLoggingPayload
] = kwargs.get("passthrough_logging_payload")
if passthrough_logging_payload:
user = handler_instance._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs["litellm_params"].setdefault(
"proxy_server_request", {}
).setdefault("body", {})["user"] = user
# Create standard logging object
if litellm_model_response is not None:
get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=litellm_model_response,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
# Update logging object with cost information
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
logging_obj.model_call_details["response_cost"] = response_cost
endpoint_type = (
"chat_completions"
if is_chat_completions
else "image_generation"
if is_image_generation
else "image_editing"
)
verbose_proxy_logger.debug(
f"OpenAI passthrough cost tracking - Endpoint: {endpoint_type}, Model: {model}, Cost: ${response_cost:.6f}"
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.error(
f"Error in OpenAI passthrough cost tracking: {str(e)}"
)
# Fall back to base handler without cost tracking
base_handler = OpenAIPassthroughLoggingHandler()
return base_handler.passthrough_chat_handler(
httpx_response=httpx_response,
response_body=response_body,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
def _build_complete_streaming_response(
self,
all_chunks: list,
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
"""
Builds complete response from raw chunks for OpenAI streaming responses.
- Converts str chunks to generic chunks
- Converts generic chunks to litellm chunks (OpenAI format)
- Builds complete response from litellm chunks
"""
try:
# OpenAI's response iterator to parse chunks
from litellm.llms.openai.openai import OpenAIChatCompletionResponseIterator
openai_iterator = OpenAIChatCompletionResponseIterator(
streaming_response=None,
sync_stream=False,
)
all_openai_chunks = []
for chunk_str in all_chunks:
try:
# Parse the string chunk using the base iterator's string parser
from litellm.llms.base_llm.base_model_iterator import (
BaseModelResponseIterator,
)
# Convert string chunk to dict
stripped_json_chunk = (
BaseModelResponseIterator._string_to_dict_parser(
str_line=chunk_str
)
)
if stripped_json_chunk:
# Parse the chunk using OpenAI's chunk parser
transformed_chunk = openai_iterator.chunk_parser(
chunk=stripped_json_chunk
)
if transformed_chunk is not None:
all_openai_chunks.append(transformed_chunk)
except (StopIteration, StopAsyncIteration, Exception) as e:
verbose_proxy_logger.debug(f"Error parsing streaming chunk: {e}")
continue
if not all_openai_chunks:
verbose_proxy_logger.warning(
"No valid chunks found in streaming response"
)
return None
# Build complete response from chunks
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response
except Exception as e:
verbose_proxy_logger.error(
f"Error building complete streaming response: {str(e)}"
)
return None
@staticmethod
def _handle_logging_openai_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle logging for collected OpenAI streaming chunks with cost tracking.
"""
try:
# Extract model from request body
model = request_body.get("model", "gpt-4o")
# Build complete response from chunks using our streaming handler
handler = OpenAIPassthroughLoggingHandler()
handler_instance = handler
complete_response = handler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
)
if complete_response is None:
verbose_proxy_logger.warning(
"Failed to build complete response from OpenAI streaming chunks"
)
return {
"result": None,
"kwargs": {},
}
custom_llm_provider = litellm_logging_obj.model_call_details.get(
"custom_llm_provider", "openai"
)
# Calculate cost using LiteLLM's cost calculator
response_cost = litellm.completion_cost(
completion_response=complete_response,
model=model,
custom_llm_provider=custom_llm_provider,
)
# Preserve existing litellm_params to maintain metadata tags
existing_litellm_params = (
litellm_logging_obj.model_call_details.get("litellm_params", {}) or {}
)
# Prepare kwargs for logging
kwargs = {
"response_cost": response_cost,
"model": model,
"custom_llm_provider": custom_llm_provider,
"litellm_params": existing_litellm_params.copy(),
}
# Extract user information for tracking
passthrough_logging_payload: Optional[
PassthroughStandardLoggingPayload
] = litellm_logging_obj.model_call_details.get(
"passthrough_logging_payload"
)
if passthrough_logging_payload:
user = handler_instance._get_user_from_metadata(
passthrough_logging_payload=passthrough_logging_payload,
)
if user:
kwargs["litellm_params"].setdefault(
"proxy_server_request", {}
).setdefault("body", {})["user"] = user
# Create standard logging object
get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=complete_response,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
status="success",
)
# Update logging object with cost information
litellm_logging_obj.model_call_details["model"] = model
litellm_logging_obj.model_call_details[
"custom_llm_provider"
] = custom_llm_provider
litellm_logging_obj.model_call_details["response_cost"] = response_cost
verbose_proxy_logger.debug(
f"OpenAI streaming passthrough cost tracking - Model: {model}, Cost: ${response_cost:.6f}"
)
return {
"result": complete_response,
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.error(
f"Error in OpenAI streaming passthrough cost tracking: {str(e)}"
)
return {
"result": None,
"kwargs": {},
}

View File

@@ -0,0 +1,403 @@
"""
Vertex AI Live API WebSocket Passthrough Logging Handler
Handles cost tracking and logging for Vertex AI Live API WebSocket passthrough endpoints.
Supports different modalities: text, audio, video, and web search.
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.base_passthrough_logging_handler import (
BasePassthroughLoggingHandler,
)
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.openai_passthrough_logging_handler import (
PassThroughEndpointLoggingTypedDict,
)
from litellm.types.utils import LlmProviders, ModelResponse, Usage
from litellm.utils import get_model_info
class VertexAILivePassthroughLoggingHandler(BasePassthroughLoggingHandler):
"""
Handles cost tracking and logging for Vertex AI Live API WebSocket passthrough.
Supports:
- Text tokens (input/output)
- Audio tokens (input/output)
- Video tokens (input/output)
- Web search requests
- Tool use tokens
"""
def _build_complete_streaming_response(self, *args, **kwargs):
"""Not applicable for WebSocket passthrough."""
return None
def get_provider_config(self, model: str):
"""Return Vertex AI provider configuration."""
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
VertexGeminiConfig,
)
return VertexGeminiConfig()
@property
def llm_provider_name(self) -> LlmProviders:
"""Return the LLM provider name."""
return LlmProviders.VERTEX_AI
@staticmethod
def _extract_usage_metadata_from_websocket_messages(
websocket_messages: List[Dict],
) -> Optional[Dict]:
"""
Extract and aggregate usage metadata from a list of WebSocket messages.
Args:
websocket_messages: List of WebSocket messages from the Live API
Returns:
Dictionary containing aggregated usage metadata, or None if not found
"""
all_usage_metadata = []
# Collect all usage metadata messages
for message in websocket_messages:
if isinstance(message, dict) and "usageMetadata" in message:
all_usage_metadata.append(message["usageMetadata"])
if not all_usage_metadata:
return None
# If only one usage metadata, return it as-is
if len(all_usage_metadata) == 1:
return all_usage_metadata[0]
# Aggregate multiple usage metadata messages
aggregated: Dict[str, Any] = {
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
"promptTokensDetails": [],
"candidatesTokensDetails": [],
}
# Aggregate token counts
for usage in all_usage_metadata:
aggregated["promptTokenCount"] += usage.get("promptTokenCount", 0)
aggregated["candidatesTokenCount"] += usage.get("candidatesTokenCount", 0)
aggregated["totalTokenCount"] += usage.get("totalTokenCount", 0)
# Aggregate token details by modality
modality_totals = {}
for usage in all_usage_metadata:
# Process prompt tokens details
for detail in usage.get("promptTokensDetails", []):
modality = detail.get("modality", "TEXT")
token_count = detail.get("tokenCount", 0)
if modality not in modality_totals:
modality_totals[modality] = {"prompt": 0, "candidate": 0}
modality_totals[modality]["prompt"] += token_count
# Process candidate tokens details
for detail in usage.get("candidatesTokensDetails", []):
modality = detail.get("modality", "TEXT")
token_count = detail.get("tokenCount", 0)
if modality not in modality_totals:
modality_totals[modality] = {"prompt": 0, "candidate": 0}
modality_totals[modality]["candidate"] += token_count
# Convert aggregated modality totals back to details format
for modality, totals in modality_totals.items():
if totals["prompt"] > 0:
aggregated["promptTokensDetails"].append(
{"modality": modality, "tokenCount": totals["prompt"]}
)
if totals["candidate"] > 0:
aggregated["candidatesTokensDetails"].append(
{"modality": modality, "tokenCount": totals["candidate"]}
)
# Add any additional fields from the first usage metadata
first_usage = all_usage_metadata[0]
for key, value in first_usage.items():
if key not in aggregated:
aggregated[key] = value
return aggregated
@staticmethod
def _calculate_live_api_cost(
model: str,
usage_metadata: Dict,
custom_llm_provider: str = "vertex_ai",
) -> float:
"""
Calculate cost for Vertex AI Live API based on usage metadata.
Args:
model: The model name (e.g., "gemini-2.0-flash-live-preview-04-09")
usage_metadata: Usage metadata from the Live API response
custom_llm_provider: The LLM provider (default: "vertex_ai")
Returns:
Total cost in USD
"""
try:
# Get model pricing information
model_info = get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
verbose_proxy_logger.debug(
f"Vertex AI Live API model info for '{model}': {model_info}"
)
# Check if pricing info is available
if not model_info or not model_info.get("input_cost_per_token"):
verbose_proxy_logger.error(
f"No pricing info found for {model} in local model pricing database"
)
return 0.0
total_cost = 0.0
# Extract token counts from usage metadata
prompt_token_count = usage_metadata.get("promptTokenCount", 0)
candidates_token_count = usage_metadata.get("candidatesTokenCount", 0)
# Calculate base text token costs
input_cost_per_token = model_info.get("input_cost_per_token", 0.0)
output_cost_per_token = model_info.get("output_cost_per_token", 0.0)
total_cost += prompt_token_count * input_cost_per_token
total_cost += candidates_token_count * output_cost_per_token
# Handle modality-specific costs if present
prompt_tokens_details = usage_metadata.get("promptTokensDetails", [])
candidates_tokens_details = usage_metadata.get(
"candidatesTokensDetails", []
)
# Process prompt tokens by modality
for detail in prompt_tokens_details:
modality = detail.get("modality", "TEXT")
token_count = detail.get("tokenCount", 0)
if modality == "AUDIO":
audio_cost_per_token = model_info.get(
"input_cost_per_audio_token", 0.0
)
total_cost += token_count * audio_cost_per_token
elif modality == "VIDEO":
# Video tokens are typically per second, but we'll treat as per token for now
video_cost_per_token = model_info.get(
"input_cost_per_video_per_second", 0.0
)
total_cost += token_count * video_cost_per_token
# TEXT tokens are already handled above
# Process candidate tokens by modality
for detail in candidates_tokens_details:
modality = detail.get("modality", "TEXT")
token_count = detail.get("tokenCount", 0)
if modality == "AUDIO":
audio_cost_per_token = model_info.get(
"output_cost_per_audio_token", 0.0
)
total_cost += token_count * audio_cost_per_token
elif modality == "VIDEO":
# Video tokens are typically per second, but we'll treat as per token for now
video_cost_per_token = model_info.get(
"output_cost_per_video_per_second", 0.0
)
total_cost += token_count * video_cost_per_token
# TEXT tokens are already handled above
# Handle web search costs if present
tool_use_prompt_token_count = usage_metadata.get(
"toolUsePromptTokenCount", 0
)
if tool_use_prompt_token_count > 0:
# Web search typically has a fixed cost per request
web_search_cost = model_info.get("web_search_cost_per_request", 0.0)
if isinstance(web_search_cost, (int, float)) and web_search_cost > 0:
total_cost += web_search_cost
else:
# Fallback to token-based pricing for tool use
total_cost += tool_use_prompt_token_count * input_cost_per_token
verbose_proxy_logger.debug(
f"Vertex AI Live API cost calculation - Model: {model}, "
f"Prompt tokens: {prompt_token_count}, "
f"Candidate tokens: {candidates_token_count}, "
f"Total cost: ${total_cost:.6f}"
)
return total_cost
except Exception as e:
verbose_proxy_logger.error(
f"Error calculating Vertex AI Live API cost: {e}"
)
return 0.0
@staticmethod
def _create_usage_object_from_metadata(
usage_metadata: Dict,
model: str,
) -> Usage:
"""
Create a LiteLLM Usage object from Live API usage metadata.
Args:
usage_metadata: Usage metadata from the Live API response
model: The model name
Returns:
LiteLLM Usage object
"""
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
total_tokens = usage_metadata.get("totalTokenCount", 0)
# Create modality-specific token details if available
prompt_tokens_details = usage_metadata.get("promptTokensDetails", [])
candidates_tokens_details = usage_metadata.get("candidatesTokensDetails", [])
# Extract text tokens from details
text_prompt_tokens = 0
text_completion_tokens = 0
for detail in prompt_tokens_details:
if detail.get("modality") == "TEXT":
text_prompt_tokens = detail.get("tokenCount", 0)
break
for detail in candidates_tokens_details:
if detail.get("modality") == "TEXT":
text_completion_tokens = detail.get("tokenCount", 0)
break
# If no text tokens found in details, use total counts
if text_prompt_tokens == 0:
text_prompt_tokens = prompt_tokens
if text_completion_tokens == 0:
text_completion_tokens = completion_tokens
return Usage(
prompt_tokens=text_prompt_tokens,
completion_tokens=text_completion_tokens,
total_tokens=total_tokens,
)
def vertex_ai_live_passthrough_handler(
self,
websocket_messages: List[Dict],
logging_obj,
url_route: str,
start_time: datetime,
end_time: datetime,
request_body: dict,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle cost tracking and logging for Vertex AI Live API WebSocket passthrough.
Args:
websocket_messages: List of WebSocket messages from the Live API
logging_obj: LiteLLM logging object
url_route: The URL route that was called
start_time: Request start time
end_time: Request end time
request_body: The original request body
**kwargs: Additional keyword arguments
Returns:
Dictionary containing the result and kwargs for logging
"""
try:
# Extract model from request body or kwargs
model = kwargs.get("model", "gemini-2.0-flash-live-preview-04-09")
custom_llm_provider = kwargs.get("custom_llm_provider", "vertex_ai")
verbose_proxy_logger.debug(
f"Vertex AI Live API model: {model}, custom_llm_provider: {custom_llm_provider}"
)
# Extract usage metadata from WebSocket messages
usage_metadata = self._extract_usage_metadata_from_websocket_messages(
websocket_messages
)
if not usage_metadata:
verbose_proxy_logger.warning(
"No usage metadata found in Vertex AI Live API WebSocket messages"
)
return {
"result": None,
"kwargs": kwargs,
}
# Calculate cost using Live API specific pricing
response_cost = self._calculate_live_api_cost(
model=model,
usage_metadata=usage_metadata,
custom_llm_provider=custom_llm_provider,
)
# Create Usage object for standard LiteLLM logging
usage = self._create_usage_object_from_metadata(
usage_metadata=usage_metadata,
model=model,
)
# Create a mock ModelResponse for standard logging
litellm_model_response = ModelResponse(
id=f"vertex-ai-live-{start_time.timestamp()}",
object="chat.completion",
created=int(start_time.timestamp()),
model=model,
usage=usage,
choices=[],
)
# Update kwargs with cost information
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = custom_llm_provider
# Safely log the model name: only allow known safe formats, redact otherwise.
import re
allowed_pattern = re.compile(r"^[A-Za-z0-9._\-:]+$")
safe_model = (
model
if isinstance(model, str) and allowed_pattern.match(model)
else "[REDACTED]"
)
verbose_proxy_logger.debug(
f"Vertex AI Live API passthrough cost tracking - "
f"Model: {safe_model}, Cost: ${response_cost:.6f}, "
f"Prompt tokens: {usage.prompt_tokens}, "
f"Completion tokens: {usage.completion_tokens}"
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.error(
f"Error in Vertex AI Live API passthrough handler: {e}"
)
return {
"result": None,
"kwargs": kwargs,
}

View File

@@ -0,0 +1,851 @@
import re
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
from urllib.parse import urlparse
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
ModelResponseIterator as VertexModelResponseIterator,
)
from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
VertexSearchAPIVectorStoreConfig,
)
from litellm.llms.vertex_ai.videos.transformation import VertexAIVideoConfig
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
from litellm.types.utils import (
Choices,
EmbeddingResponse,
ImageResponse,
ModelResponse,
SpecialEnums,
StandardPassThroughResponseObject,
TextCompletionResponse,
)
vertex_search_api_config = VertexSearchAPIVectorStoreConfig()
if TYPE_CHECKING:
from litellm.types.utils import LiteLLMBatch
from ..success_handler import PassThroughEndpointLogging
else:
PassThroughEndpointLogging = Any
LiteLLMBatch = Any
# Define EndpointType locally to avoid import issues
EndpointType = Any
class VertexPassthroughLoggingHandler:
@staticmethod
def vertex_passthrough_handler(
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: Optional[dict] = None,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
if "predictLongRunning" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
vertex_video_config = VertexAIVideoConfig()
litellm_video_response = (
vertex_video_config.transform_video_create_response(
model=model,
raw_response=httpx_response,
logging_obj=logging_obj,
custom_llm_provider="vertex_ai",
request_data=request_body,
)
)
logging_obj.model = model
logging_obj.model_call_details["model"] = model
logging_obj.model_call_details["custom_llm_provider"] = "vertex_ai"
logging_obj.custom_llm_provider = "vertex_ai"
response_cost = litellm.completion_cost(
completion_response=litellm_video_response,
model=model,
custom_llm_provider="vertex_ai",
call_type="create_video",
)
# Set response_cost in _hidden_params to prevent recalculation
if not hasattr(litellm_video_response, "_hidden_params"):
litellm_video_response._hidden_params = {}
litellm_video_response._hidden_params["response_cost"] = response_cost
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = "vertex_ai"
logging_obj.model_call_details["response_cost"] = response_cost
return {
"result": litellm_video_response,
"kwargs": kwargs,
}
elif "generateContent" in url_route:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
instance_of_vertex_llm = litellm.VertexGeminiConfig()
litellm_model_response: ModelResponse = (
instance_of_vertex_llm.transform_response(
model=model,
messages=[
{"role": "user", "content": "no-message-pass-through-endpoint"}
],
raw_response=httpx_response,
model_response=litellm.ModelResponse(),
logging_obj=logging_obj,
optional_params={},
litellm_params={},
api_key="",
request_data={},
encoding=litellm.encoding,
)
)
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=litellm_model_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
url_route
),
)
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
elif "predict" in url_route:
return VertexPassthroughLoggingHandler._handle_predict_response(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
kwargs=kwargs,
)
elif "rawPredict" in url_route or "streamRawPredict" in url_route:
from litellm.llms.vertex_ai.vertex_ai_partner_models import (
get_vertex_ai_partner_model_config,
)
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
vertex_publisher_or_api_spec = VertexPassthroughLoggingHandler._get_vertex_publisher_or_api_spec_from_url(
url_route
)
_json_response = httpx_response.json()
litellm_prediction_response = ModelResponse()
if vertex_publisher_or_api_spec is not None:
vertex_ai_partner_model_config = get_vertex_ai_partner_model_config(
model=model,
vertex_publisher_or_api_spec=vertex_publisher_or_api_spec,
)
litellm_prediction_response = (
vertex_ai_partner_model_config.transform_response(
model=model,
raw_response=httpx_response,
model_response=litellm_prediction_response,
logging_obj=logging_obj,
request_data={},
encoding=litellm.encoding,
optional_params={},
litellm_params={},
api_key="",
messages=[
{
"role": "user",
"content": "no-message-pass-through-endpoint",
}
],
)
)
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=litellm_prediction_response,
model="vertex_ai/" + model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
custom_llm_provider="vertex_ai",
)
return {
"result": litellm_prediction_response,
"kwargs": kwargs,
}
elif "search" in url_route:
litellm_vs_response = (
vertex_search_api_config.transform_search_vector_store_response(
response=httpx_response,
litellm_logging_obj=logging_obj,
)
)
response_cost = litellm.completion_cost(
completion_response=litellm_vs_response,
model="vertex_ai/search_api",
custom_llm_provider="vertex_ai",
call_type="vector_store_search",
)
standard_pass_through_response_object: StandardPassThroughResponseObject = {
"response": cast(dict, litellm_vs_response),
}
kwargs["response_cost"] = response_cost
kwargs["model"] = "vertex_ai/search_api"
logging_obj.model_call_details.setdefault("litellm_params", {})
logging_obj.model_call_details["litellm_params"][
"base_model"
] = "vertex_ai/search_api"
logging_obj.model_call_details["response_cost"] = response_cost
return {
"result": standard_pass_through_response_object,
"kwargs": kwargs,
}
elif "batchPredictionJobs" in url_route:
return VertexPassthroughLoggingHandler.batch_prediction_jobs_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
else:
return {
"result": None,
"kwargs": kwargs,
}
@staticmethod
def _handle_predict_response(
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
kwargs: dict,
) -> PassThroughEndpointLoggingTypedDict:
"""Handle predict endpoint responses (embeddings, image generation)."""
from litellm.llms.vertex_ai.image_generation.image_generation_handler import (
VertexImageGeneration,
)
from litellm.llms.vertex_ai.multimodal_embeddings.transformation import (
VertexAIMultimodalEmbeddingConfig,
)
from litellm.types.utils import PassthroughCallTypes
vertex_image_generation_class = VertexImageGeneration()
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
_json_response = httpx_response.json()
litellm_prediction_response: Union[
ModelResponse, EmbeddingResponse, ImageResponse
] = ModelResponse()
if vertex_image_generation_class.is_image_generation_response(_json_response):
litellm_prediction_response = (
vertex_image_generation_class.process_image_generation_response(
_json_response,
model_response=litellm.ImageResponse(),
model=model,
)
)
logging_obj.call_type = (
PassthroughCallTypes.passthrough_image_generation.value
)
elif VertexPassthroughLoggingHandler._is_multimodal_embedding_response(
json_response=_json_response,
):
# Use multimodal embedding transformation
vertex_multimodal_config = VertexAIMultimodalEmbeddingConfig()
litellm_prediction_response = (
vertex_multimodal_config.transform_embedding_response(
model=model,
raw_response=httpx_response,
model_response=litellm.EmbeddingResponse(),
logging_obj=logging_obj,
api_key="",
request_data={},
optional_params={},
litellm_params={},
)
)
else:
litellm_prediction_response = (
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
response=_json_response,
model=model,
model_response=litellm.EmbeddingResponse(),
)
)
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
litellm_prediction_response.model = model
logging_obj.model = model
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["custom_llm_provider"] = "vertex_ai"
logging_obj.custom_llm_provider = "vertex_ai"
response_cost = litellm.completion_cost(
completion_response=litellm_prediction_response,
model=model,
custom_llm_provider="vertex_ai",
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
kwargs["custom_llm_provider"] = "vertex_ai"
logging_obj.model_call_details["response_cost"] = response_cost
return {
"result": litellm_prediction_response,
"kwargs": kwargs,
}
@staticmethod
def _handle_logging_vertex_collected_chunks(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
all_chunks: List[str],
model: Optional[str],
end_time: datetime,
) -> PassThroughEndpointLoggingTypedDict:
"""
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
- Builds complete response from chunks
- Creates standard logging object
- Logs in litellm callbacks
"""
kwargs: Dict[str, Any] = {}
model = model or VertexPassthroughLoggingHandler.extract_model_from_url(
url_route
)
complete_streaming_response = (
VertexPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,
litellm_logging_obj=litellm_logging_obj,
model=model,
url_route=url_route,
)
)
if complete_streaming_response is None:
verbose_proxy_logger.error(
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
)
return {
"result": None,
"kwargs": kwargs,
}
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
litellm_model_response=complete_streaming_response,
model=model,
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=litellm_logging_obj,
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
url_route
),
)
return {
"result": complete_streaming_response,
"kwargs": kwargs,
}
@staticmethod
def _build_complete_streaming_response(
all_chunks: List[str],
litellm_logging_obj: LiteLLMLoggingObj,
model: str,
url_route: str,
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
parsed_chunks = []
if "generateContent" in url_route or "streamGenerateContent" in url_route:
vertex_iterator: Any = VertexModelResponseIterator(
streaming_response=None,
sync_stream=False,
logging_obj=litellm_logging_obj,
)
chunk_parsing_logic: Any = vertex_iterator._common_chunk_parsing_logic
parsed_chunks = [chunk_parsing_logic(chunk) for chunk in all_chunks]
elif "rawPredict" in url_route or "streamRawPredict" in url_route:
from litellm.llms.anthropic.chat.handler import ModelResponseIterator
from litellm.llms.base_llm.base_model_iterator import (
BaseModelResponseIterator,
)
vertex_iterator = ModelResponseIterator(
streaming_response=None,
sync_stream=False,
)
chunk_parsing_logic = vertex_iterator.chunk_parser
for chunk in all_chunks:
dict_chunk = BaseModelResponseIterator._string_to_dict_parser(chunk)
if dict_chunk is None:
continue
parsed_chunks.append(chunk_parsing_logic(dict_chunk))
else:
return None
if len(parsed_chunks) == 0:
return None
all_openai_chunks = []
for parsed_chunk in parsed_chunks:
if parsed_chunk is None:
continue
all_openai_chunks.append(parsed_chunk)
complete_streaming_response = litellm.stream_chunk_builder(
chunks=all_openai_chunks
)
return complete_streaming_response
@staticmethod
def extract_model_from_url(url: str) -> str:
pattern = r"/models/([^:]+)"
match = re.search(pattern, url)
if match:
return match.group(1)
return "unknown"
@staticmethod
def extract_model_name_from_vertex_path(vertex_model_path: str) -> str:
"""
Extract the actual model name from a Vertex AI model path.
Examples:
- publishers/google/models/gemini-2.5-flash -> gemini-2.5-flash
- projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID -> MODEL_ID
Args:
vertex_model_path: The full Vertex AI model path
Returns:
The extracted model name for use with LiteLLM
"""
# Handle publishers/google/models/ format
if "publishers/" in vertex_model_path and "models/" in vertex_model_path:
# Extract everything after the last models/
parts = vertex_model_path.split("models/")
if len(parts) > 1:
return parts[-1]
# Handle projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID format
elif "projects/" in vertex_model_path and "models/" in vertex_model_path:
# Extract everything after the last models/
parts = vertex_model_path.split("models/")
if len(parts) > 1:
return parts[-1]
# If no recognized pattern, return the original path
return vertex_model_path
@staticmethod
def _get_vertex_publisher_or_api_spec_from_url(url: str) -> Optional[str]:
# Check for specific Vertex AI partner publishers
if "/publishers/mistralai/" in url:
return "mistralai"
elif "/publishers/anthropic/" in url:
return "anthropic"
elif "/publishers/ai21/" in url:
return "ai21"
elif "/endpoints/openapi/" in url:
return "openapi"
return None
@staticmethod
def _get_custom_llm_provider_from_url(url: str) -> str:
parsed_url = urlparse(url)
if parsed_url.hostname and parsed_url.hostname.endswith(
"generativelanguage.googleapis.com"
):
return litellm.LlmProviders.GEMINI.value
return litellm.LlmProviders.VERTEX_AI.value
@staticmethod
def _is_multimodal_embedding_response(json_response: dict) -> bool:
"""
Detect if the response is from a multimodal embedding request.
Check if the response contains multimodal embedding fields:
- Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api#response-body
Args:
json_response: The JSON response from Vertex AI
Returns:
bool: True if this is a multimodal embedding response
"""
# Check if response contains multimodal embedding fields
if "predictions" in json_response:
predictions = json_response["predictions"]
for prediction in predictions:
if isinstance(prediction, dict):
# Check for multimodal embedding response fields
if any(
key in prediction
for key in [
"textEmbedding",
"imageEmbedding",
"videoEmbeddings",
]
):
return True
return False
@staticmethod
def _create_vertex_response_logging_payload_for_generate_content(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
model: str,
kwargs: dict,
start_time: datetime,
end_time: datetime,
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: str,
) -> dict:
"""
Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming)
"""
response_cost = litellm.completion_cost(
completion_response=litellm_model_response,
model=model,
custom_llm_provider="vertex_ai",
)
kwargs["response_cost"] = response_cost
kwargs["model"] = model
# pretty print standard logging object
verbose_proxy_logger.debug("kwargs= %s", kwargs)
# set litellm_call_id to logging response object
litellm_model_response.id = logging_obj.litellm_call_id
logging_obj.model = litellm_model_response.model or model
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
return kwargs
@staticmethod
def batch_prediction_jobs_handler( # noqa: PLR0915
httpx_response: httpx.Response,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
) -> PassThroughEndpointLoggingTypedDict:
"""
Handle batch prediction jobs passthrough logging.
Creates a managed object for cost tracking when batch job is successfully created.
"""
import base64
from litellm._uuid import uuid
from litellm.llms.vertex_ai.batches.transformation import (
VertexAIBatchTransformation,
)
try:
_json_response = httpx_response.json()
# Only handle successful batch job creation (POST requests)
if httpx_response.status_code == 200 and "name" in _json_response:
# Transform Vertex AI response to LiteLLM batch format
litellm_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
response=_json_response
)
# Extract batch ID and model from the response
batch_id = VertexAIBatchTransformation._get_batch_id_from_vertex_ai_batch_response(
_json_response
)
model_name = _json_response.get("model", "unknown")
# Create unified object ID for tracking
# Format: base64(litellm_proxy;model_id:{};llm_batch_id:{})
actual_model_id = (
VertexPassthroughLoggingHandler.get_actual_model_id_from_router(
model_name
)
)
unified_id_string = (
SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
actual_model_id, batch_id
)
)
unified_object_id = (
base64.urlsafe_b64encode(unified_id_string.encode())
.decode()
.rstrip("=")
)
# Store the managed object for cost tracking
# This will be picked up by check_batch_cost polling mechanism
VertexPassthroughLoggingHandler._store_batch_managed_object(
unified_object_id=unified_object_id,
batch_object=litellm_batch_response,
model_object_id=batch_id,
logging_obj=logging_obj,
**kwargs,
)
# Create a batch job response for logging
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = model_name
litellm_model_response.object = "batch_prediction_job"
litellm_model_response.created = int(start_time.timestamp())
# Add batch-specific metadata to indicate this is a pending batch job
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Batch prediction job {batch_id} created and is pending. Status will be updated when the batch completes.",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_id": batch_id,
"batch_job_state": "JOB_STATE_PENDING",
"unified_object_id": unified_object_id,
},
},
)
]
# Set response cost to 0 initially (will be updated when batch completes)
response_cost = 0.0
kwargs["response_cost"] = response_cost
kwargs["model"] = model_name
kwargs["batch_id"] = batch_id
kwargs["unified_object_id"] = unified_object_id
kwargs["batch_job_state"] = "JOB_STATE_PENDING"
logging_obj.model = model_name
logging_obj.model_call_details["model"] = logging_obj.model
logging_obj.model_call_details["response_cost"] = response_cost
logging_obj.model_call_details["batch_id"] = batch_id
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
else:
# Handle non-successful responses
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = "vertex_ai_batch"
litellm_model_response.object = "batch_prediction_job"
litellm_model_response.created = int(start_time.timestamp())
# Add error-specific metadata
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Batch prediction job creation failed. Status: {httpx_response.status_code}",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_state": "JOB_STATE_FAILED",
"status_code": httpx_response.status_code,
},
},
)
]
kwargs["response_cost"] = 0.0
kwargs["model"] = "vertex_ai_batch"
kwargs["batch_job_state"] = "JOB_STATE_FAILED"
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
except Exception as e:
verbose_proxy_logger.error(f"Error in batch_prediction_jobs_handler: {e}")
# Return basic response on error
litellm_model_response = ModelResponse()
litellm_model_response.id = str(uuid.uuid4())
litellm_model_response.model = "vertex_ai_batch"
litellm_model_response.object = "batch_prediction_job"
litellm_model_response.created = int(start_time.timestamp())
# Add error-specific metadata
litellm_model_response.choices = [
Choices(
finish_reason="stop",
index=0,
message={
"role": "assistant",
"content": f"Error creating batch prediction job: {str(e)}",
"tool_calls": None,
"function_call": None,
"provider_specific_fields": {
"batch_job_state": "JOB_STATE_FAILED",
"error": str(e),
},
},
)
]
kwargs["response_cost"] = 0.0
kwargs["model"] = "vertex_ai_batch"
kwargs["batch_job_state"] = "JOB_STATE_FAILED"
return {
"result": litellm_model_response,
"kwargs": kwargs,
}
@staticmethod
def _store_batch_managed_object(
unified_object_id: str,
batch_object: LiteLLMBatch,
model_object_id: str,
logging_obj: LiteLLMLoggingObj,
**kwargs,
) -> None:
"""
Store batch managed object for cost tracking.
This will be picked up by the check_batch_cost polling mechanism.
"""
try:
# Get the managed files hook from the logging object
# This is a bit of a hack, but we need access to the proxy logging system
from litellm.proxy.proxy_server import proxy_logging_obj
managed_files_hook = proxy_logging_obj.get_proxy_hook("managed_files")
if managed_files_hook is not None and hasattr(
managed_files_hook, "store_unified_object_id"
):
# Create a mock user API key dict for the managed object storage
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
user_api_key_dict = UserAPIKeyAuth(
user_id=kwargs.get("user_id", "default-user"),
api_key="",
team_id=None,
team_alias=None,
user_role=LitellmUserRoles.CUSTOMER, # Use proper enum value
user_email=None,
max_budget=None,
spend=0.0, # Set to 0.0 instead of None
models=[], # Set to empty list instead of None
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
max_parallel_requests=None,
allowed_model_region=None,
metadata={}, # Set to empty dict instead of None
key_alias=None,
permissions={}, # Set to empty dict instead of None
model_max_budget={}, # Set to empty dict instead of None
model_spend={}, # Set to empty dict instead of None
)
# Store the unified object for batch cost tracking
import asyncio
asyncio.create_task(
managed_files_hook.store_unified_object_id( # type: ignore
unified_object_id=unified_object_id,
file_object=batch_object,
litellm_parent_otel_span=None,
model_object_id=model_object_id,
file_purpose="batch",
user_api_key_dict=user_api_key_dict,
)
)
verbose_proxy_logger.info(
f"Stored batch managed object with unified_object_id={unified_object_id}, batch_id={model_object_id}"
)
else:
verbose_proxy_logger.warning(
"Managed files hook not available, cannot store batch object for cost tracking"
)
except Exception as e:
verbose_proxy_logger.error(f"Error storing batch managed object: {e}")
@staticmethod
def get_actual_model_id_from_router(model_name: str) -> str:
from litellm.proxy.proxy_server import llm_router
if llm_router is not None:
# Try to find the model in the router by the extracted model name
extracted_model_name = (
VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(
model_name
)
)
# Use the existing get_model_ids method from router
model_ids = llm_router.get_model_ids(model_name=extracted_model_name)
if model_ids and len(model_ids) > 0:
# Use the first model ID found
actual_model_id = model_ids[0]
verbose_proxy_logger.info(
f"Found model ID in router: {actual_model_id}"
)
return actual_model_id
else:
# Fallback to constructed model name
actual_model_id = extracted_model_name
verbose_proxy_logger.warning(
f"Model not found in router, using constructed name: {actual_model_id}"
)
return actual_model_id
else:
# Fallback if router is not available
extracted_model_name = (
VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(
model_name
)
)
verbose_proxy_logger.warning(
f"Router not available, using constructed model name: {extracted_model_name}"
)
return extracted_model_name

View File

@@ -0,0 +1,212 @@
from typing import Dict, Optional
import litellm
from litellm._logging import verbose_router_logger
from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import (
LiteLLM_ManagedVectorStore,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
class PassthroughEndpointRouter:
"""
Use this class to Set/Get credentials for pass-through endpoints
"""
def __init__(self):
self.credentials: Dict[str, str] = {}
self.deployment_key_to_vertex_credentials: Dict[
str, VertexPassThroughCredentials
] = {}
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
def set_pass_through_credentials(
self,
custom_llm_provider: str,
api_base: Optional[str],
api_key: Optional[str],
):
"""
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
Args:
custom_llm_provider: The provider of the pass-through endpoint
api_base: The base URL of the pass-through endpoint
api_key: The API key for the pass-through endpoint
"""
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=self._get_region_name_from_api_base(
api_base=api_base, custom_llm_provider=custom_llm_provider
),
)
if api_key is None:
raise ValueError("api_key is required for setting pass-through credentials")
self.credentials[credential_name] = api_key
def get_credentials(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> Optional[str]:
credential_name = self._get_credential_name_for_provider(
custom_llm_provider=custom_llm_provider,
region_name=region_name,
)
verbose_router_logger.debug(
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
)
if credential_name in self.credentials:
verbose_router_logger.debug(f"Found credentials for {credential_name}")
return self.credentials[credential_name]
else:
verbose_router_logger.debug(
f"No credentials found for {credential_name}, looking for env variable"
)
_env_variable_name = (
self._get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider=custom_llm_provider,
)
)
return get_secret_str(_env_variable_name)
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
"""
Helper to get vertex pass through config from environment variables
The following environment variables are used:
- DEFAULT_VERTEXAI_PROJECT (project id)
- DEFAULT_VERTEXAI_LOCATION (location)
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
"""
return VertexPassThroughCredentials(
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
)
def set_default_vertex_config(self, config: Optional[dict] = None):
"""Sets vertex configuration from provided config and/or environment variables
Args:
config (Optional[dict]): Configuration dictionary
Example: {
"vertex_project": "my-project-123",
"vertex_location": "us-central1",
"vertex_credentials": "os.environ/GOOGLE_CREDS"
}
"""
# Initialize config dictionary if None
if config is None:
self.default_vertex_config = self._get_vertex_env_vars()
return
if isinstance(config, dict):
for key, value in config.items():
if isinstance(value, str) and value.startswith("os.environ/"):
config[key] = get_secret_str(value)
self.default_vertex_config = VertexPassThroughCredentials(**config)
def add_vertex_credentials(
self,
project_id: str,
location: str,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
):
"""
Add the vertex credentials for the given project-id, location
"""
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
verbose_router_logger.debug(
"No deployment key found for project-id, location"
)
return
vertex_pass_through_credentials = VertexPassThroughCredentials(
vertex_project=project_id,
vertex_location=location,
vertex_credentials=vertex_credentials,
)
self.deployment_key_to_vertex_credentials[
deployment_key
] = vertex_pass_through_credentials
def _get_deployment_key(
self, project_id: Optional[str], location: Optional[str]
) -> Optional[str]:
"""
Get the deployment key for the given project-id, location
"""
if project_id is None or location is None:
return None
return f"{project_id}-{location}"
def get_vector_store_credentials(
self, vector_store_id: str
) -> Optional[LiteLLM_ManagedVectorStore]:
"""
Get the vector store credentials for the given vector store id
"""
if litellm.vector_store_registry is None:
return None
vector_store_to_run: Optional[
LiteLLM_ManagedVectorStore
] = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry(
vector_store_id=vector_store_id
)
return vector_store_to_run
def get_vertex_credentials(
self, project_id: Optional[str], location: Optional[str]
) -> Optional[VertexPassThroughCredentials]:
"""
Get the vertex credentials for the given project-id, location
"""
deployment_key = self._get_deployment_key(
project_id=project_id,
location=location,
)
if deployment_key is None:
return self.default_vertex_config
if deployment_key in self.deployment_key_to_vertex_credentials:
return self.deployment_key_to_vertex_credentials[deployment_key]
else:
return self.default_vertex_config
def _get_credential_name_for_provider(
self,
custom_llm_provider: str,
region_name: Optional[str],
) -> str:
if region_name is None:
return f"{custom_llm_provider.upper()}_API_KEY"
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
def _get_region_name_from_api_base(
self,
custom_llm_provider: str,
api_base: Optional[str],
) -> Optional[str]:
"""
Get the region name from the API base.
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
"""
if custom_llm_provider == "assemblyai":
if api_base and "eu" in api_base:
return "eu"
return None
@staticmethod
def _get_default_env_variable_name_passthrough_endpoint(
custom_llm_provider: str,
) -> str:
return f"{custom_llm_provider.upper()}_API_KEY"

View File

@@ -0,0 +1,335 @@
"""
Passthrough Guardrails Helper Module
Handles guardrail execution for passthrough endpoints with:
- Opt-in model (guardrails only run when explicitly configured)
- Field-level targeting using JSONPath expressions
- Automatic inheritance from org/team/key levels when enabled
"""
from typing import Any, Dict, List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
PassThroughGuardrailsConfig,
PassThroughGuardrailSettings,
UserAPIKeyAuth,
)
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
# Type for raw guardrails config input (before normalization)
# Can be a list of names or a dict with settings
PassThroughGuardrailsConfigInput = Union[
List[str], # Simple list: ["guard-1", "guard-2"]
PassThroughGuardrailsConfig, # Dict: {"guard-1": {"request_fields": [...]}}
]
class PassthroughGuardrailHandler:
"""
Handles guardrail execution for passthrough endpoints.
Passthrough endpoints use an opt-in model for guardrails:
- Guardrails only run when explicitly configured on the endpoint
- Supports field-level targeting using JSONPath expressions
- Automatically inherits org/team/key level guardrails when enabled
Guardrails can be specified as:
- List format (simple): ["guardrail-1", "guardrail-2"]
- Dict format (with settings): {"guardrail-1": {"request_fields": ["query"]}}
"""
@staticmethod
def normalize_config(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> Optional[PassThroughGuardrailsConfig]:
"""
Normalize guardrails config to dict format.
Accepts:
- List of guardrail names: ["g1", "g2"] -> {"g1": None, "g2": None}
- Dict with settings: {"g1": {"request_fields": [...]}}
- None: returns None
"""
if guardrails_config is None:
return None
# Already a dict - return as-is
if isinstance(guardrails_config, dict):
return guardrails_config
# List of guardrail names - convert to dict
if isinstance(guardrails_config, list):
return {name: None for name in guardrails_config}
verbose_proxy_logger.debug(
"Passthrough guardrails config is not a dict or list, got: %s",
type(guardrails_config),
)
return None
@staticmethod
def is_enabled(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> bool:
"""
Check if guardrails are enabled for a passthrough endpoint.
Passthrough endpoints are opt-in only - guardrails only run when
the guardrails config is set with at least one guardrail.
"""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return False
return len(normalized) > 0
@staticmethod
def get_guardrail_names(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> List[str]:
"""Get the list of guardrail names configured for a passthrough endpoint."""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return []
return list(normalized.keys())
@staticmethod
def get_settings(
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
guardrail_name: str,
) -> Optional[PassThroughGuardrailSettings]:
"""Get settings for a specific guardrail from the passthrough config."""
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
if normalized is None:
return None
settings = normalized.get(guardrail_name)
if settings is None:
return None
if isinstance(settings, dict):
return PassThroughGuardrailSettings(**settings)
return settings
@staticmethod
def prepare_input(
request_data: dict,
guardrail_settings: Optional[PassThroughGuardrailSettings],
) -> str:
"""
Prepare input text for guardrail execution based on field targeting settings.
If request_fields is specified, extracts only those fields.
Otherwise, uses the entire request payload as text.
"""
if guardrail_settings is None or guardrail_settings.request_fields is None:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
return safe_dumps(request_data)
return JsonPathExtractor.extract_fields(
data=request_data,
jsonpath_expressions=guardrail_settings.request_fields,
)
@staticmethod
def prepare_output(
response_data: dict,
guardrail_settings: Optional[PassThroughGuardrailSettings],
) -> str:
"""
Prepare output text for guardrail execution based on field targeting settings.
If response_fields is specified, extracts only those fields.
Otherwise, uses the entire response payload as text.
"""
if guardrail_settings is None or guardrail_settings.response_fields is None:
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
return safe_dumps(response_data)
return JsonPathExtractor.extract_fields(
data=response_data,
jsonpath_expressions=guardrail_settings.response_fields,
)
@staticmethod
async def execute(
request_data: dict,
user_api_key_dict: UserAPIKeyAuth,
guardrails_config: Optional[PassThroughGuardrailsConfig],
event_type: str = "pre_call",
) -> dict:
"""
Execute guardrails for a passthrough endpoint.
This is the main entry point for passthrough guardrail execution.
Args:
request_data: The request payload
user_api_key_dict: User API key authentication info
guardrails_config: Passthrough-specific guardrails configuration
event_type: "pre_call" for request, "post_call" for response
Returns:
The potentially modified request_data
Raises:
HTTPException if a guardrail blocks the request
"""
if not PassthroughGuardrailHandler.is_enabled(guardrails_config):
verbose_proxy_logger.debug(
"Passthrough guardrails not enabled, skipping guardrail execution"
)
return request_data
guardrail_names = PassthroughGuardrailHandler.get_guardrail_names(
guardrails_config
)
verbose_proxy_logger.debug(
"Executing passthrough guardrails: %s", guardrail_names
)
# Add to request metadata so guardrails know which to run
from litellm.proxy.pass_through_endpoints.passthrough_context import (
set_passthrough_guardrails_config,
)
if "metadata" not in request_data:
request_data["metadata"] = {}
# Set guardrails in metadata using dict format for compatibility
request_data["metadata"]["guardrails"] = {
name: True for name in guardrail_names
}
# Store passthrough guardrails config in request-scoped context
set_passthrough_guardrails_config(guardrails_config)
return request_data
@staticmethod
def collect_guardrails(
user_api_key_dict: UserAPIKeyAuth,
passthrough_guardrails_config: Optional[PassThroughGuardrailsConfigInput],
) -> Optional[Dict[str, bool]]:
"""
Collect guardrails for a passthrough endpoint.
Passthrough endpoints are opt-in only for guardrails. Guardrails only run when
the guardrails config is set with at least one guardrail.
Accepts both list and dict formats:
- List: ["guardrail-1", "guardrail-2"]
- Dict: {"guardrail-1": {"request_fields": [...]}}
When enabled, this function collects:
- Passthrough-specific guardrails from the config
- Org/team/key level guardrails (automatic inheritance when passthrough is enabled)
Args:
user_api_key_dict: User API key authentication info
passthrough_guardrails_config: List or Dict of guardrail names/settings
Returns:
Dict of guardrail names to run (format: {guardrail_name: True}), or None
"""
from litellm.proxy.litellm_pre_call_utils import (
_add_guardrails_from_key_or_team_metadata,
)
# Normalize config to dict format (handles both list and dict)
normalized_config = PassthroughGuardrailHandler.normalize_config(
passthrough_guardrails_config
)
if normalized_config is None:
verbose_proxy_logger.debug(
"Passthrough guardrails not configured, skipping guardrail collection"
)
return None
if len(normalized_config) == 0:
verbose_proxy_logger.debug(
"Passthrough guardrails config is empty, skipping"
)
return None
# Passthrough is enabled - collect guardrails
guardrails_to_run: Dict[str, bool] = {}
# Add passthrough-specific guardrails
for guardrail_name in normalized_config.keys():
guardrails_to_run[guardrail_name] = True
verbose_proxy_logger.debug(
"Added passthrough-specific guardrail: %s", guardrail_name
)
# Add org/team/key level guardrails using shared helper
temp_data: Dict[str, Any] = {"metadata": {}}
_add_guardrails_from_key_or_team_metadata(
key_metadata=user_api_key_dict.metadata,
team_metadata=user_api_key_dict.team_metadata,
data=temp_data,
metadata_variable_name="metadata",
)
# Merge inherited guardrails into guardrails_to_run
inherited_guardrails = temp_data["metadata"].get("guardrails", [])
for guardrail_name in inherited_guardrails:
if guardrail_name not in guardrails_to_run:
guardrails_to_run[guardrail_name] = True
verbose_proxy_logger.debug(
"Added inherited guardrail (key/team level): %s", guardrail_name
)
verbose_proxy_logger.debug(
"Collected guardrails for passthrough endpoint: %s",
list(guardrails_to_run.keys()),
)
return guardrails_to_run if guardrails_to_run else None
@staticmethod
def get_field_targeted_text(
data: dict,
guardrail_name: str,
is_request: bool = True,
) -> Optional[str]:
"""
Get the text to check for a guardrail, respecting field targeting settings.
Called by guardrail hooks to get the appropriate text based on
passthrough field targeting configuration.
Args:
data: The request/response data dict
guardrail_name: Name of the guardrail being executed
is_request: True for request (pre_call), False for response (post_call)
Returns:
The text to check, or None to use default behavior
"""
from litellm.proxy.pass_through_endpoints.passthrough_context import (
get_passthrough_guardrails_config,
)
passthrough_config = get_passthrough_guardrails_config()
if passthrough_config is None:
return None
settings = PassthroughGuardrailHandler.get_settings(
passthrough_config, guardrail_name
)
if settings is None:
return None
if is_request:
if settings.request_fields:
return JsonPathExtractor.extract_fields(data, settings.request_fields)
else:
if settings.response_fields:
return JsonPathExtractor.extract_fields(data, settings.response_fields)
return None

View File

@@ -0,0 +1,248 @@
import asyncio
from datetime import datetime
from typing import List, Optional
import httpx
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.thread_pool_executor import executor
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
from litellm.types.utils import StandardPassThroughResponseObject
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.openai_passthrough_logging_handler import (
OpenAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
from .success_handler import PassThroughEndpointLogging
class PassThroughStreamingHandler:
@staticmethod
async def chunk_processor(
response: httpx.Response,
request_body: Optional[dict],
litellm_logging_obj: LiteLLMLoggingObj,
endpoint_type: EndpointType,
start_time: datetime,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
):
"""
- Yields chunks from the response
- Collect non-empty chunks for post-processing (logging)
- Inject cost into chunks if include_cost_in_streaming_usage is enabled
"""
try:
raw_bytes: List[bytes] = []
# Extract model name for cost injection
model_name = PassThroughStreamingHandler._extract_model_for_cost_injection(
request_body=request_body,
url_route=url_route,
endpoint_type=endpoint_type,
litellm_logging_obj=litellm_logging_obj,
)
async for chunk in response.aiter_bytes():
raw_bytes.append(chunk)
if (
getattr(litellm, "include_cost_in_streaming_usage", False)
and model_name
):
if endpoint_type == EndpointType.VERTEX_AI:
# Only handle streamRawPredict (uses Anthropic format)
if "streamRawPredict" in url_route or "rawPredict" in url_route:
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
chunk, model_name
)
if modified_chunk is not None:
chunk = modified_chunk
elif endpoint_type == EndpointType.ANTHROPIC:
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
chunk, model_name
)
if modified_chunk is not None:
chunk = modified_chunk
yield chunk
# After all chunks are processed, handle post-processing
end_time = datetime.now()
asyncio.create_task(
PassThroughStreamingHandler._route_streaming_logging_to_handler(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body or {},
endpoint_type=endpoint_type,
start_time=start_time,
raw_bytes=raw_bytes,
end_time=end_time,
)
)
except Exception as e:
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
raise
@staticmethod
async def _route_streaming_logging_to_handler(
litellm_logging_obj: LiteLLMLoggingObj,
passthrough_success_handler_obj: PassThroughEndpointLogging,
url_route: str,
request_body: dict,
endpoint_type: EndpointType,
start_time: datetime,
raw_bytes: List[bytes],
end_time: datetime,
model: Optional[str] = None,
):
"""
Route the logging for the collected chunks to the appropriate handler
Supported endpoint types:
- Anthropic
- Vertex AI
- OpenAI
"""
try:
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
raw_bytes
)
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
kwargs: dict = {}
if endpoint_type == EndpointType.ANTHROPIC:
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.VERTEX_AI:
vertex_passthrough_logging_handler_result = VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
model=model,
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif endpoint_type == EndpointType.OPENAI:
openai_passthrough_logging_handler_result = OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
litellm_logging_obj=litellm_logging_obj,
passthrough_success_handler_obj=passthrough_success_handler_obj,
url_route=url_route,
request_body=request_body,
endpoint_type=endpoint_type,
start_time=start_time,
all_chunks=all_chunks,
end_time=end_time,
)
standard_logging_response_object = (
openai_passthrough_logging_handler_result["result"]
)
kwargs = openai_passthrough_logging_handler_result["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
)
await litellm_logging_obj.async_success_handler(
result=standard_logging_response_object,
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
if (
litellm_logging_obj._should_run_sync_callbacks_for_async_calls()
is False
):
return
executor.submit(
litellm_logging_obj.success_handler,
result=standard_logging_response_object,
end_time=end_time,
cache_hit=False,
start_time=start_time,
**kwargs,
)
except Exception as e:
verbose_proxy_logger.error(
f"Error in _route_streaming_logging_to_handler: {str(e)}"
)
@staticmethod
def _extract_model_for_cost_injection(
request_body: Optional[dict],
url_route: str,
endpoint_type: EndpointType,
litellm_logging_obj: LiteLLMLoggingObj,
) -> Optional[str]:
"""
Extract model name for cost injection from various sources.
"""
# Try to get model from request body
if request_body:
model = request_body.get("model")
if model:
return model
# Try to get model from logging object
if hasattr(litellm_logging_obj, "model_call_details"):
model = litellm_logging_obj.model_call_details.get("model")
if model:
return model
# For Vertex AI, try to extract from URL
if endpoint_type == EndpointType.VERTEX_AI:
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
if model and model != "unknown":
return model
return None
@staticmethod
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
"""
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
Args:
raw_bytes: List of bytes chunks from aiter.bytes()
Returns:
List of string lines, with each line being a complete data: {} chunk
"""
# Combine all bytes and decode to string
combined_str = b"".join(raw_bytes).decode("utf-8")
# Split by newlines and filter out empty lines
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
return lines

View File

@@ -0,0 +1,494 @@
import json
from datetime import datetime
from typing import Any, Optional, Union
from urllib.parse import urlparse
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
PassthroughStandardLoggingPayload,
)
from litellm.types.utils import StandardPassThroughResponseObject
from litellm.utils import executor as thread_pool_executor
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
AnthropicPassthroughLoggingHandler,
)
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
AssemblyAIPassthroughLoggingHandler,
)
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
CoherePassthroughLoggingHandler,
)
from .llm_provider_handlers.cursor_passthrough_logging_handler import (
CursorPassthroughLoggingHandler,
)
from .llm_provider_handlers.gemini_passthrough_logging_handler import (
GeminiPassthroughLoggingHandler,
)
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
VertexPassthroughLoggingHandler,
)
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
class PassThroughEndpointLogging:
def __init__(self):
self.TRACKED_VERTEX_ROUTES = [
"generateContent",
"streamGenerateContent",
"predict",
"rawPredict",
"streamRawPredict",
"search",
"batchPredictionJobs",
"predictLongRunning",
]
# Anthropic
self.TRACKED_ANTHROPIC_ROUTES = ["/messages", "/v1/messages/batches"]
# Cohere
self.TRACKED_COHERE_ROUTES = ["/v2/chat", "/v1/embed"]
self.assemblyai_passthrough_logging_handler = (
AssemblyAIPassthroughLoggingHandler()
)
# Langfuse
self.TRACKED_LANGFUSE_ROUTES = ["/langfuse/"]
# Gemini
self.TRACKED_GEMINI_ROUTES = [
"generateContent",
"streamGenerateContent",
"predictLongRunning",
]
# Cursor Cloud Agents
self.TRACKED_CURSOR_ROUTES = [
"/v0/agents",
"/v0/me",
"/v0/models",
"/v0/repositories",
]
# Vertex AI Live API WebSocket
self.TRACKED_VERTEX_AI_LIVE_ROUTES = ["/vertex_ai/live"]
async def _handle_logging(
self,
logging_obj: LiteLLMLoggingObj,
standard_logging_response_object: Union[
StandardPassThroughResponseObject,
PassThroughEndpointLoggingResultValues,
dict,
],
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
**kwargs,
):
"""Helper function to handle both sync and async logging operations"""
# Submit to thread pool for sync logging
thread_pool_executor.submit(
logging_obj.success_handler,
standard_logging_response_object,
start_time,
end_time,
cache_hit,
**kwargs,
)
# Handle async logging
await logging_obj.async_success_handler(
result=(
json.dumps(result)
if isinstance(result, dict)
else standard_logging_response_object
),
start_time=start_time,
end_time=end_time,
cache_hit=False,
**kwargs,
)
def normalize_llm_passthrough_logging_payload(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
request_body: dict,
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
return_dict = {
"standard_logging_response_object": None,
"kwargs": kwargs,
}
standard_logging_response_object: Optional[Any] = None
if self.is_gemini_route(url_route, custom_llm_provider):
gemini_passthrough_logging_handler_result = (
GeminiPassthroughLoggingHandler.gemini_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
gemini_passthrough_logging_handler_result["result"]
)
kwargs = gemini_passthrough_logging_handler_result["kwargs"]
elif self.is_vertex_route(url_route):
vertex_passthrough_logging_handler_result = (
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
httpx_response=httpx_response,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
vertex_passthrough_logging_handler_result["result"]
)
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
elif self.is_anthropic_route(url_route):
anthropic_passthrough_logging_handler_result = (
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
anthropic_passthrough_logging_handler_result["result"]
)
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
elif self.is_cohere_route(url_route):
cohere_passthrough_logging_handler_result = (
cohere_passthrough_logging_handler.cohere_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cohere_passthrough_logging_handler_result["result"]
)
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
elif self.is_openai_route(url_route) and self._is_supported_openai_endpoint(
url_route
):
from .llm_provider_handlers.openai_passthrough_logging_handler import (
OpenAIPassthroughLoggingHandler,
)
openai_passthrough_logging_handler_result = (
OpenAIPassthroughLoggingHandler.openai_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
openai_passthrough_logging_handler_result["result"]
)
kwargs = openai_passthrough_logging_handler_result["kwargs"]
elif self.is_cursor_route(url_route, custom_llm_provider):
cursor_passthrough_logging_handler_result = (
CursorPassthroughLoggingHandler.cursor_passthrough_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = (
cursor_passthrough_logging_handler_result["result"]
)
kwargs = cursor_passthrough_logging_handler_result["kwargs"]
elif self.is_vertex_ai_live_route(url_route):
from .llm_provider_handlers.vertex_ai_live_passthrough_logging_handler import (
VertexAILivePassthroughLoggingHandler,
)
vertex_ai_live_handler = VertexAILivePassthroughLoggingHandler()
# For WebSocket responses, response_body should be a list of messages
websocket_messages: list[dict[str, Any]] = (
response_body if isinstance(response_body, list) else []
)
vertex_ai_live_handler_result = (
vertex_ai_live_handler.vertex_ai_live_passthrough_handler(
websocket_messages=websocket_messages,
logging_obj=logging_obj,
url_route=url_route,
start_time=start_time,
end_time=end_time,
request_body=request_body,
**kwargs,
)
)
standard_logging_response_object = vertex_ai_live_handler_result["result"]
kwargs = vertex_ai_live_handler_result["kwargs"]
return_dict[
"standard_logging_response_object"
] = standard_logging_response_object
return_dict["kwargs"] = kwargs
return return_dict
async def pass_through_async_success_handler(
self,
httpx_response: httpx.Response,
response_body: Optional[dict],
logging_obj: LiteLLMLoggingObj,
url_route: str,
result: str,
start_time: datetime,
end_time: datetime,
cache_hit: bool,
request_body: dict,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
standard_logging_response_object: Optional[
PassThroughEndpointLoggingResultValues
] = None
logging_obj.model_call_details[
"passthrough_logging_payload"
] = passthrough_logging_payload
if self.is_assemblyai_route(url_route):
if (
AssemblyAIPassthroughLoggingHandler._should_log_request(
httpx_response.request.method
)
is not True
):
return
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
httpx_response=httpx_response,
response_body=response_body or {},
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
**kwargs,
)
return
elif self.is_langfuse_route(url_route):
# Don't log langfuse pass-through requests
return
else:
normalized_llm_passthrough_logging_payload = (
self.normalize_llm_passthrough_logging_payload(
httpx_response=httpx_response,
response_body=response_body,
request_body=request_body,
logging_obj=logging_obj,
url_route=url_route,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
standard_logging_response_object = (
normalized_llm_passthrough_logging_payload[
"standard_logging_response_object"
]
)
kwargs = normalized_llm_passthrough_logging_payload["kwargs"]
if standard_logging_response_object is None:
standard_logging_response_object = StandardPassThroughResponseObject(
response=httpx_response.text
)
kwargs = self._set_cost_per_request(
logging_obj=logging_obj,
passthrough_logging_payload=passthrough_logging_payload,
kwargs=kwargs,
)
await self._handle_logging(
logging_obj=logging_obj,
standard_logging_response_object=standard_logging_response_object,
result=result,
start_time=start_time,
end_time=end_time,
cache_hit=cache_hit,
standard_pass_through_logging_payload=passthrough_logging_payload,
**kwargs,
)
def is_vertex_route(self, url_route: str):
for route in self.TRACKED_VERTEX_ROUTES:
if route in url_route:
return True
return False
def is_anthropic_route(self, url_route: str):
for route in self.TRACKED_ANTHROPIC_ROUTES:
if route in url_route:
return True
return False
def is_cohere_route(self, url_route: str):
for route in self.TRACKED_COHERE_ROUTES:
if route in url_route:
return True
def is_assemblyai_route(self, url_route: str):
parsed_url = urlparse(url_route)
if parsed_url.hostname == "api.assemblyai.com":
return True
elif "/transcript" in parsed_url.path:
return True
return False
def is_langfuse_route(self, url_route: str):
parsed_url = urlparse(url_route)
for route in self.TRACKED_LANGFUSE_ROUTES:
if route in parsed_url.path:
return True
return False
def is_vertex_ai_live_route(self, url_route: str):
"""Check if the URL route is a Vertex AI Live API WebSocket route."""
if not url_route:
return False
for route in self.TRACKED_VERTEX_AI_LIVE_ROUTES:
if route in url_route:
return True
return False
def is_cursor_route(
self, url_route: str, custom_llm_provider: Optional[str] = None
):
"""Check if the URL route is a Cursor Cloud Agents API route."""
if custom_llm_provider == "cursor":
return True
parsed_url = urlparse(url_route)
if parsed_url.hostname and "api.cursor.com" in parsed_url.hostname:
return True
for route in self.TRACKED_CURSOR_ROUTES:
if route in url_route:
path = parsed_url.path if parsed_url.scheme else url_route
if path.startswith("/v0/"):
return custom_llm_provider == "cursor"
return False
def is_openai_route(self, url_route: str):
"""Check if the URL route is an OpenAI API route."""
if not url_route:
return False
parsed_url = urlparse(url_route)
return parsed_url.hostname and (
"api.openai.com" in parsed_url.hostname
or "openai.azure.com" in parsed_url.hostname
)
def is_gemini_route(
self, url_route: str, custom_llm_provider: Optional[str] = None
):
"""Check if the URL route is a Gemini API route."""
for route in self.TRACKED_GEMINI_ROUTES:
if route in url_route and custom_llm_provider == "gemini":
return True
return False
def _is_supported_openai_endpoint(self, url_route: str) -> bool:
"""Check if the OpenAI endpoint is supported by the passthrough logging handler."""
from .llm_provider_handlers.openai_passthrough_logging_handler import (
OpenAIPassthroughLoggingHandler,
)
return (
OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route(url_route)
or OpenAIPassthroughLoggingHandler.is_openai_image_generation_route(
url_route
)
or OpenAIPassthroughLoggingHandler.is_openai_image_editing_route(url_route)
)
def _set_cost_per_request(
self,
logging_obj: LiteLLMLoggingObj,
passthrough_logging_payload: PassthroughStandardLoggingPayload,
kwargs: dict,
):
"""
Helper function to set the cost per request in the logging object
Only set the cost per request if it's set in the passthrough logging payload.
If it's not set, don't set it in the logging object.
"""
#########################################################
# Check if cost per request is set
#########################################################
if passthrough_logging_payload.get("cost_per_request") is not None:
kwargs["response_cost"] = passthrough_logging_payload.get(
"cost_per_request"
)
logging_obj.model_call_details[
"response_cost"
] = passthrough_logging_payload.get("cost_per_request")
return kwargs