chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,71 @@
|
||||
# Pass-Through Endpoints Architecture
|
||||
|
||||
## Why Pass-Through Endpoints Transform Requests
|
||||
|
||||
Even "pass-through" endpoints must perform essential transformations. The request **body** passes through unchanged, but:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client
|
||||
participant Proxy as LiteLLM Proxy
|
||||
participant Provider as LLM Provider
|
||||
|
||||
Client->>Proxy: POST /vertex_ai/v1/projects/.../generateContent
|
||||
Note over Client,Proxy: Headers: Authorization: Bearer sk-litellm-key
|
||||
Note over Client,Proxy: Body: { "contents": [...] }
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Proxy: 1. URL Construction
|
||||
Note over Proxy: Build regional/provider-specific URL
|
||||
end
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Proxy: 2. Auth Header Replacement
|
||||
Note over Proxy: LiteLLM key → provider credentials
|
||||
end
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Proxy: 3. Extra Operations
|
||||
Note over Proxy: • x-pass-* headers (strip prefix, forward)
|
||||
Note over Proxy: • x-litellm-tags → metadata
|
||||
Note over Proxy: • Guardrails (opt-in)
|
||||
Note over Proxy: • Multipart form reconstruction
|
||||
end
|
||||
|
||||
Proxy->>Provider: POST https://us-central1-aiplatform.googleapis.com/...
|
||||
Note over Proxy,Provider: Headers: Authorization: Bearer ya29.google-oauth...
|
||||
Note over Proxy,Provider: Body: { "contents": [...] } ← UNCHANGED
|
||||
|
||||
Provider-->>Proxy: Response (streaming or non-streaming)
|
||||
|
||||
rect rgb(240, 240, 240)
|
||||
Note over Proxy: 4. Response Handling (async)
|
||||
Note over Proxy: • Collect streaming chunks for logging
|
||||
Note over Proxy: • Cost injection (if enabled)
|
||||
Note over Proxy: • Parse response → calculate cost → log
|
||||
end
|
||||
|
||||
Proxy-->>Client: Response (unchanged)
|
||||
```
|
||||
|
||||
## Essential Transformations
|
||||
|
||||
- **URL Construction** - Build correct provider URL (e.g., regional endpoints for Vertex AI, Bedrock)
|
||||
- **Auth Header Replacement** - Swap LiteLLM virtual key for actual provider credentials
|
||||
|
||||
## Extra Operations
|
||||
|
||||
| Operation | Description |
|
||||
|-----------|-------------|
|
||||
| `x-pass-*` headers | Strip prefix and forward (e.g., `x-pass-anthropic-beta` → `anthropic-beta`) |
|
||||
| `x-litellm-tags` header | Extract tags and add to request metadata for logging |
|
||||
| Streaming chunk collection | Collect chunks async for logging after stream completes |
|
||||
| Multipart form handling | Reconstruct multipart/form-data requests for file uploads |
|
||||
| Guardrails (opt-in) | Run content filtering when explicitly configured |
|
||||
| Cost injection | Inject cost into streaming chunks when `include_cost_in_streaming_usage` enabled |
|
||||
|
||||
## What Does NOT Change
|
||||
|
||||
- Request body
|
||||
- Response body
|
||||
- Provider-specific parameters
|
||||
@@ -0,0 +1,16 @@
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def get_litellm_virtual_key(request: Request) -> str:
|
||||
"""
|
||||
Extract and format API key from request headers.
|
||||
Prioritizes x-litellm-api-key over Authorization header.
|
||||
|
||||
|
||||
Vertex JS SDK uses `Authorization` header, we use `x-litellm-api-key` to pass litellm virtual key
|
||||
|
||||
"""
|
||||
litellm_api_key = request.headers.get("x-litellm-api-key")
|
||||
if litellm_api_key:
|
||||
return f"Bearer {litellm_api_key}"
|
||||
return request.headers.get("Authorization", "")
|
||||
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
JSONPath Extractor Module
|
||||
|
||||
Extracts field values from data using simple JSONPath-like expressions.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
class JsonPathExtractor:
|
||||
"""Extracts field values from data using JSONPath-like expressions."""
|
||||
|
||||
@staticmethod
|
||||
def extract_fields(
|
||||
data: dict,
|
||||
jsonpath_expressions: List[str],
|
||||
) -> str:
|
||||
"""
|
||||
Extract field values from data using JSONPath-like expressions.
|
||||
|
||||
Supports simple expressions like:
|
||||
- "query" -> data["query"]
|
||||
- "documents[*].text" -> all text fields from documents array
|
||||
- "messages[*].content" -> all content fields from messages array
|
||||
|
||||
Returns concatenated string of all extracted values.
|
||||
"""
|
||||
extracted_values: List[str] = []
|
||||
|
||||
for expr in jsonpath_expressions:
|
||||
try:
|
||||
value = JsonPathExtractor.evaluate(data, expr)
|
||||
if value:
|
||||
if isinstance(value, list):
|
||||
extracted_values.extend([str(v) for v in value if v])
|
||||
else:
|
||||
extracted_values.append(str(value))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"Failed to extract field %s: %s", expr, str(e)
|
||||
)
|
||||
|
||||
return "\n".join(extracted_values)
|
||||
|
||||
@staticmethod
|
||||
def evaluate(data: dict, expr: str) -> Union[str, List[str], None]:
|
||||
"""
|
||||
Evaluate a simple JSONPath-like expression.
|
||||
|
||||
Supports:
|
||||
- Simple key: "query" -> data["query"]
|
||||
- Nested key: "foo.bar" -> data["foo"]["bar"]
|
||||
- Array wildcard: "items[*].text" -> [item["text"] for item in data["items"]]
|
||||
"""
|
||||
if not expr or not data:
|
||||
return None
|
||||
|
||||
parts = expr.replace("[*]", ".[*]").split(".")
|
||||
current: Any = data
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
if part == "[*]":
|
||||
# Wildcard - current should be a list
|
||||
if not isinstance(current, list):
|
||||
return None
|
||||
|
||||
# Get remaining path
|
||||
remaining_path = ".".join(parts[i + 1 :])
|
||||
if not remaining_path:
|
||||
return current
|
||||
|
||||
# Recursively evaluate remaining path for each item
|
||||
results = []
|
||||
for item in current:
|
||||
if isinstance(item, dict):
|
||||
result = JsonPathExtractor.evaluate(item, remaining_path)
|
||||
if result:
|
||||
if isinstance(result, list):
|
||||
results.extend(result)
|
||||
else:
|
||||
results.append(result)
|
||||
return results if results else None
|
||||
|
||||
elif isinstance(current, dict):
|
||||
current = current.get(part)
|
||||
else:
|
||||
return None
|
||||
|
||||
return current
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,619 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.anthropic import get_anthropic_config
|
||||
from litellm.llms.anthropic.chat.handler import (
|
||||
ModelResponseIterator as AnthropicModelResponseIterator,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch, ModelResponse, TextCompletionResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
|
||||
class AnthropicPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
def anthropic_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transforms Anthropic response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||
"""
|
||||
# Check if this is a batch creation request
|
||||
if "/v1/messages/batches" in url_route and httpx_response.status_code == 200:
|
||||
# Get request body from parameter or kwargs
|
||||
request_body = request_body or kwargs.get("request_body", {})
|
||||
return AnthropicPassthroughLoggingHandler.batch_creation_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model = response_body.get("model", "")
|
||||
anthropic_config = get_anthropic_config(url_route)
|
||||
litellm_model_response: ModelResponse = anthropic_config().transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=[],
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _get_user_from_metadata(
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
) -> Optional[str]:
|
||||
request_body = passthrough_logging_payload.get("request_body")
|
||||
if request_body:
|
||||
return get_end_user_id_from_request_body(request_body)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _create_anthropic_response_logging_payload(
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
):
|
||||
"""
|
||||
Create the standard logging object for Anthropic passthrough
|
||||
|
||||
handles streaming and non-streaming responses
|
||||
"""
|
||||
try:
|
||||
# Get custom_llm_provider from logging object if available (e.g., azure_ai for Azure Anthropic)
|
||||
custom_llm_provider = logging_obj.model_call_details.get(
|
||||
"custom_llm_provider"
|
||||
)
|
||||
|
||||
# Prepend custom_llm_provider to model if not already present
|
||||
model_for_cost = model
|
||||
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
|
||||
model_for_cost = f"{custom_llm_provider}/{model}"
|
||||
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model_for_cost,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
if passthrough_logging_payload:
|
||||
user = AnthropicPassthroughLoggingHandler._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs.setdefault("litellm_params", {})
|
||||
kwargs["litellm_params"].update(
|
||||
{"proxy_server_request": {"body": {"user": user}}}
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"kwargs= %s",
|
||||
json.dumps(kwargs, indent=4, default=str),
|
||||
)
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
litellm_model_response.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
if not logging_obj.model_call_details.get("custom_llm_provider"):
|
||||
logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = litellm.LlmProviders.ANTHROPIC.value
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error creating Anthropic response logging payload: %s", e
|
||||
)
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
|
||||
model = request_body.get("model", "")
|
||||
# Check if it's available in the logging object
|
||||
if (
|
||||
not model
|
||||
and hasattr(litellm_logging_obj, "model_call_details")
|
||||
and litellm_logging_obj.model_call_details.get("model")
|
||||
):
|
||||
model = cast(str, litellm_logging_obj.model_call_details.get("model"))
|
||||
|
||||
complete_streaming_response = (
|
||||
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _split_sse_chunk_into_events(chunk: Union[str, bytes]) -> List[str]:
|
||||
"""
|
||||
Split a chunk that may contain multiple SSE events into individual events.
|
||||
|
||||
SSE format: "event: type\ndata: {...}\n\n"
|
||||
Multiple events in a single chunk are separated by double newlines.
|
||||
|
||||
Args:
|
||||
chunk: Raw chunk string that may contain multiple SSE events
|
||||
|
||||
Returns:
|
||||
List of individual SSE event strings (each containing "event: X\ndata: {...}")
|
||||
"""
|
||||
# Handle bytes input
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
|
||||
# Split on double newlines to separate SSE events
|
||||
# Filter out empty strings
|
||||
events = [event.strip() for event in chunk.split("\n\n") if event.strip()]
|
||||
|
||||
return events
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
all_chunks: Sequence[Union[str, bytes]],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw Anthropic chunks
|
||||
|
||||
- Splits multi-event chunks into individual SSE events
|
||||
- Converts str chunks to generic chunks
|
||||
- Converts generic chunks to litellm chunks (OpenAI format)
|
||||
- Builds complete response from litellm chunks
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Building complete streaming response from %d chunks", len(all_chunks)
|
||||
)
|
||||
anthropic_model_response_iterator = AnthropicModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
all_openai_chunks = []
|
||||
|
||||
# Process each chunk - a chunk may contain multiple SSE events
|
||||
for _chunk_str in all_chunks:
|
||||
# Split chunk into individual SSE events
|
||||
individual_events = (
|
||||
AnthropicPassthroughLoggingHandler._split_sse_chunk_into_events(
|
||||
_chunk_str
|
||||
)
|
||||
)
|
||||
|
||||
# Process each individual event
|
||||
for event_str in individual_events:
|
||||
try:
|
||||
transformed_openai_chunk = anthropic_model_response_iterator.convert_str_chunk_to_generic_chunk(
|
||||
chunk=event_str
|
||||
)
|
||||
if transformed_openai_chunk is not None:
|
||||
all_openai_chunks.append(transformed_openai_chunk)
|
||||
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
break
|
||||
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Complete streaming response built: %s", complete_streaming_response
|
||||
)
|
||||
return complete_streaming_response
|
||||
|
||||
@staticmethod
|
||||
def batch_creation_handler( # noqa: PLR0915
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle Anthropic batch creation passthrough logging.
|
||||
Creates a managed object for cost tracking when batch job is successfully created.
|
||||
"""
|
||||
import base64
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from litellm.llms.anthropic.batches.transformation import (
|
||||
AnthropicBatchesConfig,
|
||||
)
|
||||
from litellm.types.utils import Choices, SpecialEnums
|
||||
|
||||
try:
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
# Only handle successful batch job creation (POST requests with 201 status)
|
||||
if httpx_response.status_code == 200 and "id" in _json_response:
|
||||
# Transform Anthropic response to LiteLLM batch format
|
||||
anthropic_batches_config = AnthropicBatchesConfig()
|
||||
litellm_batch_response = (
|
||||
anthropic_batches_config.transform_retrieve_batch_response(
|
||||
model=None,
|
||||
raw_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
litellm_params={},
|
||||
)
|
||||
)
|
||||
# Set status to "validating" for newly created batches so polling mechanism picks them up
|
||||
# The polling mechanism only looks for status="validating" jobs
|
||||
litellm_batch_response.status = "validating"
|
||||
|
||||
# Extract batch ID from the response
|
||||
batch_id = _json_response.get("id", "")
|
||||
|
||||
# Get model from request body (batch response doesn't include model)
|
||||
request_body = request_body or {}
|
||||
# Try to extract model from the batch request body, supporting Anthropic's nested structure
|
||||
model_name: str = "unknown"
|
||||
if isinstance(request_body, dict):
|
||||
# Standard: {"model": ...}
|
||||
model_name = request_body.get("model") or "unknown"
|
||||
if model_name == "unknown":
|
||||
# Anthropic batches: look under requests[0].params.model
|
||||
requests_list = request_body.get("requests", [])
|
||||
if isinstance(requests_list, list) and len(requests_list) > 0:
|
||||
first_req = requests_list[0]
|
||||
if isinstance(first_req, dict):
|
||||
params = first_req.get("params", {})
|
||||
if isinstance(params, dict):
|
||||
extracted_model = params.get("model")
|
||||
if extracted_model:
|
||||
model_name = extracted_model
|
||||
|
||||
# Create unified object ID for tracking
|
||||
# Format: base64(litellm_proxy;model_id:{};llm_batch_id:{})
|
||||
# For Anthropic passthrough, prefix model with "anthropic/" so router can determine provider
|
||||
actual_model_id = (
|
||||
AnthropicPassthroughLoggingHandler.get_actual_model_id_from_router(
|
||||
model_name
|
||||
)
|
||||
)
|
||||
|
||||
# If model not in router, use "anthropic/{model_name}" format so router can determine provider
|
||||
if actual_model_id == model_name and not actual_model_id.startswith(
|
||||
"anthropic/"
|
||||
):
|
||||
actual_model_id = f"anthropic/{model_name}"
|
||||
|
||||
unified_id_string = (
|
||||
SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
|
||||
actual_model_id, batch_id
|
||||
)
|
||||
)
|
||||
unified_object_id = (
|
||||
base64.urlsafe_b64encode(unified_id_string.encode())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Store the managed object for cost tracking
|
||||
# This will be picked up by check_batch_cost polling mechanism
|
||||
AnthropicPassthroughLoggingHandler._store_batch_managed_object(
|
||||
unified_object_id=unified_object_id,
|
||||
batch_object=litellm_batch_response,
|
||||
model_object_id=batch_id,
|
||||
logging_obj=logging_obj,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create a batch job response for logging
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = model_name
|
||||
litellm_model_response.object = "batch"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add batch-specific metadata to indicate this is a pending batch job
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Batch job {batch_id} created and is pending. Status will be updated when the batch completes.",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_id": batch_id,
|
||||
"batch_job_state": "in_progress",
|
||||
"unified_object_id": unified_object_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
# Set response cost to 0 initially (will be updated when batch completes)
|
||||
response_cost = 0.0
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model_name
|
||||
kwargs["batch_id"] = batch_id
|
||||
kwargs["unified_object_id"] = unified_object_id
|
||||
kwargs["batch_job_state"] = "in_progress"
|
||||
|
||||
logging_obj.model = model_name
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
logging_obj.model_call_details["batch_id"] = batch_id
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
else:
|
||||
# Handle non-successful responses
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = "anthropic_batch"
|
||||
litellm_model_response.object = "batch"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add error-specific metadata
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Batch job creation failed. Status: {httpx_response.status_code}",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_state": "failed",
|
||||
"status_code": httpx_response.status_code,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
kwargs["response_cost"] = 0.0
|
||||
kwargs["model"] = "anthropic_batch"
|
||||
kwargs["batch_job_state"] = "failed"
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error in batch_creation_handler: {e}")
|
||||
# Return basic response on error
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = "anthropic_batch"
|
||||
litellm_model_response.object = "batch"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add error-specific metadata
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Error creating batch job: {str(e)}",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_state": "failed",
|
||||
"error": str(e),
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
kwargs["response_cost"] = 0.0
|
||||
kwargs["model"] = "anthropic_batch"
|
||||
kwargs["batch_job_state"] = "failed"
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _store_batch_managed_object(
|
||||
unified_object_id: str,
|
||||
batch_object: LiteLLMBatch,
|
||||
model_object_id: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Store batch managed object for cost tracking.
|
||||
This will be picked up by the check_batch_cost polling mechanism.
|
||||
"""
|
||||
try:
|
||||
# Get the managed files hook from the logging object
|
||||
# This is a bit of a hack, but we need access to the proxy logging system
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
managed_files_hook = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
if managed_files_hook is not None and hasattr(
|
||||
managed_files_hook, "store_unified_object_id"
|
||||
):
|
||||
# Create a mock user API key dict for the managed object storage
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
user_id=kwargs.get("user_id", "default-user"),
|
||||
api_key="",
|
||||
team_id=None,
|
||||
team_alias=None,
|
||||
user_role=LitellmUserRoles.CUSTOMER, # Use proper enum value
|
||||
user_email=None,
|
||||
max_budget=None,
|
||||
spend=0.0, # Set to 0.0 instead of None
|
||||
models=[], # Set to empty list instead of None
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
budget_reset_at=None,
|
||||
max_parallel_requests=None,
|
||||
allowed_model_region=None,
|
||||
metadata={}, # Set to empty dict instead of None
|
||||
key_alias=None,
|
||||
permissions={}, # Set to empty dict instead of None
|
||||
model_max_budget={}, # Set to empty dict instead of None
|
||||
model_spend={}, # Set to empty dict instead of None
|
||||
)
|
||||
|
||||
# Store the unified object for batch cost tracking
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
managed_files_hook.store_unified_object_id( # type: ignore
|
||||
unified_object_id=unified_object_id,
|
||||
file_object=batch_object,
|
||||
litellm_parent_otel_span=None,
|
||||
model_object_id=model_object_id,
|
||||
file_purpose="batch",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Stored Anthropic batch managed object with unified_object_id={unified_object_id}, batch_id={model_object_id}"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Managed files hook not available, cannot store batch object for cost tracking"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error storing Anthropic batch managed object: {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_actual_model_id_from_router(model_name: str) -> str:
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
if llm_router is not None:
|
||||
# Try to find the model in the router by the model name
|
||||
# Use the existing get_model_ids method from router
|
||||
model_ids = llm_router.get_model_ids(model_name=model_name)
|
||||
if model_ids and len(model_ids) > 0:
|
||||
# Use the first model ID found
|
||||
actual_model_id = model_ids[0]
|
||||
verbose_proxy_logger.info(
|
||||
f"Found model ID in router: {actual_model_id}"
|
||||
)
|
||||
return actual_model_id
|
||||
else:
|
||||
# Fallback to model name
|
||||
actual_model_id = model_name
|
||||
verbose_proxy_logger.warning(
|
||||
f"Model not found in router, using model name: {actual_model_id}"
|
||||
)
|
||||
return actual_model_id
|
||||
else:
|
||||
# Fallback if router is not available
|
||||
verbose_proxy_logger.warning(
|
||||
f"Router not available, using model name: {model_name}"
|
||||
)
|
||||
return model_name
|
||||
@@ -0,0 +1,333 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.types.passthrough_endpoints.assembly_ai import (
|
||||
ASSEMBLY_AI_MAX_POLLING_ATTEMPTS,
|
||||
ASSEMBLY_AI_POLLING_INTERVAL,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
|
||||
|
||||
class AssemblyAITranscriptResponse(TypedDict, total=False):
|
||||
id: str
|
||||
speech_model: str
|
||||
acoustic_model: str
|
||||
language_code: str
|
||||
status: str
|
||||
audio_duration: float
|
||||
|
||||
|
||||
class AssemblyAIPassthroughLoggingHandler:
|
||||
def __init__(self):
|
||||
self.assembly_ai_base_url = "https://api.assemblyai.com"
|
||||
self.assembly_ai_eu_base_url = "https://eu.assemblyai.com"
|
||||
"""
|
||||
The base URL for the AssemblyAI API
|
||||
"""
|
||||
|
||||
self.polling_interval: float = ASSEMBLY_AI_POLLING_INTERVAL
|
||||
"""
|
||||
The polling interval for the AssemblyAI API.
|
||||
litellm needs to poll the GET /transcript/{transcript_id} endpoint to get the status of the transcript.
|
||||
"""
|
||||
|
||||
self.max_polling_attempts = ASSEMBLY_AI_MAX_POLLING_ATTEMPTS
|
||||
"""
|
||||
The maximum number of polling attempts for the AssemblyAI API.
|
||||
"""
|
||||
|
||||
def assemblyai_passthrough_logging_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Since cost tracking requires polling the AssemblyAI API, we need to handle this in a separate thread. Hence the executor.submit.
|
||||
"""
|
||||
executor.submit(
|
||||
self._handle_assemblyai_passthrough_logging,
|
||||
httpx_response,
|
||||
response_body,
|
||||
logging_obj,
|
||||
url_route,
|
||||
result,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _handle_assemblyai_passthrough_logging(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Handles logging for AssemblyAI successful passthrough requests
|
||||
"""
|
||||
from ..pass_through_endpoints import pass_through_endpoint_logging
|
||||
|
||||
model = response_body.get("speech_model", "")
|
||||
verbose_proxy_logger.debug(
|
||||
"response body %s", json.dumps(response_body, indent=4)
|
||||
)
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "assemblyai"
|
||||
response_cost: Optional[float] = None
|
||||
|
||||
transcript_id = response_body.get("id")
|
||||
if transcript_id is None:
|
||||
raise ValueError(
|
||||
"Transcript ID is required to log the cost of the transcription"
|
||||
)
|
||||
transcript_response = self._poll_assembly_for_transcript_response(
|
||||
transcript_id=transcript_id, url_route=url_route
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"finished polling assembly for transcript response- got transcript response %s",
|
||||
json.dumps(transcript_response, indent=4),
|
||||
)
|
||||
if transcript_response:
|
||||
cost = self.get_cost_for_assembly_transcript(
|
||||
speech_model=model,
|
||||
transcript_response=transcript_response,
|
||||
)
|
||||
response_cost = cost
|
||||
|
||||
# Make standard logging object for Vertex AI
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=transcript_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_passthrough_logging_object %s",
|
||||
json.dumps(passthrough_logging_payload, indent=4),
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_logging_object= %s", json.dumps(standard_logging_object, indent=4)
|
||||
)
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "assemblyai"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
asyncio.run(
|
||||
pass_through_endpoint_logging._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=self._get_response_to_log(
|
||||
transcript_response
|
||||
),
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
def _get_response_to_log(
|
||||
self, transcript_response: Optional[AssemblyAITranscriptResponse]
|
||||
) -> dict:
|
||||
if transcript_response is None:
|
||||
return {}
|
||||
return dict(transcript_response)
|
||||
|
||||
def _get_assembly_transcript(
|
||||
self,
|
||||
transcript_id: str,
|
||||
request_region: Optional[Literal["eu"]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Get the transcript details from AssemblyAI API
|
||||
|
||||
Args:
|
||||
response_body (dict): Response containing the transcript ID
|
||||
|
||||
Returns:
|
||||
Optional[dict]: Transcript details if successful, None otherwise
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
passthrough_endpoint_router,
|
||||
)
|
||||
|
||||
_base_url = (
|
||||
self.assembly_ai_eu_base_url
|
||||
if request_region == "eu"
|
||||
else self.assembly_ai_base_url
|
||||
)
|
||||
_api_key = passthrough_endpoint_router.get_credentials(
|
||||
custom_llm_provider="assemblyai",
|
||||
region_name=request_region,
|
||||
)
|
||||
if _api_key is None:
|
||||
raise ValueError("AssemblyAI API key not found")
|
||||
try:
|
||||
url = f"{_base_url}/v2/transcript/{transcript_id}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {_api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = httpx.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI transcript: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
def _poll_assembly_for_transcript_response(
|
||||
self,
|
||||
transcript_id: str,
|
||||
url_route: Optional[str] = None,
|
||||
) -> Optional[AssemblyAITranscriptResponse]:
|
||||
"""
|
||||
Poll the status of the transcript until it is completed or timeout (30 minutes)
|
||||
"""
|
||||
for _ in range(
|
||||
self.max_polling_attempts
|
||||
): # 180 attempts * 10s = 30 minutes max
|
||||
transcript = self._get_assembly_transcript(
|
||||
request_region=AssemblyAIPassthroughLoggingHandler._get_assembly_region_from_url(
|
||||
url=url_route
|
||||
),
|
||||
transcript_id=transcript_id,
|
||||
)
|
||||
if transcript is None:
|
||||
return None
|
||||
if (
|
||||
transcript.get("status") == "completed"
|
||||
or transcript.get("status") == "error"
|
||||
):
|
||||
return AssemblyAITranscriptResponse(**transcript)
|
||||
time.sleep(self.polling_interval)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_cost_for_assembly_transcript(
|
||||
transcript_response: AssemblyAITranscriptResponse,
|
||||
speech_model: str,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Get the cost for the assembly transcript
|
||||
"""
|
||||
_audio_duration = transcript_response.get("audio_duration")
|
||||
if _audio_duration is None:
|
||||
return None
|
||||
_cost_per_second = (
|
||||
AssemblyAIPassthroughLoggingHandler.get_cost_per_second_for_assembly_model(
|
||||
speech_model=speech_model
|
||||
)
|
||||
)
|
||||
if _cost_per_second is None:
|
||||
return None
|
||||
return _audio_duration * _cost_per_second
|
||||
|
||||
@staticmethod
|
||||
def get_cost_per_second_for_assembly_model(speech_model: str) -> Optional[float]:
|
||||
"""
|
||||
Get the cost per second for the assembly model.
|
||||
Falls back to assemblyai/nano if the specific speech model info cannot be found.
|
||||
"""
|
||||
try:
|
||||
# First try with the provided speech model
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model=speech_model,
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass # Continue to fallback if model not found
|
||||
|
||||
# Fallback to assemblyai/nano if speech model info not found
|
||||
try:
|
||||
model_info = litellm.get_model_info(
|
||||
model="assemblyai/nano",
|
||||
custom_llm_provider="assemblyai",
|
||||
)
|
||||
if model_info and model_info.get("input_cost_per_second") is not None:
|
||||
return model_info.get("input_cost_per_second")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"[Non blocking logging error] Error getting AssemblyAI model info: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _should_log_request(request_method: str) -> bool:
|
||||
"""
|
||||
only POST transcription jobs are logged. litellm will POLL assembly to wait for the transcription to complete to log the complete response / cost
|
||||
"""
|
||||
return request_method == "POST"
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_region_from_url(url: Optional[str]) -> Optional[Literal["eu"]]:
|
||||
"""
|
||||
Get the region from the URL
|
||||
"""
|
||||
if url is None:
|
||||
return None
|
||||
if urlparse(url).hostname == "eu.assemblyai.com":
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_assembly_base_url_from_region(region: Optional[Literal["eu"]]) -> str:
|
||||
"""
|
||||
Get the base URL for the AssemblyAI API
|
||||
if region == "eu", return "https://api.eu.assemblyai.com"
|
||||
else return "https://api.assemblyai.com"
|
||||
"""
|
||||
if region == "eu":
|
||||
return "https://api.eu.assemblyai.com"
|
||||
return "https://api.assemblyai.com"
|
||||
@@ -0,0 +1,221 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.proxy.auth.auth_utils import get_end_user_id_from_request_body
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
from ..types import EndpointType
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BasePassthroughLoggingHandler(ABC):
|
||||
@property
|
||||
@abstractmethod
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_config(self, model: str) -> BaseConfig:
|
||||
pass
|
||||
|
||||
def passthrough_chat_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transforms LLM response to OpenAI response, generates a standard logging object so downstream logging can be handled
|
||||
"""
|
||||
model = request_body.get("model", response_body.get("model", ""))
|
||||
provider_config = self.get_provider_config(model=model)
|
||||
litellm_model_response: ModelResponse = provider_config.transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=[],
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
kwargs = self._create_response_logging_payload(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
def _get_user_from_metadata(
|
||||
self,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
) -> Optional[str]:
|
||||
request_body = passthrough_logging_payload.get("request_body")
|
||||
if request_body:
|
||||
return get_end_user_id_from_request_body(request_body)
|
||||
return None
|
||||
|
||||
def _create_response_logging_payload(
|
||||
self,
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> dict:
|
||||
"""
|
||||
Create the standard logging object for Generic LLM passthrough
|
||||
|
||||
handles streaming and non-streaming responses
|
||||
"""
|
||||
|
||||
try:
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
passthrough_logging_payload: Optional[PassthroughStandardLoggingPayload] = ( # type: ignore
|
||||
kwargs.get("passthrough_logging_payload")
|
||||
)
|
||||
if passthrough_logging_payload:
|
||||
user = self._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs.setdefault("litellm_params", {})
|
||||
kwargs["litellm_params"].update(
|
||||
{"proxy_server_request": {"body": {"user": user}}}
|
||||
)
|
||||
|
||||
# Make standard logging object for Anthropic
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug(
|
||||
"standard_logging_object= %s",
|
||||
json.dumps(standard_logging_object, indent=4),
|
||||
)
|
||||
kwargs["standard_logging_object"] = standard_logging_object
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
litellm_model_response.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
return kwargs
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error creating LLM passthrough response logging payload: %s", e
|
||||
)
|
||||
return kwargs
|
||||
|
||||
@abstractmethod
|
||||
def _build_complete_streaming_response(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw chunks
|
||||
|
||||
- Converts str chunks to generic chunks
|
||||
- Converts generic chunks to litellm chunks (OpenAI format)
|
||||
- Builds complete response from litellm chunks
|
||||
"""
|
||||
pass
|
||||
|
||||
def _handle_logging_llm_collected_chunks(
|
||||
self,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Anthropic passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
|
||||
model = request_body.get("model", "")
|
||||
complete_streaming_response = self._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Anthropic passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
kwargs = self._create_response_logging_payload(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.cohere.chat.v2_transformation import CohereV2ChatConfig
|
||||
from litellm.llms.cohere.common_utils import (
|
||||
ModelResponseIterator as CohereModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.cohere.embed.v1_transformation import CohereEmbeddingConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
LlmProviders,
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
from .base_passthrough_logging_handler import BasePassthroughLoggingHandler
|
||||
|
||||
|
||||
class CoherePassthroughLoggingHandler(BasePassthroughLoggingHandler):
|
||||
@property
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
return LlmProviders.COHERE
|
||||
|
||||
def get_provider_config(self, model: str) -> BaseConfig:
|
||||
return CohereV2ChatConfig()
|
||||
|
||||
def _build_complete_streaming_response(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
cohere_model_response_iterator = CohereModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
litellm_custom_stream_wrapper = CustomStreamWrapper(
|
||||
completion_stream=cohere_model_response_iterator,
|
||||
model=model,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider="cohere",
|
||||
)
|
||||
all_openai_chunks = []
|
||||
for _chunk_str in all_chunks:
|
||||
try:
|
||||
generic_chunk = (
|
||||
cohere_model_response_iterator.convert_str_chunk_to_generic_chunk(
|
||||
chunk=_chunk_str
|
||||
)
|
||||
)
|
||||
litellm_chunk = litellm_custom_stream_wrapper.chunk_creator(
|
||||
chunk=generic_chunk
|
||||
)
|
||||
if litellm_chunk is not None:
|
||||
all_openai_chunks.append(litellm_chunk)
|
||||
except (StopIteration, StopAsyncIteration):
|
||||
break
|
||||
complete_streaming_response = stream_chunk_builder(chunks=all_openai_chunks)
|
||||
return complete_streaming_response
|
||||
|
||||
def cohere_passthrough_handler( # noqa: PLR0915
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle Cohere passthrough logging with route detection and cost tracking.
|
||||
"""
|
||||
# Check if this is an embed endpoint
|
||||
if "/v1/embed" in url_route:
|
||||
model = request_body.get("model", response_body.get("model", ""))
|
||||
try:
|
||||
cohere_embed_config = CohereEmbeddingConfig()
|
||||
litellm_model_response = litellm.EmbeddingResponse()
|
||||
handler_instance = CoherePassthroughLoggingHandler()
|
||||
|
||||
input_texts = request_body.get("texts", [])
|
||||
if not input_texts:
|
||||
input_texts = request_body.get("input", [])
|
||||
|
||||
# Transform the response
|
||||
litellm_model_response = cohere_embed_config._transform_response(
|
||||
response=httpx_response,
|
||||
api_key="",
|
||||
logging_obj=logging_obj,
|
||||
data=request_body,
|
||||
model_response=litellm_model_response,
|
||||
model=model,
|
||||
encoding=litellm.encoding,
|
||||
input=input_texts,
|
||||
)
|
||||
|
||||
# Calculate cost using LiteLLM's cost calculator
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
custom_llm_provider="cohere",
|
||||
call_type="aembedding",
|
||||
)
|
||||
|
||||
# Set the calculated cost in _hidden_params to prevent recalculation
|
||||
if not hasattr(litellm_model_response, "_hidden_params"):
|
||||
litellm_model_response._hidden_params = {}
|
||||
litellm_model_response._hidden_params["response_cost"] = response_cost
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "cohere"
|
||||
|
||||
# Extract user information for tracking
|
||||
passthrough_logging_payload: Optional[
|
||||
PassthroughStandardLoggingPayload
|
||||
] = kwargs.get("passthrough_logging_payload")
|
||||
if passthrough_logging_payload:
|
||||
user = handler_instance._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs.setdefault("litellm_params", {})
|
||||
kwargs["litellm_params"].update(
|
||||
{"proxy_server_request": {"body": {"user": user}}}
|
||||
)
|
||||
|
||||
# Create standard logging object
|
||||
if litellm_model_response is not None:
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
# Update logging object with cost information
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "cohere"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
except Exception:
|
||||
# For other routes (e.g., /v2/chat), fall back to chat handler
|
||||
return super().passthrough_chat_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# For non-embed routes (e.g., /v2/chat), fall back to chat handler
|
||||
return super().passthrough_chat_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Cursor Cloud Agents API - Pass-through Logging Handler
|
||||
|
||||
Transforms Cursor API responses into standardized logging payloads
|
||||
so they appear cleanly in the LiteLLM Logs page.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
|
||||
CURSOR_AGENT_ENDPOINTS: Dict[str, str] = {
|
||||
"POST /v0/agents": "cursor:agent:create",
|
||||
"GET /v0/agents": "cursor:agent:list",
|
||||
"POST /v0/agents/{id}/followup": "cursor:agent:followup",
|
||||
"POST /v0/agents/{id}/stop": "cursor:agent:stop",
|
||||
"DELETE /v0/agents/{id}": "cursor:agent:delete",
|
||||
"GET /v0/agents/{id}/conversation": "cursor:agent:conversation",
|
||||
"GET /v0/agents/{id}": "cursor:agent:status",
|
||||
"GET /v0/me": "cursor:account:info",
|
||||
"GET /v0/models": "cursor:models:list",
|
||||
"GET /v0/repositories": "cursor:repositories:list",
|
||||
}
|
||||
|
||||
|
||||
def _classify_cursor_request(method: str, path: str) -> str:
|
||||
"""Classify a Cursor API request into a readable operation name."""
|
||||
normalized = path.rstrip("/")
|
||||
|
||||
for pattern, operation in CURSOR_AGENT_ENDPOINTS.items():
|
||||
pat_method, pat_path = pattern.split(" ", 1)
|
||||
if method.upper() != pat_method:
|
||||
continue
|
||||
|
||||
pat_parts = pat_path.strip("/").split("/")
|
||||
req_parts = normalized.strip("/").split("/")
|
||||
|
||||
if len(pat_parts) != len(req_parts):
|
||||
continue
|
||||
|
||||
match = True
|
||||
for pp, rp in zip(pat_parts, req_parts):
|
||||
if pp.startswith("{") and pp.endswith("}"):
|
||||
continue
|
||||
if pp != rp:
|
||||
match = False
|
||||
break
|
||||
if match:
|
||||
return operation
|
||||
|
||||
return f"cursor:{method.lower()}:{normalized}"
|
||||
|
||||
|
||||
class CursorPassthroughLoggingHandler:
|
||||
"""Handles logging for Cursor Cloud Agents pass-through requests."""
|
||||
|
||||
@staticmethod
|
||||
def cursor_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Transform a Cursor API response into a standard logging payload.
|
||||
"""
|
||||
try:
|
||||
method = httpx_response.request.method
|
||||
path = httpx.URL(url_route).path
|
||||
operation = _classify_cursor_request(method, path)
|
||||
|
||||
agent_id = response_body.get("id", "")
|
||||
agent_name = response_body.get("name", "")
|
||||
agent_status = response_body.get("status", "")
|
||||
|
||||
model_name = f"cursor/{operation}"
|
||||
|
||||
summary_parts = []
|
||||
if agent_id:
|
||||
summary_parts.append(f"id={agent_id}")
|
||||
if agent_name:
|
||||
summary_parts.append(f"name={agent_name}")
|
||||
if agent_status:
|
||||
summary_parts.append(f"status={agent_status}")
|
||||
|
||||
response_summary = ", ".join(summary_parts) if summary_parts else result
|
||||
|
||||
kwargs["model"] = model_name
|
||||
kwargs["response_cost"] = 0.0
|
||||
logging_obj.model_call_details["model"] = model_name
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "cursor"
|
||||
logging_obj.model_call_details["response_cost"] = 0.0
|
||||
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=StandardPassThroughResponseObject(
|
||||
response=response_summary
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
kwargs["standard_logging_object"] = standard_logging_object
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Cursor passthrough logging: operation=%s, agent_id=%s",
|
||||
operation,
|
||||
agent_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"result": StandardPassThroughResponseObject(response=response_summary),
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Error in Cursor passthrough logging handler: %s", e
|
||||
)
|
||||
return {
|
||||
"result": StandardPassThroughResponseObject(response=result),
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
@@ -0,0 +1,254 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.gemini.videos.transformation import GeminiVideoConfig
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator as GeminiModelResponseIterator,
|
||||
)
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
EndpointType = Any
|
||||
|
||||
|
||||
class GeminiPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
def gemini_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
if "predictLongRunning" in url_route:
|
||||
model = GeminiPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
gemini_video_config = GeminiVideoConfig()
|
||||
litellm_video_response = (
|
||||
gemini_video_config.transform_video_create_response(
|
||||
model=model,
|
||||
raw_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider="gemini",
|
||||
request_data=request_body,
|
||||
)
|
||||
)
|
||||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "gemini"
|
||||
logging_obj.custom_llm_provider = "gemini"
|
||||
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_video_response,
|
||||
model=model,
|
||||
custom_llm_provider="gemini",
|
||||
call_type="create_video",
|
||||
)
|
||||
|
||||
# Set response_cost in _hidden_params to prevent recalculation
|
||||
if not hasattr(litellm_video_response, "_hidden_params"):
|
||||
litellm_video_response._hidden_params = {}
|
||||
litellm_video_response._hidden_params["response_cost"] = response_cost
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "gemini"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
return {
|
||||
"result": litellm_video_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
if "generateContent" in url_route:
|
||||
model = GeminiPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
# Use Gemini config for transformation
|
||||
instance_of_gemini_llm = litellm.GoogleAIStudioGeminiConfig()
|
||||
litellm_model_response: ModelResponse = (
|
||||
instance_of_gemini_llm.transform_response(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
||||
],
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
)
|
||||
)
|
||||
kwargs = GeminiPassthroughLoggingHandler._create_gemini_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_gemini_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
model: Optional[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Gemini passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
model = model or GeminiPassthroughLoggingHandler.extract_model_from_url(
|
||||
url_route
|
||||
)
|
||||
complete_streaming_response = (
|
||||
GeminiPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
url_route=url_route,
|
||||
)
|
||||
)
|
||||
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Gemini passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
kwargs = GeminiPassthroughLoggingHandler._create_gemini_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
url_route: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
parsed_chunks = []
|
||||
if "generateContent" in url_route or "streamGenerateContent" in url_route:
|
||||
gemini_iterator: Any = GeminiModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
chunk_parsing_logic: Any = gemini_iterator._common_chunk_parsing_logic
|
||||
parsed_chunks = [chunk_parsing_logic(chunk) for chunk in all_chunks]
|
||||
else:
|
||||
return None
|
||||
|
||||
if len(parsed_chunks) == 0:
|
||||
return None
|
||||
|
||||
all_openai_chunks = []
|
||||
for parsed_chunk in parsed_chunks:
|
||||
if parsed_chunk is None:
|
||||
continue
|
||||
all_openai_chunks.append(parsed_chunk)
|
||||
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks
|
||||
)
|
||||
|
||||
return complete_streaming_response
|
||||
|
||||
@staticmethod
|
||||
def extract_model_from_url(url: str) -> str:
|
||||
pattern = r"/models/([^:]+)"
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def _create_gemini_response_logging_payload_for_generate_content(
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: str,
|
||||
):
|
||||
"""
|
||||
Create the standard logging object for Gemini passthrough generateContent (streaming and non-streaming)
|
||||
"""
|
||||
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug("kwargs= %s", kwargs)
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
logging_obj.model = litellm_model_response.model or model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
return kwargs
|
||||
@@ -0,0 +1,608 @@
|
||||
"""
|
||||
OpenAI Passthrough Logging Handler
|
||||
|
||||
Handles cost tracking and logging for OpenAI passthrough endpoints, specifically /chat/completions.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
from litellm.llms.openai.openai import OpenAIConfig
|
||||
from litellm.llms.openai.openai import OpenAIConfig as OpenAIConfigType
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.base_passthrough_logging_handler import (
|
||||
BasePassthroughLoggingHandler,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
EndpointType,
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse, LlmProviders, PassthroughCallTypes
|
||||
from litellm.utils import ModelResponse, TextCompletionResponse
|
||||
|
||||
|
||||
class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler):
|
||||
"""
|
||||
OpenAI-specific passthrough logging handler that provides cost tracking for /chat/completions endpoints.
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
return LlmProviders.OPENAI
|
||||
|
||||
def get_provider_config(self, model: str) -> OpenAIConfigType:
|
||||
"""Get OpenAI provider configuration for the given model."""
|
||||
return OpenAIConfig()
|
||||
|
||||
@staticmethod
|
||||
def is_openai_chat_completions_route(url_route: str) -> bool:
|
||||
"""Check if the URL route is an OpenAI chat completions endpoint."""
|
||||
if not url_route:
|
||||
return False
|
||||
parsed_url = urlparse(url_route)
|
||||
return bool(
|
||||
parsed_url.hostname
|
||||
and (
|
||||
"api.openai.com" in parsed_url.hostname
|
||||
or "openai.azure.com" in parsed_url.hostname
|
||||
)
|
||||
and "/v1/chat/completions" in parsed_url.path
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_openai_image_generation_route(url_route: str) -> bool:
|
||||
"""Check if the URL route is an OpenAI image generation endpoint."""
|
||||
if not url_route:
|
||||
return False
|
||||
parsed_url = urlparse(url_route)
|
||||
return bool(
|
||||
parsed_url.hostname
|
||||
and (
|
||||
"api.openai.com" in parsed_url.hostname
|
||||
or "openai.azure.com" in parsed_url.hostname
|
||||
)
|
||||
and "/v1/images/generations" in parsed_url.path
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_openai_image_editing_route(url_route: str) -> bool:
|
||||
"""Check if the URL route is an OpenAI image editing endpoint."""
|
||||
if not url_route:
|
||||
return False
|
||||
parsed_url = urlparse(url_route)
|
||||
return bool(
|
||||
parsed_url.hostname
|
||||
and (
|
||||
"api.openai.com" in parsed_url.hostname
|
||||
or "openai.azure.com" in parsed_url.hostname
|
||||
)
|
||||
and "/v1/images/edits" in parsed_url.path
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_openai_responses_route(url_route: str) -> bool:
|
||||
"""Check if the URL route is an OpenAI responses API endpoint."""
|
||||
if not url_route:
|
||||
return False
|
||||
parsed_url = urlparse(url_route)
|
||||
return bool(
|
||||
parsed_url.hostname
|
||||
and (
|
||||
"api.openai.com" in parsed_url.hostname
|
||||
or "openai.azure.com" in parsed_url.hostname
|
||||
)
|
||||
and ("/v1/responses" in parsed_url.path or "/responses" in parsed_url.path)
|
||||
)
|
||||
|
||||
def _get_user_from_metadata(
|
||||
self,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
) -> Optional[str]:
|
||||
"""Extract user information from passthrough logging payload."""
|
||||
request_body = passthrough_logging_payload.get("request_body")
|
||||
if request_body:
|
||||
return request_body.get("user")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _calculate_image_generation_cost(
|
||||
model: str,
|
||||
response_body: dict,
|
||||
request_body: dict,
|
||||
) -> float:
|
||||
"""Calculate cost for OpenAI image generation."""
|
||||
try:
|
||||
# Extract parameters from request
|
||||
n = request_body.get("n", 1)
|
||||
try:
|
||||
n = int(n)
|
||||
except Exception:
|
||||
n = 1
|
||||
size = request_body.get("size", "1024x1024")
|
||||
quality = request_body.get("quality", None)
|
||||
|
||||
# Use LiteLLM's default image cost calculator
|
||||
from litellm.cost_calculator import default_image_cost_calculator
|
||||
|
||||
cost = default_image_cost_calculator(
|
||||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
quality=quality,
|
||||
n=n,
|
||||
size=size,
|
||||
optional_params=request_body,
|
||||
)
|
||||
|
||||
return cost
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Error calculating image generation cost: {str(e)}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _calculate_image_editing_cost(
|
||||
model: str,
|
||||
response_body: dict,
|
||||
request_body: dict,
|
||||
) -> float:
|
||||
"""Calculate cost for OpenAI image editing."""
|
||||
try:
|
||||
# Extract parameters from request
|
||||
n = request_body.get("n", 1)
|
||||
# Image edit typically uses multipart/form-data (because of files), so all fields arrive as strings (e.g., n = "1").
|
||||
try:
|
||||
n = int(n)
|
||||
except Exception:
|
||||
n = 1
|
||||
size = request_body.get("size", "1024x1024")
|
||||
|
||||
# Use LiteLLM's default image cost calculator
|
||||
from litellm.cost_calculator import default_image_cost_calculator
|
||||
|
||||
cost = default_image_cost_calculator(
|
||||
model=model,
|
||||
custom_llm_provider="openai",
|
||||
quality=None, # Image editing doesn't have quality parameter
|
||||
n=n,
|
||||
size=size,
|
||||
optional_params=request_body,
|
||||
)
|
||||
|
||||
return cost
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Error calculating image editing cost: {str(e)}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def openai_passthrough_handler( # noqa: PLR0915
|
||||
httpx_response: httpx.Response,
|
||||
response_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle OpenAI passthrough logging with cost tracking for chat completions, image generation, image editing, and responses API.
|
||||
"""
|
||||
# Check if this is a supported endpoint for cost tracking
|
||||
is_chat_completions = (
|
||||
OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route(url_route)
|
||||
)
|
||||
is_image_generation = (
|
||||
OpenAIPassthroughLoggingHandler.is_openai_image_generation_route(url_route)
|
||||
)
|
||||
is_image_editing = (
|
||||
OpenAIPassthroughLoggingHandler.is_openai_image_editing_route(url_route)
|
||||
)
|
||||
is_responses = OpenAIPassthroughLoggingHandler.is_openai_responses_route(
|
||||
url_route
|
||||
)
|
||||
|
||||
if not (
|
||||
is_chat_completions
|
||||
or is_image_generation
|
||||
or is_image_editing
|
||||
or is_responses
|
||||
):
|
||||
# For unsupported endpoints, return None to let the system fall back to generic behavior
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
# Extract model from request or response
|
||||
model = request_body.get("model", response_body.get("model", ""))
|
||||
if not model:
|
||||
verbose_proxy_logger.warning(
|
||||
"No model found in request or response for OpenAI passthrough cost tracking"
|
||||
)
|
||||
base_handler = OpenAIPassthroughLoggingHandler()
|
||||
return base_handler.passthrough_chat_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
response_cost = 0.0
|
||||
litellm_model_response: Optional[
|
||||
Union[ModelResponse, TextCompletionResponse, ImageResponse]
|
||||
] = None
|
||||
handler_instance = OpenAIPassthroughLoggingHandler()
|
||||
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", "openai")
|
||||
|
||||
if is_chat_completions:
|
||||
# Handle chat completions with existing logic
|
||||
provider_config = handler_instance.get_provider_config(model=model)
|
||||
# Preserve existing litellm_params to maintain metadata tags
|
||||
existing_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
litellm_model_response = provider_config.transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=request_body.get("messages", []),
|
||||
logging_obj=logging_obj,
|
||||
optional_params=request_body.get("optional_params", {}),
|
||||
api_key="",
|
||||
request_data=request_body,
|
||||
encoding=litellm.encoding,
|
||||
json_mode=request_body.get("response_format", {}).get("type")
|
||||
== "json_object",
|
||||
litellm_params=existing_litellm_params,
|
||||
)
|
||||
|
||||
# Calculate cost using LiteLLM's cost calculator
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
elif is_image_generation:
|
||||
# Handle image generation cost calculation
|
||||
response_cost = (
|
||||
OpenAIPassthroughLoggingHandler._calculate_image_generation_cost(
|
||||
model=model,
|
||||
response_body=response_body,
|
||||
request_body=request_body,
|
||||
)
|
||||
)
|
||||
# Mark call type for downstream image-aware logic/metrics
|
||||
try:
|
||||
logging_obj.call_type = (
|
||||
PassthroughCallTypes.passthrough_image_generation.value
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Create a simple response object for logging
|
||||
litellm_model_response = ImageResponse(
|
||||
data=response_body.get("data", []),
|
||||
model=model,
|
||||
)
|
||||
# Set the calculated cost in _hidden_params to prevent recalculation
|
||||
if not hasattr(litellm_model_response, "_hidden_params"):
|
||||
litellm_model_response._hidden_params = {}
|
||||
litellm_model_response._hidden_params["response_cost"] = response_cost
|
||||
elif is_image_editing:
|
||||
# Handle image editing cost calculation
|
||||
response_cost = (
|
||||
OpenAIPassthroughLoggingHandler._calculate_image_editing_cost(
|
||||
model=model,
|
||||
response_body=response_body,
|
||||
request_body=request_body,
|
||||
)
|
||||
)
|
||||
# Mark call type for downstream image-aware logic/metrics
|
||||
try:
|
||||
logging_obj.call_type = (
|
||||
PassthroughCallTypes.passthrough_image_generation.value
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Create a simple response object for logging
|
||||
litellm_model_response = ImageResponse(
|
||||
data=response_body.get("data", []),
|
||||
model=model,
|
||||
)
|
||||
# Set the calculated cost in _hidden_params to prevent recalculation
|
||||
if not hasattr(litellm_model_response, "_hidden_params"):
|
||||
litellm_model_response._hidden_params = {}
|
||||
litellm_model_response._hidden_params["response_cost"] = response_cost
|
||||
elif is_responses:
|
||||
# Handle responses API cost calculation
|
||||
provider_config = handler_instance.get_provider_config(model=model)
|
||||
existing_litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
litellm_model_response = provider_config.transform_response(
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
model=model,
|
||||
messages=request_body.get("messages", []),
|
||||
logging_obj=logging_obj,
|
||||
optional_params=request_body.get("optional_params", {}),
|
||||
api_key="",
|
||||
request_data=request_body,
|
||||
encoding=litellm.encoding,
|
||||
json_mode=False,
|
||||
litellm_params=existing_litellm_params,
|
||||
)
|
||||
|
||||
# Calculate cost using LiteLLM's cost calculator with responses call type
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
call_type="responses",
|
||||
)
|
||||
|
||||
# Update kwargs with cost information
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Extract user information for tracking
|
||||
passthrough_logging_payload: Optional[
|
||||
PassthroughStandardLoggingPayload
|
||||
] = kwargs.get("passthrough_logging_payload")
|
||||
if passthrough_logging_payload:
|
||||
user = handler_instance._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs["litellm_params"].setdefault(
|
||||
"proxy_server_request", {}
|
||||
).setdefault("body", {})["user"] = user
|
||||
|
||||
# Create standard logging object
|
||||
if litellm_model_response is not None:
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=litellm_model_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
# Update logging object with cost information
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
endpoint_type = (
|
||||
"chat_completions"
|
||||
if is_chat_completions
|
||||
else "image_generation"
|
||||
if is_image_generation
|
||||
else "image_editing"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"OpenAI passthrough cost tracking - Endpoint: {endpoint_type}, Model: {model}, Cost: ${response_cost:.6f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in OpenAI passthrough cost tracking: {str(e)}"
|
||||
)
|
||||
# Fall back to base handler without cost tracking
|
||||
base_handler = OpenAIPassthroughLoggingHandler()
|
||||
return base_handler.passthrough_chat_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _build_complete_streaming_response(
|
||||
self,
|
||||
all_chunks: list,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
"""
|
||||
Builds complete response from raw chunks for OpenAI streaming responses.
|
||||
|
||||
- Converts str chunks to generic chunks
|
||||
- Converts generic chunks to litellm chunks (OpenAI format)
|
||||
- Builds complete response from litellm chunks
|
||||
"""
|
||||
try:
|
||||
# OpenAI's response iterator to parse chunks
|
||||
from litellm.llms.openai.openai import OpenAIChatCompletionResponseIterator
|
||||
|
||||
openai_iterator = OpenAIChatCompletionResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
|
||||
all_openai_chunks = []
|
||||
for chunk_str in all_chunks:
|
||||
try:
|
||||
# Parse the string chunk using the base iterator's string parser
|
||||
from litellm.llms.base_llm.base_model_iterator import (
|
||||
BaseModelResponseIterator,
|
||||
)
|
||||
|
||||
# Convert string chunk to dict
|
||||
stripped_json_chunk = (
|
||||
BaseModelResponseIterator._string_to_dict_parser(
|
||||
str_line=chunk_str
|
||||
)
|
||||
)
|
||||
|
||||
if stripped_json_chunk:
|
||||
# Parse the chunk using OpenAI's chunk parser
|
||||
transformed_chunk = openai_iterator.chunk_parser(
|
||||
chunk=stripped_json_chunk
|
||||
)
|
||||
if transformed_chunk is not None:
|
||||
all_openai_chunks.append(transformed_chunk)
|
||||
|
||||
except (StopIteration, StopAsyncIteration, Exception) as e:
|
||||
verbose_proxy_logger.debug(f"Error parsing streaming chunk: {e}")
|
||||
continue
|
||||
|
||||
if not all_openai_chunks:
|
||||
verbose_proxy_logger.warning(
|
||||
"No valid chunks found in streaming response"
|
||||
)
|
||||
return None
|
||||
|
||||
# Build complete response from chunks
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks
|
||||
)
|
||||
|
||||
return complete_streaming_response
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error building complete streaming response: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_openai_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle logging for collected OpenAI streaming chunks with cost tracking.
|
||||
"""
|
||||
try:
|
||||
# Extract model from request body
|
||||
model = request_body.get("model", "gpt-4o")
|
||||
|
||||
# Build complete response from chunks using our streaming handler
|
||||
handler = OpenAIPassthroughLoggingHandler()
|
||||
handler_instance = handler
|
||||
complete_response = handler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if complete_response is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Failed to build complete response from OpenAI streaming chunks"
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
|
||||
custom_llm_provider = litellm_logging_obj.model_call_details.get(
|
||||
"custom_llm_provider", "openai"
|
||||
)
|
||||
# Calculate cost using LiteLLM's cost calculator
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=complete_response,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Preserve existing litellm_params to maintain metadata tags
|
||||
existing_litellm_params = (
|
||||
litellm_logging_obj.model_call_details.get("litellm_params", {}) or {}
|
||||
)
|
||||
|
||||
# Prepare kwargs for logging
|
||||
kwargs = {
|
||||
"response_cost": response_cost,
|
||||
"model": model,
|
||||
"custom_llm_provider": custom_llm_provider,
|
||||
"litellm_params": existing_litellm_params.copy(),
|
||||
}
|
||||
|
||||
# Extract user information for tracking
|
||||
passthrough_logging_payload: Optional[
|
||||
PassthroughStandardLoggingPayload
|
||||
] = litellm_logging_obj.model_call_details.get(
|
||||
"passthrough_logging_payload"
|
||||
)
|
||||
if passthrough_logging_payload:
|
||||
user = handler_instance._get_user_from_metadata(
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
)
|
||||
if user:
|
||||
kwargs["litellm_params"].setdefault(
|
||||
"proxy_server_request", {}
|
||||
).setdefault("body", {})["user"] = user
|
||||
|
||||
# Create standard logging object
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=complete_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
# Update logging object with cost information
|
||||
litellm_logging_obj.model_call_details["model"] = model
|
||||
litellm_logging_obj.model_call_details[
|
||||
"custom_llm_provider"
|
||||
] = custom_llm_provider
|
||||
litellm_logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"OpenAI streaming passthrough cost tracking - Model: {model}, Cost: ${response_cost:.6f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in OpenAI streaming passthrough cost tracking: {str(e)}"
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": {},
|
||||
}
|
||||
@@ -0,0 +1,403 @@
|
||||
"""
|
||||
Vertex AI Live API WebSocket Passthrough Logging Handler
|
||||
|
||||
Handles cost tracking and logging for Vertex AI Live API WebSocket passthrough endpoints.
|
||||
Supports different modalities: text, audio, video, and web search.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.base_passthrough_logging_handler import (
|
||||
BasePassthroughLoggingHandler,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.llm_provider_handlers.openai_passthrough_logging_handler import (
|
||||
PassThroughEndpointLoggingTypedDict,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
class VertexAILivePassthroughLoggingHandler(BasePassthroughLoggingHandler):
|
||||
"""
|
||||
Handles cost tracking and logging for Vertex AI Live API WebSocket passthrough.
|
||||
|
||||
Supports:
|
||||
- Text tokens (input/output)
|
||||
- Audio tokens (input/output)
|
||||
- Video tokens (input/output)
|
||||
- Web search requests
|
||||
- Tool use tokens
|
||||
"""
|
||||
|
||||
def _build_complete_streaming_response(self, *args, **kwargs):
|
||||
"""Not applicable for WebSocket passthrough."""
|
||||
return None
|
||||
|
||||
def get_provider_config(self, model: str):
|
||||
"""Return Vertex AI provider configuration."""
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
VertexGeminiConfig,
|
||||
)
|
||||
|
||||
return VertexGeminiConfig()
|
||||
|
||||
@property
|
||||
def llm_provider_name(self) -> LlmProviders:
|
||||
"""Return the LLM provider name."""
|
||||
return LlmProviders.VERTEX_AI
|
||||
|
||||
@staticmethod
|
||||
def _extract_usage_metadata_from_websocket_messages(
|
||||
websocket_messages: List[Dict],
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Extract and aggregate usage metadata from a list of WebSocket messages.
|
||||
|
||||
Args:
|
||||
websocket_messages: List of WebSocket messages from the Live API
|
||||
|
||||
Returns:
|
||||
Dictionary containing aggregated usage metadata, or None if not found
|
||||
"""
|
||||
all_usage_metadata = []
|
||||
|
||||
# Collect all usage metadata messages
|
||||
for message in websocket_messages:
|
||||
if isinstance(message, dict) and "usageMetadata" in message:
|
||||
all_usage_metadata.append(message["usageMetadata"])
|
||||
|
||||
if not all_usage_metadata:
|
||||
return None
|
||||
|
||||
# If only one usage metadata, return it as-is
|
||||
if len(all_usage_metadata) == 1:
|
||||
return all_usage_metadata[0]
|
||||
|
||||
# Aggregate multiple usage metadata messages
|
||||
aggregated: Dict[str, Any] = {
|
||||
"promptTokenCount": 0,
|
||||
"candidatesTokenCount": 0,
|
||||
"totalTokenCount": 0,
|
||||
"promptTokensDetails": [],
|
||||
"candidatesTokensDetails": [],
|
||||
}
|
||||
|
||||
# Aggregate token counts
|
||||
for usage in all_usage_metadata:
|
||||
aggregated["promptTokenCount"] += usage.get("promptTokenCount", 0)
|
||||
aggregated["candidatesTokenCount"] += usage.get("candidatesTokenCount", 0)
|
||||
aggregated["totalTokenCount"] += usage.get("totalTokenCount", 0)
|
||||
|
||||
# Aggregate token details by modality
|
||||
modality_totals = {}
|
||||
|
||||
for usage in all_usage_metadata:
|
||||
# Process prompt tokens details
|
||||
for detail in usage.get("promptTokensDetails", []):
|
||||
modality = detail.get("modality", "TEXT")
|
||||
token_count = detail.get("tokenCount", 0)
|
||||
|
||||
if modality not in modality_totals:
|
||||
modality_totals[modality] = {"prompt": 0, "candidate": 0}
|
||||
modality_totals[modality]["prompt"] += token_count
|
||||
|
||||
# Process candidate tokens details
|
||||
for detail in usage.get("candidatesTokensDetails", []):
|
||||
modality = detail.get("modality", "TEXT")
|
||||
token_count = detail.get("tokenCount", 0)
|
||||
|
||||
if modality not in modality_totals:
|
||||
modality_totals[modality] = {"prompt": 0, "candidate": 0}
|
||||
modality_totals[modality]["candidate"] += token_count
|
||||
|
||||
# Convert aggregated modality totals back to details format
|
||||
for modality, totals in modality_totals.items():
|
||||
if totals["prompt"] > 0:
|
||||
aggregated["promptTokensDetails"].append(
|
||||
{"modality": modality, "tokenCount": totals["prompt"]}
|
||||
)
|
||||
if totals["candidate"] > 0:
|
||||
aggregated["candidatesTokensDetails"].append(
|
||||
{"modality": modality, "tokenCount": totals["candidate"]}
|
||||
)
|
||||
|
||||
# Add any additional fields from the first usage metadata
|
||||
first_usage = all_usage_metadata[0]
|
||||
for key, value in first_usage.items():
|
||||
if key not in aggregated:
|
||||
aggregated[key] = value
|
||||
|
||||
return aggregated
|
||||
|
||||
@staticmethod
|
||||
def _calculate_live_api_cost(
|
||||
model: str,
|
||||
usage_metadata: Dict,
|
||||
custom_llm_provider: str = "vertex_ai",
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost for Vertex AI Live API based on usage metadata.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gemini-2.0-flash-live-preview-04-09")
|
||||
usage_metadata: Usage metadata from the Live API response
|
||||
custom_llm_provider: The LLM provider (default: "vertex_ai")
|
||||
|
||||
Returns:
|
||||
Total cost in USD
|
||||
"""
|
||||
try:
|
||||
# Get model pricing information
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Vertex AI Live API model info for '{model}': {model_info}"
|
||||
)
|
||||
|
||||
# Check if pricing info is available
|
||||
if not model_info or not model_info.get("input_cost_per_token"):
|
||||
verbose_proxy_logger.error(
|
||||
f"No pricing info found for {model} in local model pricing database"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
total_cost = 0.0
|
||||
|
||||
# Extract token counts from usage metadata
|
||||
prompt_token_count = usage_metadata.get("promptTokenCount", 0)
|
||||
candidates_token_count = usage_metadata.get("candidatesTokenCount", 0)
|
||||
|
||||
# Calculate base text token costs
|
||||
input_cost_per_token = model_info.get("input_cost_per_token", 0.0)
|
||||
output_cost_per_token = model_info.get("output_cost_per_token", 0.0)
|
||||
|
||||
total_cost += prompt_token_count * input_cost_per_token
|
||||
total_cost += candidates_token_count * output_cost_per_token
|
||||
|
||||
# Handle modality-specific costs if present
|
||||
prompt_tokens_details = usage_metadata.get("promptTokensDetails", [])
|
||||
candidates_tokens_details = usage_metadata.get(
|
||||
"candidatesTokensDetails", []
|
||||
)
|
||||
|
||||
# Process prompt tokens by modality
|
||||
for detail in prompt_tokens_details:
|
||||
modality = detail.get("modality", "TEXT")
|
||||
token_count = detail.get("tokenCount", 0)
|
||||
|
||||
if modality == "AUDIO":
|
||||
audio_cost_per_token = model_info.get(
|
||||
"input_cost_per_audio_token", 0.0
|
||||
)
|
||||
total_cost += token_count * audio_cost_per_token
|
||||
elif modality == "VIDEO":
|
||||
# Video tokens are typically per second, but we'll treat as per token for now
|
||||
video_cost_per_token = model_info.get(
|
||||
"input_cost_per_video_per_second", 0.0
|
||||
)
|
||||
total_cost += token_count * video_cost_per_token
|
||||
# TEXT tokens are already handled above
|
||||
|
||||
# Process candidate tokens by modality
|
||||
for detail in candidates_tokens_details:
|
||||
modality = detail.get("modality", "TEXT")
|
||||
token_count = detail.get("tokenCount", 0)
|
||||
|
||||
if modality == "AUDIO":
|
||||
audio_cost_per_token = model_info.get(
|
||||
"output_cost_per_audio_token", 0.0
|
||||
)
|
||||
total_cost += token_count * audio_cost_per_token
|
||||
elif modality == "VIDEO":
|
||||
# Video tokens are typically per second, but we'll treat as per token for now
|
||||
video_cost_per_token = model_info.get(
|
||||
"output_cost_per_video_per_second", 0.0
|
||||
)
|
||||
total_cost += token_count * video_cost_per_token
|
||||
# TEXT tokens are already handled above
|
||||
|
||||
# Handle web search costs if present
|
||||
tool_use_prompt_token_count = usage_metadata.get(
|
||||
"toolUsePromptTokenCount", 0
|
||||
)
|
||||
if tool_use_prompt_token_count > 0:
|
||||
# Web search typically has a fixed cost per request
|
||||
web_search_cost = model_info.get("web_search_cost_per_request", 0.0)
|
||||
if isinstance(web_search_cost, (int, float)) and web_search_cost > 0:
|
||||
total_cost += web_search_cost
|
||||
else:
|
||||
# Fallback to token-based pricing for tool use
|
||||
total_cost += tool_use_prompt_token_count * input_cost_per_token
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Vertex AI Live API cost calculation - Model: {model}, "
|
||||
f"Prompt tokens: {prompt_token_count}, "
|
||||
f"Candidate tokens: {candidates_token_count}, "
|
||||
f"Total cost: ${total_cost:.6f}"
|
||||
)
|
||||
|
||||
return total_cost
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error calculating Vertex AI Live API cost: {e}"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
@staticmethod
|
||||
def _create_usage_object_from_metadata(
|
||||
usage_metadata: Dict,
|
||||
model: str,
|
||||
) -> Usage:
|
||||
"""
|
||||
Create a LiteLLM Usage object from Live API usage metadata.
|
||||
|
||||
Args:
|
||||
usage_metadata: Usage metadata from the Live API response
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
LiteLLM Usage object
|
||||
"""
|
||||
prompt_tokens = usage_metadata.get("promptTokenCount", 0)
|
||||
completion_tokens = usage_metadata.get("candidatesTokenCount", 0)
|
||||
total_tokens = usage_metadata.get("totalTokenCount", 0)
|
||||
|
||||
# Create modality-specific token details if available
|
||||
prompt_tokens_details = usage_metadata.get("promptTokensDetails", [])
|
||||
candidates_tokens_details = usage_metadata.get("candidatesTokensDetails", [])
|
||||
|
||||
# Extract text tokens from details
|
||||
text_prompt_tokens = 0
|
||||
text_completion_tokens = 0
|
||||
|
||||
for detail in prompt_tokens_details:
|
||||
if detail.get("modality") == "TEXT":
|
||||
text_prompt_tokens = detail.get("tokenCount", 0)
|
||||
break
|
||||
|
||||
for detail in candidates_tokens_details:
|
||||
if detail.get("modality") == "TEXT":
|
||||
text_completion_tokens = detail.get("tokenCount", 0)
|
||||
break
|
||||
|
||||
# If no text tokens found in details, use total counts
|
||||
if text_prompt_tokens == 0:
|
||||
text_prompt_tokens = prompt_tokens
|
||||
if text_completion_tokens == 0:
|
||||
text_completion_tokens = completion_tokens
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=text_prompt_tokens,
|
||||
completion_tokens=text_completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
)
|
||||
|
||||
def vertex_ai_live_passthrough_handler(
|
||||
self,
|
||||
websocket_messages: List[Dict],
|
||||
logging_obj,
|
||||
url_route: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
request_body: dict,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle cost tracking and logging for Vertex AI Live API WebSocket passthrough.
|
||||
|
||||
Args:
|
||||
websocket_messages: List of WebSocket messages from the Live API
|
||||
logging_obj: LiteLLM logging object
|
||||
url_route: The URL route that was called
|
||||
start_time: Request start time
|
||||
end_time: Request end time
|
||||
request_body: The original request body
|
||||
**kwargs: Additional keyword arguments
|
||||
|
||||
Returns:
|
||||
Dictionary containing the result and kwargs for logging
|
||||
"""
|
||||
try:
|
||||
# Extract model from request body or kwargs
|
||||
model = kwargs.get("model", "gemini-2.0-flash-live-preview-04-09")
|
||||
custom_llm_provider = kwargs.get("custom_llm_provider", "vertex_ai")
|
||||
verbose_proxy_logger.debug(
|
||||
f"Vertex AI Live API model: {model}, custom_llm_provider: {custom_llm_provider}"
|
||||
)
|
||||
|
||||
# Extract usage metadata from WebSocket messages
|
||||
usage_metadata = self._extract_usage_metadata_from_websocket_messages(
|
||||
websocket_messages
|
||||
)
|
||||
|
||||
if not usage_metadata:
|
||||
verbose_proxy_logger.warning(
|
||||
"No usage metadata found in Vertex AI Live API WebSocket messages"
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
# Calculate cost using Live API specific pricing
|
||||
response_cost = self._calculate_live_api_cost(
|
||||
model=model,
|
||||
usage_metadata=usage_metadata,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Create Usage object for standard LiteLLM logging
|
||||
usage = self._create_usage_object_from_metadata(
|
||||
usage_metadata=usage_metadata,
|
||||
model=model,
|
||||
)
|
||||
|
||||
# Create a mock ModelResponse for standard logging
|
||||
litellm_model_response = ModelResponse(
|
||||
id=f"vertex-ai-live-{start_time.timestamp()}",
|
||||
object="chat.completion",
|
||||
created=int(start_time.timestamp()),
|
||||
model=model,
|
||||
usage=usage,
|
||||
choices=[],
|
||||
)
|
||||
|
||||
# Update kwargs with cost information
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
# Safely log the model name: only allow known safe formats, redact otherwise.
|
||||
import re
|
||||
|
||||
allowed_pattern = re.compile(r"^[A-Za-z0-9._\-:]+$")
|
||||
safe_model = (
|
||||
model
|
||||
if isinstance(model, str) and allowed_pattern.match(model)
|
||||
else "[REDACTED]"
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Vertex AI Live API passthrough cost tracking - "
|
||||
f"Model: {safe_model}, Cost: ${response_cost:.6f}, "
|
||||
f"Prompt tokens: {usage.prompt_tokens}, "
|
||||
f"Completion tokens: {usage.completion_tokens}"
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in Vertex AI Live API passthrough handler: {e}"
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
@@ -0,0 +1,851 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import (
|
||||
ModelResponseIterator as VertexModelResponseIterator,
|
||||
)
|
||||
from litellm.llms.vertex_ai.vector_stores.search_api.transformation import (
|
||||
VertexSearchAPIVectorStoreConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.videos.transformation import VertexAIVideoConfig
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingTypedDict
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
SpecialEnums,
|
||||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
|
||||
vertex_search_api_config = VertexSearchAPIVectorStoreConfig()
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
from ..success_handler import PassThroughEndpointLogging
|
||||
else:
|
||||
PassThroughEndpointLogging = Any
|
||||
LiteLLMBatch = Any
|
||||
|
||||
# Define EndpointType locally to avoid import issues
|
||||
EndpointType = Any
|
||||
|
||||
|
||||
class VertexPassthroughLoggingHandler:
|
||||
@staticmethod
|
||||
def vertex_passthrough_handler(
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
if "predictLongRunning" in url_route:
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
vertex_video_config = VertexAIVideoConfig()
|
||||
litellm_video_response = (
|
||||
vertex_video_config.transform_video_create_response(
|
||||
model=model,
|
||||
raw_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider="vertex_ai",
|
||||
request_data=request_body,
|
||||
)
|
||||
)
|
||||
|
||||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "vertex_ai"
|
||||
logging_obj.custom_llm_provider = "vertex_ai"
|
||||
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_video_response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
call_type="create_video",
|
||||
)
|
||||
|
||||
# Set response_cost in _hidden_params to prevent recalculation
|
||||
if not hasattr(litellm_video_response, "_hidden_params"):
|
||||
litellm_video_response._hidden_params = {}
|
||||
litellm_video_response._hidden_params["response_cost"] = response_cost
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "vertex_ai"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
return {
|
||||
"result": litellm_video_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
elif "generateContent" in url_route:
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
instance_of_vertex_llm = litellm.VertexGeminiConfig()
|
||||
litellm_model_response: ModelResponse = (
|
||||
instance_of_vertex_llm.transform_response(
|
||||
model=model,
|
||||
messages=[
|
||||
{"role": "user", "content": "no-message-pass-through-endpoint"}
|
||||
],
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="",
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
)
|
||||
)
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=litellm_model_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
|
||||
url_route
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
elif "predict" in url_route:
|
||||
return VertexPassthroughLoggingHandler._handle_predict_response(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
elif "rawPredict" in url_route or "streamRawPredict" in url_route:
|
||||
from litellm.llms.vertex_ai.vertex_ai_partner_models import (
|
||||
get_vertex_ai_partner_model_config,
|
||||
)
|
||||
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
vertex_publisher_or_api_spec = VertexPassthroughLoggingHandler._get_vertex_publisher_or_api_spec_from_url(
|
||||
url_route
|
||||
)
|
||||
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
litellm_prediction_response = ModelResponse()
|
||||
|
||||
if vertex_publisher_or_api_spec is not None:
|
||||
vertex_ai_partner_model_config = get_vertex_ai_partner_model_config(
|
||||
model=model,
|
||||
vertex_publisher_or_api_spec=vertex_publisher_or_api_spec,
|
||||
)
|
||||
litellm_prediction_response = (
|
||||
vertex_ai_partner_model_config.transform_response(
|
||||
model=model,
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm_prediction_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data={},
|
||||
encoding=litellm.encoding,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "no-message-pass-through-endpoint",
|
||||
}
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=litellm_prediction_response,
|
||||
model="vertex_ai/" + model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
return {
|
||||
"result": litellm_prediction_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
elif "search" in url_route:
|
||||
litellm_vs_response = (
|
||||
vertex_search_api_config.transform_search_vector_store_response(
|
||||
response=httpx_response,
|
||||
litellm_logging_obj=logging_obj,
|
||||
)
|
||||
)
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_vs_response,
|
||||
model="vertex_ai/search_api",
|
||||
custom_llm_provider="vertex_ai",
|
||||
call_type="vector_store_search",
|
||||
)
|
||||
|
||||
standard_pass_through_response_object: StandardPassThroughResponseObject = {
|
||||
"response": cast(dict, litellm_vs_response),
|
||||
}
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = "vertex_ai/search_api"
|
||||
logging_obj.model_call_details.setdefault("litellm_params", {})
|
||||
logging_obj.model_call_details["litellm_params"][
|
||||
"base_model"
|
||||
] = "vertex_ai/search_api"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
return {
|
||||
"result": standard_pass_through_response_object,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
elif "batchPredictionJobs" in url_route:
|
||||
return VertexPassthroughLoggingHandler.batch_prediction_jobs_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_predict_response(
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
kwargs: dict,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""Handle predict endpoint responses (embeddings, image generation)."""
|
||||
from litellm.llms.vertex_ai.image_generation.image_generation_handler import (
|
||||
VertexImageGeneration,
|
||||
)
|
||||
from litellm.llms.vertex_ai.multimodal_embeddings.transformation import (
|
||||
VertexAIMultimodalEmbeddingConfig,
|
||||
)
|
||||
from litellm.types.utils import PassthroughCallTypes
|
||||
|
||||
vertex_image_generation_class = VertexImageGeneration()
|
||||
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
litellm_prediction_response: Union[
|
||||
ModelResponse, EmbeddingResponse, ImageResponse
|
||||
] = ModelResponse()
|
||||
if vertex_image_generation_class.is_image_generation_response(_json_response):
|
||||
litellm_prediction_response = (
|
||||
vertex_image_generation_class.process_image_generation_response(
|
||||
_json_response,
|
||||
model_response=litellm.ImageResponse(),
|
||||
model=model,
|
||||
)
|
||||
)
|
||||
|
||||
logging_obj.call_type = (
|
||||
PassthroughCallTypes.passthrough_image_generation.value
|
||||
)
|
||||
elif VertexPassthroughLoggingHandler._is_multimodal_embedding_response(
|
||||
json_response=_json_response,
|
||||
):
|
||||
# Use multimodal embedding transformation
|
||||
vertex_multimodal_config = VertexAIMultimodalEmbeddingConfig()
|
||||
litellm_prediction_response = (
|
||||
vertex_multimodal_config.transform_embedding_response(
|
||||
model=model,
|
||||
raw_response=httpx_response,
|
||||
model_response=litellm.EmbeddingResponse(),
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
request_data={},
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
)
|
||||
else:
|
||||
litellm_prediction_response = (
|
||||
litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
|
||||
response=_json_response,
|
||||
model=model,
|
||||
model_response=litellm.EmbeddingResponse(),
|
||||
)
|
||||
)
|
||||
if isinstance(litellm_prediction_response, litellm.EmbeddingResponse):
|
||||
litellm_prediction_response.model = model
|
||||
|
||||
logging_obj.model = model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = "vertex_ai"
|
||||
logging_obj.custom_llm_provider = "vertex_ai"
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_prediction_response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
kwargs["custom_llm_provider"] = "vertex_ai"
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
|
||||
return {
|
||||
"result": litellm_prediction_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
all_chunks: List[str],
|
||||
model: Optional[str],
|
||||
end_time: datetime,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Takes raw chunks from Vertex passthrough endpoint and logs them in litellm callbacks
|
||||
|
||||
- Builds complete response from chunks
|
||||
- Creates standard logging object
|
||||
- Logs in litellm callbacks
|
||||
"""
|
||||
kwargs: Dict[str, Any] = {}
|
||||
model = model or VertexPassthroughLoggingHandler.extract_model_from_url(
|
||||
url_route
|
||||
)
|
||||
complete_streaming_response = (
|
||||
VertexPassthroughLoggingHandler._build_complete_streaming_response(
|
||||
all_chunks=all_chunks,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
model=model,
|
||||
url_route=url_route,
|
||||
)
|
||||
)
|
||||
|
||||
if complete_streaming_response is None:
|
||||
verbose_proxy_logger.error(
|
||||
"Unable to build complete streaming response for Vertex passthrough endpoint, not logging..."
|
||||
)
|
||||
return {
|
||||
"result": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
kwargs = VertexPassthroughLoggingHandler._create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response=complete_streaming_response,
|
||||
model=model,
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=litellm_logging_obj,
|
||||
custom_llm_provider=VertexPassthroughLoggingHandler._get_custom_llm_provider_from_url(
|
||||
url_route
|
||||
),
|
||||
)
|
||||
|
||||
return {
|
||||
"result": complete_streaming_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_complete_streaming_response(
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
model: str,
|
||||
url_route: str,
|
||||
) -> Optional[Union[ModelResponse, TextCompletionResponse]]:
|
||||
parsed_chunks = []
|
||||
if "generateContent" in url_route or "streamGenerateContent" in url_route:
|
||||
vertex_iterator: Any = VertexModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
chunk_parsing_logic: Any = vertex_iterator._common_chunk_parsing_logic
|
||||
parsed_chunks = [chunk_parsing_logic(chunk) for chunk in all_chunks]
|
||||
elif "rawPredict" in url_route or "streamRawPredict" in url_route:
|
||||
from litellm.llms.anthropic.chat.handler import ModelResponseIterator
|
||||
from litellm.llms.base_llm.base_model_iterator import (
|
||||
BaseModelResponseIterator,
|
||||
)
|
||||
|
||||
vertex_iterator = ModelResponseIterator(
|
||||
streaming_response=None,
|
||||
sync_stream=False,
|
||||
)
|
||||
chunk_parsing_logic = vertex_iterator.chunk_parser
|
||||
for chunk in all_chunks:
|
||||
dict_chunk = BaseModelResponseIterator._string_to_dict_parser(chunk)
|
||||
if dict_chunk is None:
|
||||
continue
|
||||
parsed_chunks.append(chunk_parsing_logic(dict_chunk))
|
||||
else:
|
||||
return None
|
||||
if len(parsed_chunks) == 0:
|
||||
return None
|
||||
all_openai_chunks = []
|
||||
for parsed_chunk in parsed_chunks:
|
||||
if parsed_chunk is None:
|
||||
continue
|
||||
all_openai_chunks.append(parsed_chunk)
|
||||
|
||||
complete_streaming_response = litellm.stream_chunk_builder(
|
||||
chunks=all_openai_chunks
|
||||
)
|
||||
|
||||
return complete_streaming_response
|
||||
|
||||
@staticmethod
|
||||
def extract_model_from_url(url: str) -> str:
|
||||
pattern = r"/models/([^:]+)"
|
||||
match = re.search(pattern, url)
|
||||
if match:
|
||||
return match.group(1)
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def extract_model_name_from_vertex_path(vertex_model_path: str) -> str:
|
||||
"""
|
||||
Extract the actual model name from a Vertex AI model path.
|
||||
|
||||
Examples:
|
||||
- publishers/google/models/gemini-2.5-flash -> gemini-2.5-flash
|
||||
- projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID -> MODEL_ID
|
||||
|
||||
Args:
|
||||
vertex_model_path: The full Vertex AI model path
|
||||
|
||||
Returns:
|
||||
The extracted model name for use with LiteLLM
|
||||
"""
|
||||
# Handle publishers/google/models/ format
|
||||
if "publishers/" in vertex_model_path and "models/" in vertex_model_path:
|
||||
# Extract everything after the last models/
|
||||
parts = vertex_model_path.split("models/")
|
||||
if len(parts) > 1:
|
||||
return parts[-1]
|
||||
|
||||
# Handle projects/PROJECT_ID/locations/LOCATION/models/MODEL_ID format
|
||||
elif "projects/" in vertex_model_path and "models/" in vertex_model_path:
|
||||
# Extract everything after the last models/
|
||||
parts = vertex_model_path.split("models/")
|
||||
if len(parts) > 1:
|
||||
return parts[-1]
|
||||
|
||||
# If no recognized pattern, return the original path
|
||||
return vertex_model_path
|
||||
|
||||
@staticmethod
|
||||
def _get_vertex_publisher_or_api_spec_from_url(url: str) -> Optional[str]:
|
||||
# Check for specific Vertex AI partner publishers
|
||||
if "/publishers/mistralai/" in url:
|
||||
return "mistralai"
|
||||
elif "/publishers/anthropic/" in url:
|
||||
return "anthropic"
|
||||
elif "/publishers/ai21/" in url:
|
||||
return "ai21"
|
||||
elif "/endpoints/openapi/" in url:
|
||||
return "openapi"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_custom_llm_provider_from_url(url: str) -> str:
|
||||
parsed_url = urlparse(url)
|
||||
if parsed_url.hostname and parsed_url.hostname.endswith(
|
||||
"generativelanguage.googleapis.com"
|
||||
):
|
||||
return litellm.LlmProviders.GEMINI.value
|
||||
return litellm.LlmProviders.VERTEX_AI.value
|
||||
|
||||
@staticmethod
|
||||
def _is_multimodal_embedding_response(json_response: dict) -> bool:
|
||||
"""
|
||||
Detect if the response is from a multimodal embedding request.
|
||||
|
||||
Check if the response contains multimodal embedding fields:
|
||||
- Docs: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-embeddings-api#response-body
|
||||
|
||||
|
||||
Args:
|
||||
json_response: The JSON response from Vertex AI
|
||||
|
||||
Returns:
|
||||
bool: True if this is a multimodal embedding response
|
||||
"""
|
||||
# Check if response contains multimodal embedding fields
|
||||
if "predictions" in json_response:
|
||||
predictions = json_response["predictions"]
|
||||
for prediction in predictions:
|
||||
if isinstance(prediction, dict):
|
||||
# Check for multimodal embedding response fields
|
||||
if any(
|
||||
key in prediction
|
||||
for key in [
|
||||
"textEmbedding",
|
||||
"imageEmbedding",
|
||||
"videoEmbeddings",
|
||||
]
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _create_vertex_response_logging_payload_for_generate_content(
|
||||
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
|
||||
model: str,
|
||||
kwargs: dict,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: str,
|
||||
) -> dict:
|
||||
"""
|
||||
Create the standard logging object for Vertex passthrough generateContent (streaming and non-streaming)
|
||||
|
||||
"""
|
||||
|
||||
response_cost = litellm.completion_cost(
|
||||
completion_response=litellm_model_response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model
|
||||
|
||||
# pretty print standard logging object
|
||||
verbose_proxy_logger.debug("kwargs= %s", kwargs)
|
||||
|
||||
# set litellm_call_id to logging response object
|
||||
litellm_model_response.id = logging_obj.litellm_call_id
|
||||
logging_obj.model = litellm_model_response.model or model
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
|
||||
return kwargs
|
||||
|
||||
@staticmethod
|
||||
def batch_prediction_jobs_handler( # noqa: PLR0915
|
||||
httpx_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
) -> PassThroughEndpointLoggingTypedDict:
|
||||
"""
|
||||
Handle batch prediction jobs passthrough logging.
|
||||
Creates a managed object for cost tracking when batch job is successfully created.
|
||||
"""
|
||||
import base64
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from litellm.llms.vertex_ai.batches.transformation import (
|
||||
VertexAIBatchTransformation,
|
||||
)
|
||||
|
||||
try:
|
||||
_json_response = httpx_response.json()
|
||||
|
||||
# Only handle successful batch job creation (POST requests)
|
||||
if httpx_response.status_code == 200 and "name" in _json_response:
|
||||
# Transform Vertex AI response to LiteLLM batch format
|
||||
litellm_batch_response = VertexAIBatchTransformation.transform_vertex_ai_batch_response_to_openai_batch_response(
|
||||
response=_json_response
|
||||
)
|
||||
|
||||
# Extract batch ID and model from the response
|
||||
batch_id = VertexAIBatchTransformation._get_batch_id_from_vertex_ai_batch_response(
|
||||
_json_response
|
||||
)
|
||||
model_name = _json_response.get("model", "unknown")
|
||||
|
||||
# Create unified object ID for tracking
|
||||
# Format: base64(litellm_proxy;model_id:{};llm_batch_id:{})
|
||||
actual_model_id = (
|
||||
VertexPassthroughLoggingHandler.get_actual_model_id_from_router(
|
||||
model_name
|
||||
)
|
||||
)
|
||||
|
||||
unified_id_string = (
|
||||
SpecialEnums.LITELLM_MANAGED_BATCH_COMPLETE_STR.value.format(
|
||||
actual_model_id, batch_id
|
||||
)
|
||||
)
|
||||
unified_object_id = (
|
||||
base64.urlsafe_b64encode(unified_id_string.encode())
|
||||
.decode()
|
||||
.rstrip("=")
|
||||
)
|
||||
|
||||
# Store the managed object for cost tracking
|
||||
# This will be picked up by check_batch_cost polling mechanism
|
||||
VertexPassthroughLoggingHandler._store_batch_managed_object(
|
||||
unified_object_id=unified_object_id,
|
||||
batch_object=litellm_batch_response,
|
||||
model_object_id=batch_id,
|
||||
logging_obj=logging_obj,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Create a batch job response for logging
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = model_name
|
||||
litellm_model_response.object = "batch_prediction_job"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add batch-specific metadata to indicate this is a pending batch job
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Batch prediction job {batch_id} created and is pending. Status will be updated when the batch completes.",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_id": batch_id,
|
||||
"batch_job_state": "JOB_STATE_PENDING",
|
||||
"unified_object_id": unified_object_id,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
# Set response cost to 0 initially (will be updated when batch completes)
|
||||
response_cost = 0.0
|
||||
kwargs["response_cost"] = response_cost
|
||||
kwargs["model"] = model_name
|
||||
kwargs["batch_id"] = batch_id
|
||||
kwargs["unified_object_id"] = unified_object_id
|
||||
kwargs["batch_job_state"] = "JOB_STATE_PENDING"
|
||||
|
||||
logging_obj.model = model_name
|
||||
logging_obj.model_call_details["model"] = logging_obj.model
|
||||
logging_obj.model_call_details["response_cost"] = response_cost
|
||||
logging_obj.model_call_details["batch_id"] = batch_id
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
else:
|
||||
# Handle non-successful responses
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = "vertex_ai_batch"
|
||||
litellm_model_response.object = "batch_prediction_job"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add error-specific metadata
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Batch prediction job creation failed. Status: {httpx_response.status_code}",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_state": "JOB_STATE_FAILED",
|
||||
"status_code": httpx_response.status_code,
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
kwargs["response_cost"] = 0.0
|
||||
kwargs["model"] = "vertex_ai_batch"
|
||||
kwargs["batch_job_state"] = "JOB_STATE_FAILED"
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error in batch_prediction_jobs_handler: {e}")
|
||||
# Return basic response on error
|
||||
litellm_model_response = ModelResponse()
|
||||
litellm_model_response.id = str(uuid.uuid4())
|
||||
litellm_model_response.model = "vertex_ai_batch"
|
||||
litellm_model_response.object = "batch_prediction_job"
|
||||
litellm_model_response.created = int(start_time.timestamp())
|
||||
|
||||
# Add error-specific metadata
|
||||
litellm_model_response.choices = [
|
||||
Choices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message={
|
||||
"role": "assistant",
|
||||
"content": f"Error creating batch prediction job: {str(e)}",
|
||||
"tool_calls": None,
|
||||
"function_call": None,
|
||||
"provider_specific_fields": {
|
||||
"batch_job_state": "JOB_STATE_FAILED",
|
||||
"error": str(e),
|
||||
},
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
kwargs["response_cost"] = 0.0
|
||||
kwargs["model"] = "vertex_ai_batch"
|
||||
kwargs["batch_job_state"] = "JOB_STATE_FAILED"
|
||||
|
||||
return {
|
||||
"result": litellm_model_response,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _store_batch_managed_object(
|
||||
unified_object_id: str,
|
||||
batch_object: LiteLLMBatch,
|
||||
model_object_id: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Store batch managed object for cost tracking.
|
||||
This will be picked up by the check_batch_cost polling mechanism.
|
||||
"""
|
||||
try:
|
||||
# Get the managed files hook from the logging object
|
||||
# This is a bit of a hack, but we need access to the proxy logging system
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
|
||||
managed_files_hook = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
if managed_files_hook is not None and hasattr(
|
||||
managed_files_hook, "store_unified_object_id"
|
||||
):
|
||||
# Create a mock user API key dict for the managed object storage
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
user_id=kwargs.get("user_id", "default-user"),
|
||||
api_key="",
|
||||
team_id=None,
|
||||
team_alias=None,
|
||||
user_role=LitellmUserRoles.CUSTOMER, # Use proper enum value
|
||||
user_email=None,
|
||||
max_budget=None,
|
||||
spend=0.0, # Set to 0.0 instead of None
|
||||
models=[], # Set to empty list instead of None
|
||||
tpm_limit=None,
|
||||
rpm_limit=None,
|
||||
budget_duration=None,
|
||||
budget_reset_at=None,
|
||||
max_parallel_requests=None,
|
||||
allowed_model_region=None,
|
||||
metadata={}, # Set to empty dict instead of None
|
||||
key_alias=None,
|
||||
permissions={}, # Set to empty dict instead of None
|
||||
model_max_budget={}, # Set to empty dict instead of None
|
||||
model_spend={}, # Set to empty dict instead of None
|
||||
)
|
||||
|
||||
# Store the unified object for batch cost tracking
|
||||
import asyncio
|
||||
|
||||
asyncio.create_task(
|
||||
managed_files_hook.store_unified_object_id( # type: ignore
|
||||
unified_object_id=unified_object_id,
|
||||
file_object=batch_object,
|
||||
litellm_parent_otel_span=None,
|
||||
model_object_id=model_object_id,
|
||||
file_purpose="batch",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Stored batch managed object with unified_object_id={unified_object_id}, batch_id={model_object_id}"
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"Managed files hook not available, cannot store batch object for cost tracking"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error storing batch managed object: {e}")
|
||||
|
||||
@staticmethod
|
||||
def get_actual_model_id_from_router(model_name: str) -> str:
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
if llm_router is not None:
|
||||
# Try to find the model in the router by the extracted model name
|
||||
extracted_model_name = (
|
||||
VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(
|
||||
model_name
|
||||
)
|
||||
)
|
||||
|
||||
# Use the existing get_model_ids method from router
|
||||
model_ids = llm_router.get_model_ids(model_name=extracted_model_name)
|
||||
if model_ids and len(model_ids) > 0:
|
||||
# Use the first model ID found
|
||||
actual_model_id = model_ids[0]
|
||||
verbose_proxy_logger.info(
|
||||
f"Found model ID in router: {actual_model_id}"
|
||||
)
|
||||
return actual_model_id
|
||||
else:
|
||||
# Fallback to constructed model name
|
||||
actual_model_id = extracted_model_name
|
||||
verbose_proxy_logger.warning(
|
||||
f"Model not found in router, using constructed name: {actual_model_id}"
|
||||
)
|
||||
return actual_model_id
|
||||
else:
|
||||
# Fallback if router is not available
|
||||
extracted_model_name = (
|
||||
VertexPassthroughLoggingHandler.extract_model_name_from_vertex_path(
|
||||
model_name
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.warning(
|
||||
f"Router not available, using constructed model name: {extracted_model_name}"
|
||||
)
|
||||
return extracted_model_name
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,212 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.vector_store_integrations.vector_store_pre_call_hook import (
|
||||
LiteLLM_ManagedVectorStore,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
from litellm.types.passthrough_endpoints.vertex_ai import VertexPassThroughCredentials
|
||||
|
||||
|
||||
class PassthroughEndpointRouter:
|
||||
"""
|
||||
Use this class to Set/Get credentials for pass-through endpoints
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.credentials: Dict[str, str] = {}
|
||||
self.deployment_key_to_vertex_credentials: Dict[
|
||||
str, VertexPassThroughCredentials
|
||||
] = {}
|
||||
self.default_vertex_config: Optional[VertexPassThroughCredentials] = None
|
||||
|
||||
def set_pass_through_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
):
|
||||
"""
|
||||
Set credentials for a pass-through endpoint. Used when a user adds a pass-through LLM endpoint on the UI.
|
||||
|
||||
Args:
|
||||
custom_llm_provider: The provider of the pass-through endpoint
|
||||
api_base: The base URL of the pass-through endpoint
|
||||
api_key: The API key for the pass-through endpoint
|
||||
"""
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=self._get_region_name_from_api_base(
|
||||
api_base=api_base, custom_llm_provider=custom_llm_provider
|
||||
),
|
||||
)
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for setting pass-through credentials")
|
||||
self.credentials[credential_name] = api_key
|
||||
|
||||
def get_credentials(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> Optional[str]:
|
||||
credential_name = self._get_credential_name_for_provider(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
region_name=region_name,
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Pass-through llm endpoints router, looking for credentials for {credential_name}"
|
||||
)
|
||||
if credential_name in self.credentials:
|
||||
verbose_router_logger.debug(f"Found credentials for {credential_name}")
|
||||
return self.credentials[credential_name]
|
||||
else:
|
||||
verbose_router_logger.debug(
|
||||
f"No credentials found for {credential_name}, looking for env variable"
|
||||
)
|
||||
_env_variable_name = (
|
||||
self._get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
)
|
||||
return get_secret_str(_env_variable_name)
|
||||
|
||||
def _get_vertex_env_vars(self) -> VertexPassThroughCredentials:
|
||||
"""
|
||||
Helper to get vertex pass through config from environment variables
|
||||
|
||||
The following environment variables are used:
|
||||
- DEFAULT_VERTEXAI_PROJECT (project id)
|
||||
- DEFAULT_VERTEXAI_LOCATION (location)
|
||||
- DEFAULT_GOOGLE_APPLICATION_CREDENTIALS (path to credentials file)
|
||||
"""
|
||||
return VertexPassThroughCredentials(
|
||||
vertex_project=get_secret_str("DEFAULT_VERTEXAI_PROJECT"),
|
||||
vertex_location=get_secret_str("DEFAULT_VERTEXAI_LOCATION"),
|
||||
vertex_credentials=get_secret_str("DEFAULT_GOOGLE_APPLICATION_CREDENTIALS"),
|
||||
)
|
||||
|
||||
def set_default_vertex_config(self, config: Optional[dict] = None):
|
||||
"""Sets vertex configuration from provided config and/or environment variables
|
||||
|
||||
Args:
|
||||
config (Optional[dict]): Configuration dictionary
|
||||
Example: {
|
||||
"vertex_project": "my-project-123",
|
||||
"vertex_location": "us-central1",
|
||||
"vertex_credentials": "os.environ/GOOGLE_CREDS"
|
||||
}
|
||||
"""
|
||||
# Initialize config dictionary if None
|
||||
if config is None:
|
||||
self.default_vertex_config = self._get_vertex_env_vars()
|
||||
return
|
||||
|
||||
if isinstance(config, dict):
|
||||
for key, value in config.items():
|
||||
if isinstance(value, str) and value.startswith("os.environ/"):
|
||||
config[key] = get_secret_str(value)
|
||||
|
||||
self.default_vertex_config = VertexPassThroughCredentials(**config)
|
||||
|
||||
def add_vertex_credentials(
|
||||
self,
|
||||
project_id: str,
|
||||
location: str,
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
):
|
||||
"""
|
||||
Add the vertex credentials for the given project-id, location
|
||||
"""
|
||||
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
if deployment_key is None:
|
||||
verbose_router_logger.debug(
|
||||
"No deployment key found for project-id, location"
|
||||
)
|
||||
return
|
||||
vertex_pass_through_credentials = VertexPassThroughCredentials(
|
||||
vertex_project=project_id,
|
||||
vertex_location=location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
self.deployment_key_to_vertex_credentials[
|
||||
deployment_key
|
||||
] = vertex_pass_through_credentials
|
||||
|
||||
def _get_deployment_key(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the deployment key for the given project-id, location
|
||||
"""
|
||||
if project_id is None or location is None:
|
||||
return None
|
||||
return f"{project_id}-{location}"
|
||||
|
||||
def get_vector_store_credentials(
|
||||
self, vector_store_id: str
|
||||
) -> Optional[LiteLLM_ManagedVectorStore]:
|
||||
"""
|
||||
Get the vector store credentials for the given vector store id
|
||||
"""
|
||||
if litellm.vector_store_registry is None:
|
||||
return None
|
||||
vector_store_to_run: Optional[
|
||||
LiteLLM_ManagedVectorStore
|
||||
] = litellm.vector_store_registry.get_litellm_managed_vector_store_from_registry(
|
||||
vector_store_id=vector_store_id
|
||||
)
|
||||
return vector_store_to_run
|
||||
|
||||
def get_vertex_credentials(
|
||||
self, project_id: Optional[str], location: Optional[str]
|
||||
) -> Optional[VertexPassThroughCredentials]:
|
||||
"""
|
||||
Get the vertex credentials for the given project-id, location
|
||||
"""
|
||||
deployment_key = self._get_deployment_key(
|
||||
project_id=project_id,
|
||||
location=location,
|
||||
)
|
||||
|
||||
if deployment_key is None:
|
||||
return self.default_vertex_config
|
||||
if deployment_key in self.deployment_key_to_vertex_credentials:
|
||||
return self.deployment_key_to_vertex_credentials[deployment_key]
|
||||
else:
|
||||
return self.default_vertex_config
|
||||
|
||||
def _get_credential_name_for_provider(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
region_name: Optional[str],
|
||||
) -> str:
|
||||
if region_name is None:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
return f"{custom_llm_provider.upper()}_{region_name.upper()}_API_KEY"
|
||||
|
||||
def _get_region_name_from_api_base(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
api_base: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the region name from the API base.
|
||||
|
||||
Each provider might have a different way of specifying the region in the API base - this is where you can use conditional logic to handle that.
|
||||
"""
|
||||
if custom_llm_provider == "assemblyai":
|
||||
if api_base and "eu" in api_base:
|
||||
return "eu"
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_default_env_variable_name_passthrough_endpoint(
|
||||
custom_llm_provider: str,
|
||||
) -> str:
|
||||
return f"{custom_llm_provider.upper()}_API_KEY"
|
||||
@@ -0,0 +1,335 @@
|
||||
"""
|
||||
Passthrough Guardrails Helper Module
|
||||
|
||||
Handles guardrail execution for passthrough endpoints with:
|
||||
- Opt-in model (guardrails only run when explicitly configured)
|
||||
- Field-level targeting using JSONPath expressions
|
||||
- Automatic inheritance from org/team/key levels when enabled
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
PassThroughGuardrailsConfig,
|
||||
PassThroughGuardrailSettings,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.jsonpath_extractor import JsonPathExtractor
|
||||
|
||||
# Type for raw guardrails config input (before normalization)
|
||||
# Can be a list of names or a dict with settings
|
||||
PassThroughGuardrailsConfigInput = Union[
|
||||
List[str], # Simple list: ["guard-1", "guard-2"]
|
||||
PassThroughGuardrailsConfig, # Dict: {"guard-1": {"request_fields": [...]}}
|
||||
]
|
||||
|
||||
|
||||
class PassthroughGuardrailHandler:
|
||||
"""
|
||||
Handles guardrail execution for passthrough endpoints.
|
||||
|
||||
Passthrough endpoints use an opt-in model for guardrails:
|
||||
- Guardrails only run when explicitly configured on the endpoint
|
||||
- Supports field-level targeting using JSONPath expressions
|
||||
- Automatically inherits org/team/key level guardrails when enabled
|
||||
|
||||
Guardrails can be specified as:
|
||||
- List format (simple): ["guardrail-1", "guardrail-2"]
|
||||
- Dict format (with settings): {"guardrail-1": {"request_fields": ["query"]}}
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def normalize_config(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> Optional[PassThroughGuardrailsConfig]:
|
||||
"""
|
||||
Normalize guardrails config to dict format.
|
||||
|
||||
Accepts:
|
||||
- List of guardrail names: ["g1", "g2"] -> {"g1": None, "g2": None}
|
||||
- Dict with settings: {"g1": {"request_fields": [...]}}
|
||||
- None: returns None
|
||||
"""
|
||||
if guardrails_config is None:
|
||||
return None
|
||||
|
||||
# Already a dict - return as-is
|
||||
if isinstance(guardrails_config, dict):
|
||||
return guardrails_config
|
||||
|
||||
# List of guardrail names - convert to dict
|
||||
if isinstance(guardrails_config, list):
|
||||
return {name: None for name in guardrails_config}
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails config is not a dict or list, got: %s",
|
||||
type(guardrails_config),
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_enabled(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if guardrails are enabled for a passthrough endpoint.
|
||||
|
||||
Passthrough endpoints are opt-in only - guardrails only run when
|
||||
the guardrails config is set with at least one guardrail.
|
||||
"""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return False
|
||||
return len(normalized) > 0
|
||||
|
||||
@staticmethod
|
||||
def get_guardrail_names(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> List[str]:
|
||||
"""Get the list of guardrail names configured for a passthrough endpoint."""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return []
|
||||
return list(normalized.keys())
|
||||
|
||||
@staticmethod
|
||||
def get_settings(
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
guardrail_name: str,
|
||||
) -> Optional[PassThroughGuardrailSettings]:
|
||||
"""Get settings for a specific guardrail from the passthrough config."""
|
||||
normalized = PassthroughGuardrailHandler.normalize_config(guardrails_config)
|
||||
if normalized is None:
|
||||
return None
|
||||
|
||||
settings = normalized.get(guardrail_name)
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
if isinstance(settings, dict):
|
||||
return PassThroughGuardrailSettings(**settings)
|
||||
|
||||
return settings
|
||||
|
||||
@staticmethod
|
||||
def prepare_input(
|
||||
request_data: dict,
|
||||
guardrail_settings: Optional[PassThroughGuardrailSettings],
|
||||
) -> str:
|
||||
"""
|
||||
Prepare input text for guardrail execution based on field targeting settings.
|
||||
|
||||
If request_fields is specified, extracts only those fields.
|
||||
Otherwise, uses the entire request payload as text.
|
||||
"""
|
||||
if guardrail_settings is None or guardrail_settings.request_fields is None:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
return safe_dumps(request_data)
|
||||
|
||||
return JsonPathExtractor.extract_fields(
|
||||
data=request_data,
|
||||
jsonpath_expressions=guardrail_settings.request_fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def prepare_output(
|
||||
response_data: dict,
|
||||
guardrail_settings: Optional[PassThroughGuardrailSettings],
|
||||
) -> str:
|
||||
"""
|
||||
Prepare output text for guardrail execution based on field targeting settings.
|
||||
|
||||
If response_fields is specified, extracts only those fields.
|
||||
Otherwise, uses the entire response payload as text.
|
||||
"""
|
||||
if guardrail_settings is None or guardrail_settings.response_fields is None:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
return safe_dumps(response_data)
|
||||
|
||||
return JsonPathExtractor.extract_fields(
|
||||
data=response_data,
|
||||
jsonpath_expressions=guardrail_settings.response_fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def execute(
|
||||
request_data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
guardrails_config: Optional[PassThroughGuardrailsConfig],
|
||||
event_type: str = "pre_call",
|
||||
) -> dict:
|
||||
"""
|
||||
Execute guardrails for a passthrough endpoint.
|
||||
|
||||
This is the main entry point for passthrough guardrail execution.
|
||||
|
||||
Args:
|
||||
request_data: The request payload
|
||||
user_api_key_dict: User API key authentication info
|
||||
guardrails_config: Passthrough-specific guardrails configuration
|
||||
event_type: "pre_call" for request, "post_call" for response
|
||||
|
||||
Returns:
|
||||
The potentially modified request_data
|
||||
|
||||
Raises:
|
||||
HTTPException if a guardrail blocks the request
|
||||
"""
|
||||
if not PassthroughGuardrailHandler.is_enabled(guardrails_config):
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails not enabled, skipping guardrail execution"
|
||||
)
|
||||
return request_data
|
||||
|
||||
guardrail_names = PassthroughGuardrailHandler.get_guardrail_names(
|
||||
guardrails_config
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"Executing passthrough guardrails: %s", guardrail_names
|
||||
)
|
||||
|
||||
# Add to request metadata so guardrails know which to run
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_context import (
|
||||
set_passthrough_guardrails_config,
|
||||
)
|
||||
|
||||
if "metadata" not in request_data:
|
||||
request_data["metadata"] = {}
|
||||
|
||||
# Set guardrails in metadata using dict format for compatibility
|
||||
request_data["metadata"]["guardrails"] = {
|
||||
name: True for name in guardrail_names
|
||||
}
|
||||
|
||||
# Store passthrough guardrails config in request-scoped context
|
||||
set_passthrough_guardrails_config(guardrails_config)
|
||||
|
||||
return request_data
|
||||
|
||||
@staticmethod
|
||||
def collect_guardrails(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
passthrough_guardrails_config: Optional[PassThroughGuardrailsConfigInput],
|
||||
) -> Optional[Dict[str, bool]]:
|
||||
"""
|
||||
Collect guardrails for a passthrough endpoint.
|
||||
|
||||
Passthrough endpoints are opt-in only for guardrails. Guardrails only run when
|
||||
the guardrails config is set with at least one guardrail.
|
||||
|
||||
Accepts both list and dict formats:
|
||||
- List: ["guardrail-1", "guardrail-2"]
|
||||
- Dict: {"guardrail-1": {"request_fields": [...]}}
|
||||
|
||||
When enabled, this function collects:
|
||||
- Passthrough-specific guardrails from the config
|
||||
- Org/team/key level guardrails (automatic inheritance when passthrough is enabled)
|
||||
|
||||
Args:
|
||||
user_api_key_dict: User API key authentication info
|
||||
passthrough_guardrails_config: List or Dict of guardrail names/settings
|
||||
|
||||
Returns:
|
||||
Dict of guardrail names to run (format: {guardrail_name: True}), or None
|
||||
"""
|
||||
from litellm.proxy.litellm_pre_call_utils import (
|
||||
_add_guardrails_from_key_or_team_metadata,
|
||||
)
|
||||
|
||||
# Normalize config to dict format (handles both list and dict)
|
||||
normalized_config = PassthroughGuardrailHandler.normalize_config(
|
||||
passthrough_guardrails_config
|
||||
)
|
||||
|
||||
if normalized_config is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails not configured, skipping guardrail collection"
|
||||
)
|
||||
return None
|
||||
|
||||
if len(normalized_config) == 0:
|
||||
verbose_proxy_logger.debug(
|
||||
"Passthrough guardrails config is empty, skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Passthrough is enabled - collect guardrails
|
||||
guardrails_to_run: Dict[str, bool] = {}
|
||||
|
||||
# Add passthrough-specific guardrails
|
||||
for guardrail_name in normalized_config.keys():
|
||||
guardrails_to_run[guardrail_name] = True
|
||||
verbose_proxy_logger.debug(
|
||||
"Added passthrough-specific guardrail: %s", guardrail_name
|
||||
)
|
||||
|
||||
# Add org/team/key level guardrails using shared helper
|
||||
temp_data: Dict[str, Any] = {"metadata": {}}
|
||||
_add_guardrails_from_key_or_team_metadata(
|
||||
key_metadata=user_api_key_dict.metadata,
|
||||
team_metadata=user_api_key_dict.team_metadata,
|
||||
data=temp_data,
|
||||
metadata_variable_name="metadata",
|
||||
)
|
||||
|
||||
# Merge inherited guardrails into guardrails_to_run
|
||||
inherited_guardrails = temp_data["metadata"].get("guardrails", [])
|
||||
for guardrail_name in inherited_guardrails:
|
||||
if guardrail_name not in guardrails_to_run:
|
||||
guardrails_to_run[guardrail_name] = True
|
||||
verbose_proxy_logger.debug(
|
||||
"Added inherited guardrail (key/team level): %s", guardrail_name
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Collected guardrails for passthrough endpoint: %s",
|
||||
list(guardrails_to_run.keys()),
|
||||
)
|
||||
|
||||
return guardrails_to_run if guardrails_to_run else None
|
||||
|
||||
@staticmethod
|
||||
def get_field_targeted_text(
|
||||
data: dict,
|
||||
guardrail_name: str,
|
||||
is_request: bool = True,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the text to check for a guardrail, respecting field targeting settings.
|
||||
|
||||
Called by guardrail hooks to get the appropriate text based on
|
||||
passthrough field targeting configuration.
|
||||
|
||||
Args:
|
||||
data: The request/response data dict
|
||||
guardrail_name: Name of the guardrail being executed
|
||||
is_request: True for request (pre_call), False for response (post_call)
|
||||
|
||||
Returns:
|
||||
The text to check, or None to use default behavior
|
||||
"""
|
||||
from litellm.proxy.pass_through_endpoints.passthrough_context import (
|
||||
get_passthrough_guardrails_config,
|
||||
)
|
||||
|
||||
passthrough_config = get_passthrough_guardrails_config()
|
||||
if passthrough_config is None:
|
||||
return None
|
||||
|
||||
settings = PassthroughGuardrailHandler.get_settings(
|
||||
passthrough_config, guardrail_name
|
||||
)
|
||||
if settings is None:
|
||||
return None
|
||||
|
||||
if is_request:
|
||||
if settings.request_fields:
|
||||
return JsonPathExtractor.extract_fields(data, settings.request_fields)
|
||||
else:
|
||||
if settings.response_fields:
|
||||
return JsonPathExtractor.extract_fields(data, settings.response_fields)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,248 @@
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.thread_pool_executor import executor
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import EndpointType
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
|
||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.openai_passthrough_logging_handler import (
|
||||
OpenAIPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
from .success_handler import PassThroughEndpointLogging
|
||||
|
||||
|
||||
class PassThroughStreamingHandler:
|
||||
@staticmethod
|
||||
async def chunk_processor(
|
||||
response: httpx.Response,
|
||||
request_body: Optional[dict],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
):
|
||||
"""
|
||||
- Yields chunks from the response
|
||||
- Collect non-empty chunks for post-processing (logging)
|
||||
- Inject cost into chunks if include_cost_in_streaming_usage is enabled
|
||||
"""
|
||||
try:
|
||||
raw_bytes: List[bytes] = []
|
||||
# Extract model name for cost injection
|
||||
model_name = PassThroughStreamingHandler._extract_model_for_cost_injection(
|
||||
request_body=request_body,
|
||||
url_route=url_route,
|
||||
endpoint_type=endpoint_type,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
async for chunk in response.aiter_bytes():
|
||||
raw_bytes.append(chunk)
|
||||
if (
|
||||
getattr(litellm, "include_cost_in_streaming_usage", False)
|
||||
and model_name
|
||||
):
|
||||
if endpoint_type == EndpointType.VERTEX_AI:
|
||||
# Only handle streamRawPredict (uses Anthropic format)
|
||||
if "streamRawPredict" in url_route or "rawPredict" in url_route:
|
||||
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
|
||||
chunk, model_name
|
||||
)
|
||||
if modified_chunk is not None:
|
||||
chunk = modified_chunk
|
||||
elif endpoint_type == EndpointType.ANTHROPIC:
|
||||
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
|
||||
chunk, model_name
|
||||
)
|
||||
if modified_chunk is not None:
|
||||
chunk = modified_chunk
|
||||
|
||||
yield chunk
|
||||
|
||||
# After all chunks are processed, handle post-processing
|
||||
end_time = datetime.now()
|
||||
|
||||
asyncio.create_task(
|
||||
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body or {},
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
raw_bytes=raw_bytes,
|
||||
end_time=end_time,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def _route_streaming_logging_to_handler(
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
request_body: dict,
|
||||
endpoint_type: EndpointType,
|
||||
start_time: datetime,
|
||||
raw_bytes: List[bytes],
|
||||
end_time: datetime,
|
||||
model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Route the logging for the collected chunks to the appropriate handler
|
||||
|
||||
Supported endpoint types:
|
||||
- Anthropic
|
||||
- Vertex AI
|
||||
- OpenAI
|
||||
"""
|
||||
try:
|
||||
all_chunks = PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(
|
||||
raw_bytes
|
||||
)
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
kwargs: dict = {}
|
||||
if endpoint_type == EndpointType.ANTHROPIC:
|
||||
anthropic_passthrough_logging_handler_result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.VERTEX_AI:
|
||||
vertex_passthrough_logging_handler_result = VertexPassthroughLoggingHandler._handle_logging_vertex_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
model=model,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif endpoint_type == EndpointType.OPENAI:
|
||||
openai_passthrough_logging_handler_result = OpenAIPassthroughLoggingHandler._handle_logging_openai_collected_chunks(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
passthrough_success_handler_obj=passthrough_success_handler_obj,
|
||||
url_route=url_route,
|
||||
request_body=request_body,
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
all_chunks=all_chunks,
|
||||
end_time=end_time,
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
openai_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = openai_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=f"cannot parse chunks to standard response object. Chunks={all_chunks}"
|
||||
)
|
||||
await litellm_logging_obj.async_success_handler(
|
||||
result=standard_logging_response_object,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
if (
|
||||
litellm_logging_obj._should_run_sync_callbacks_for_async_calls()
|
||||
is False
|
||||
):
|
||||
return
|
||||
|
||||
executor.submit(
|
||||
litellm_logging_obj.success_handler,
|
||||
result=standard_logging_response_object,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
start_time=start_time,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error in _route_streaming_logging_to_handler: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_model_for_cost_injection(
|
||||
request_body: Optional[dict],
|
||||
url_route: str,
|
||||
endpoint_type: EndpointType,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract model name for cost injection from various sources.
|
||||
"""
|
||||
# Try to get model from request body
|
||||
if request_body:
|
||||
model = request_body.get("model")
|
||||
if model:
|
||||
return model
|
||||
|
||||
# Try to get model from logging object
|
||||
if hasattr(litellm_logging_obj, "model_call_details"):
|
||||
model = litellm_logging_obj.model_call_details.get("model")
|
||||
if model:
|
||||
return model
|
||||
|
||||
# For Vertex AI, try to extract from URL
|
||||
if endpoint_type == EndpointType.VERTEX_AI:
|
||||
model = VertexPassthroughLoggingHandler.extract_model_from_url(url_route)
|
||||
if model and model != "unknown":
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_raw_bytes_to_str_lines(raw_bytes: List[bytes]) -> List[str]:
|
||||
"""
|
||||
Converts a list of raw bytes into a list of string lines, similar to aiter_lines()
|
||||
|
||||
Args:
|
||||
raw_bytes: List of bytes chunks from aiter.bytes()
|
||||
|
||||
Returns:
|
||||
List of string lines, with each line being a complete data: {} chunk
|
||||
"""
|
||||
# Combine all bytes and decode to string
|
||||
combined_str = b"".join(raw_bytes).decode("utf-8")
|
||||
|
||||
# Split by newlines and filter out empty lines
|
||||
lines = [line.strip() for line in combined_str.split("\n") if line.strip()]
|
||||
|
||||
return lines
|
||||
@@ -0,0 +1,494 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.proxy._types import PassThroughEndpointLoggingResultValues
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
from litellm.types.utils import StandardPassThroughResponseObject
|
||||
from litellm.utils import executor as thread_pool_executor
|
||||
|
||||
from .llm_provider_handlers.anthropic_passthrough_logging_handler import (
|
||||
AnthropicPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.assembly_passthrough_logging_handler import (
|
||||
AssemblyAIPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.cohere_passthrough_logging_handler import (
|
||||
CoherePassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.cursor_passthrough_logging_handler import (
|
||||
CursorPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.gemini_passthrough_logging_handler import (
|
||||
GeminiPassthroughLoggingHandler,
|
||||
)
|
||||
from .llm_provider_handlers.vertex_passthrough_logging_handler import (
|
||||
VertexPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
cohere_passthrough_logging_handler = CoherePassthroughLoggingHandler()
|
||||
|
||||
|
||||
class PassThroughEndpointLogging:
|
||||
def __init__(self):
|
||||
self.TRACKED_VERTEX_ROUTES = [
|
||||
"generateContent",
|
||||
"streamGenerateContent",
|
||||
"predict",
|
||||
"rawPredict",
|
||||
"streamRawPredict",
|
||||
"search",
|
||||
"batchPredictionJobs",
|
||||
"predictLongRunning",
|
||||
]
|
||||
|
||||
# Anthropic
|
||||
self.TRACKED_ANTHROPIC_ROUTES = ["/messages", "/v1/messages/batches"]
|
||||
|
||||
# Cohere
|
||||
self.TRACKED_COHERE_ROUTES = ["/v2/chat", "/v1/embed"]
|
||||
self.assemblyai_passthrough_logging_handler = (
|
||||
AssemblyAIPassthroughLoggingHandler()
|
||||
)
|
||||
|
||||
# Langfuse
|
||||
self.TRACKED_LANGFUSE_ROUTES = ["/langfuse/"]
|
||||
|
||||
# Gemini
|
||||
self.TRACKED_GEMINI_ROUTES = [
|
||||
"generateContent",
|
||||
"streamGenerateContent",
|
||||
"predictLongRunning",
|
||||
]
|
||||
|
||||
# Cursor Cloud Agents
|
||||
self.TRACKED_CURSOR_ROUTES = [
|
||||
"/v0/agents",
|
||||
"/v0/me",
|
||||
"/v0/models",
|
||||
"/v0/repositories",
|
||||
]
|
||||
|
||||
# Vertex AI Live API WebSocket
|
||||
self.TRACKED_VERTEX_AI_LIVE_ROUTES = ["/vertex_ai/live"]
|
||||
|
||||
async def _handle_logging(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
standard_logging_response_object: Union[
|
||||
StandardPassThroughResponseObject,
|
||||
PassThroughEndpointLoggingResultValues,
|
||||
dict,
|
||||
],
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
**kwargs,
|
||||
):
|
||||
"""Helper function to handle both sync and async logging operations"""
|
||||
# Submit to thread pool for sync logging
|
||||
thread_pool_executor.submit(
|
||||
logging_obj.success_handler,
|
||||
standard_logging_response_object,
|
||||
start_time,
|
||||
end_time,
|
||||
cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Handle async logging
|
||||
await logging_obj.async_success_handler(
|
||||
result=(
|
||||
json.dumps(result)
|
||||
if isinstance(result, dict)
|
||||
else standard_logging_response_object
|
||||
),
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=False,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def normalize_llm_passthrough_logging_payload(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: Optional[dict],
|
||||
request_body: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
return_dict = {
|
||||
"standard_logging_response_object": None,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
standard_logging_response_object: Optional[Any] = None
|
||||
|
||||
if self.is_gemini_route(url_route, custom_llm_provider):
|
||||
gemini_passthrough_logging_handler_result = (
|
||||
GeminiPassthroughLoggingHandler.gemini_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
gemini_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = gemini_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_vertex_route(url_route):
|
||||
vertex_passthrough_logging_handler_result = (
|
||||
VertexPassthroughLoggingHandler.vertex_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
vertex_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = vertex_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_anthropic_route(url_route):
|
||||
anthropic_passthrough_logging_handler_result = (
|
||||
AnthropicPassthroughLoggingHandler.anthropic_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
standard_logging_response_object = (
|
||||
anthropic_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = anthropic_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_cohere_route(url_route):
|
||||
cohere_passthrough_logging_handler_result = (
|
||||
cohere_passthrough_logging_handler.cohere_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
cohere_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = cohere_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_openai_route(url_route) and self._is_supported_openai_endpoint(
|
||||
url_route
|
||||
):
|
||||
from .llm_provider_handlers.openai_passthrough_logging_handler import (
|
||||
OpenAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
openai_passthrough_logging_handler_result = (
|
||||
OpenAIPassthroughLoggingHandler.openai_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
openai_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = openai_passthrough_logging_handler_result["kwargs"]
|
||||
|
||||
elif self.is_cursor_route(url_route, custom_llm_provider):
|
||||
cursor_passthrough_logging_handler_result = (
|
||||
CursorPassthroughLoggingHandler.cursor_passthrough_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
cursor_passthrough_logging_handler_result["result"]
|
||||
)
|
||||
kwargs = cursor_passthrough_logging_handler_result["kwargs"]
|
||||
elif self.is_vertex_ai_live_route(url_route):
|
||||
from .llm_provider_handlers.vertex_ai_live_passthrough_logging_handler import (
|
||||
VertexAILivePassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
vertex_ai_live_handler = VertexAILivePassthroughLoggingHandler()
|
||||
|
||||
# For WebSocket responses, response_body should be a list of messages
|
||||
websocket_messages: list[dict[str, Any]] = (
|
||||
response_body if isinstance(response_body, list) else []
|
||||
)
|
||||
|
||||
vertex_ai_live_handler_result = (
|
||||
vertex_ai_live_handler.vertex_ai_live_passthrough_handler(
|
||||
websocket_messages=websocket_messages,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
request_body=request_body,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
standard_logging_response_object = vertex_ai_live_handler_result["result"]
|
||||
kwargs = vertex_ai_live_handler_result["kwargs"]
|
||||
return_dict[
|
||||
"standard_logging_response_object"
|
||||
] = standard_logging_response_object
|
||||
|
||||
return_dict["kwargs"] = kwargs
|
||||
return return_dict
|
||||
|
||||
async def pass_through_async_success_handler(
|
||||
self,
|
||||
httpx_response: httpx.Response,
|
||||
response_body: Optional[dict],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
url_route: str,
|
||||
result: str,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
cache_hit: bool,
|
||||
request_body: dict,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
standard_logging_response_object: Optional[
|
||||
PassThroughEndpointLoggingResultValues
|
||||
] = None
|
||||
logging_obj.model_call_details[
|
||||
"passthrough_logging_payload"
|
||||
] = passthrough_logging_payload
|
||||
if self.is_assemblyai_route(url_route):
|
||||
if (
|
||||
AssemblyAIPassthroughLoggingHandler._should_log_request(
|
||||
httpx_response.request.method
|
||||
)
|
||||
is not True
|
||||
):
|
||||
return
|
||||
self.assemblyai_passthrough_logging_handler.assemblyai_passthrough_logging_handler(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body or {},
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
**kwargs,
|
||||
)
|
||||
return
|
||||
elif self.is_langfuse_route(url_route):
|
||||
# Don't log langfuse pass-through requests
|
||||
return
|
||||
else:
|
||||
normalized_llm_passthrough_logging_payload = (
|
||||
self.normalize_llm_passthrough_logging_payload(
|
||||
httpx_response=httpx_response,
|
||||
response_body=response_body,
|
||||
request_body=request_body,
|
||||
logging_obj=logging_obj,
|
||||
url_route=url_route,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
standard_logging_response_object = (
|
||||
normalized_llm_passthrough_logging_payload[
|
||||
"standard_logging_response_object"
|
||||
]
|
||||
)
|
||||
kwargs = normalized_llm_passthrough_logging_payload["kwargs"]
|
||||
if standard_logging_response_object is None:
|
||||
standard_logging_response_object = StandardPassThroughResponseObject(
|
||||
response=httpx_response.text
|
||||
)
|
||||
|
||||
kwargs = self._set_cost_per_request(
|
||||
logging_obj=logging_obj,
|
||||
passthrough_logging_payload=passthrough_logging_payload,
|
||||
kwargs=kwargs,
|
||||
)
|
||||
|
||||
await self._handle_logging(
|
||||
logging_obj=logging_obj,
|
||||
standard_logging_response_object=standard_logging_response_object,
|
||||
result=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=cache_hit,
|
||||
standard_pass_through_logging_payload=passthrough_logging_payload,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def is_vertex_route(self, url_route: str):
|
||||
for route in self.TRACKED_VERTEX_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_anthropic_route(self, url_route: str):
|
||||
for route in self.TRACKED_ANTHROPIC_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_cohere_route(self, url_route: str):
|
||||
for route in self.TRACKED_COHERE_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
|
||||
def is_assemblyai_route(self, url_route: str):
|
||||
parsed_url = urlparse(url_route)
|
||||
if parsed_url.hostname == "api.assemblyai.com":
|
||||
return True
|
||||
elif "/transcript" in parsed_url.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_langfuse_route(self, url_route: str):
|
||||
parsed_url = urlparse(url_route)
|
||||
for route in self.TRACKED_LANGFUSE_ROUTES:
|
||||
if route in parsed_url.path:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_vertex_ai_live_route(self, url_route: str):
|
||||
"""Check if the URL route is a Vertex AI Live API WebSocket route."""
|
||||
if not url_route:
|
||||
return False
|
||||
for route in self.TRACKED_VERTEX_AI_LIVE_ROUTES:
|
||||
if route in url_route:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_cursor_route(
|
||||
self, url_route: str, custom_llm_provider: Optional[str] = None
|
||||
):
|
||||
"""Check if the URL route is a Cursor Cloud Agents API route."""
|
||||
if custom_llm_provider == "cursor":
|
||||
return True
|
||||
parsed_url = urlparse(url_route)
|
||||
if parsed_url.hostname and "api.cursor.com" in parsed_url.hostname:
|
||||
return True
|
||||
for route in self.TRACKED_CURSOR_ROUTES:
|
||||
if route in url_route:
|
||||
path = parsed_url.path if parsed_url.scheme else url_route
|
||||
if path.startswith("/v0/"):
|
||||
return custom_llm_provider == "cursor"
|
||||
return False
|
||||
|
||||
def is_openai_route(self, url_route: str):
|
||||
"""Check if the URL route is an OpenAI API route."""
|
||||
if not url_route:
|
||||
return False
|
||||
parsed_url = urlparse(url_route)
|
||||
return parsed_url.hostname and (
|
||||
"api.openai.com" in parsed_url.hostname
|
||||
or "openai.azure.com" in parsed_url.hostname
|
||||
)
|
||||
|
||||
def is_gemini_route(
|
||||
self, url_route: str, custom_llm_provider: Optional[str] = None
|
||||
):
|
||||
"""Check if the URL route is a Gemini API route."""
|
||||
for route in self.TRACKED_GEMINI_ROUTES:
|
||||
if route in url_route and custom_llm_provider == "gemini":
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_supported_openai_endpoint(self, url_route: str) -> bool:
|
||||
"""Check if the OpenAI endpoint is supported by the passthrough logging handler."""
|
||||
from .llm_provider_handlers.openai_passthrough_logging_handler import (
|
||||
OpenAIPassthroughLoggingHandler,
|
||||
)
|
||||
|
||||
return (
|
||||
OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route(url_route)
|
||||
or OpenAIPassthroughLoggingHandler.is_openai_image_generation_route(
|
||||
url_route
|
||||
)
|
||||
or OpenAIPassthroughLoggingHandler.is_openai_image_editing_route(url_route)
|
||||
)
|
||||
|
||||
def _set_cost_per_request(
|
||||
self,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
passthrough_logging_payload: PassthroughStandardLoggingPayload,
|
||||
kwargs: dict,
|
||||
):
|
||||
"""
|
||||
Helper function to set the cost per request in the logging object
|
||||
|
||||
Only set the cost per request if it's set in the passthrough logging payload.
|
||||
If it's not set, don't set it in the logging object.
|
||||
"""
|
||||
#########################################################
|
||||
# Check if cost per request is set
|
||||
#########################################################
|
||||
if passthrough_logging_payload.get("cost_per_request") is not None:
|
||||
kwargs["response_cost"] = passthrough_logging_payload.get(
|
||||
"cost_per_request"
|
||||
)
|
||||
logging_obj.model_call_details[
|
||||
"response_cost"
|
||||
] = passthrough_logging_payload.get("cost_per_request")
|
||||
|
||||
return kwargs
|
||||
Reference in New Issue
Block a user