Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/custom_httpx/container_handler.py
2026-03-26 20:06:14 +08:00

420 lines
14 KiB
Python

"""
Generic container file handler for LiteLLM.
This module provides a single generic handler that can process any container file
endpoint defined in endpoints.json, eliminating the need for individual handler methods.
"""
import json
from pathlib import Path
from typing import TYPE_CHECKING, Any, Coroutine, Dict, Optional, Type, Union
import httpx
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.containers.main import (
ContainerFileListResponse,
ContainerFileObject,
DeleteContainerFileResponse,
)
from litellm.types.router import GenericLiteLLMParams
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.base_llm.containers.transformation import BaseContainerConfig
# Response type mapping
RESPONSE_TYPES: Dict[str, Type] = {
"ContainerFileListResponse": ContainerFileListResponse,
"ContainerFileObject": ContainerFileObject,
"DeleteContainerFileResponse": DeleteContainerFileResponse,
}
def _load_endpoints_config() -> Dict:
"""Load the endpoints configuration from JSON file."""
config_path = Path(__file__).parent.parent.parent / "containers" / "endpoints.json"
with open(config_path) as f:
return json.load(f)
def _get_endpoint_config(endpoint_name: str) -> Optional[Dict]:
"""Get config for a specific endpoint by name."""
config = _load_endpoints_config()
for endpoint in config["endpoints"]:
if endpoint["name"] == endpoint_name or endpoint["async_name"] == endpoint_name:
return endpoint
return None
def _build_url(
api_base: str,
path_template: str,
path_params: Dict[str, str],
) -> str:
"""Build the full URL by substituting path parameters.
The api_base from get_complete_url already includes /containers,
so we need to strip that prefix from the path_template.
"""
# api_base ends with /containers, path_template starts with /containers
# So we need to strip /containers from the path
if path_template.startswith("/containers"):
path_template = path_template[len("/containers") :]
url = f"{api_base.rstrip('/')}{path_template}"
for param, value in path_params.items():
url = url.replace(f"{{{param}}}", value)
return url
def _build_query_params(
query_param_names: list,
kwargs: Dict[str, Any],
) -> Dict[str, str]:
"""Build query parameters from kwargs."""
params = {}
for param_name in query_param_names:
value = kwargs.get(param_name)
if value is not None:
params[param_name] = str(value) if not isinstance(value, str) else value
return params
def _prepare_multipart_file_upload(
file: Any,
headers: Dict[str, Any],
) -> tuple:
"""
Prepare file and headers for multipart upload.
Returns:
Tuple of (files_dict, headers_without_content_type)
"""
from litellm.litellm_core_utils.prompt_templates.common_utils import (
extract_file_data,
)
extracted = extract_file_data(file)
filename = extracted.get("filename") or "file"
content = extracted.get("content") or b""
content_type = extracted.get("content_type") or "application/octet-stream"
files = {"file": (filename, content, content_type)}
# Remove content-type header - httpx will set it automatically for multipart
headers_copy = headers.copy()
headers_copy.pop("content-type", None)
headers_copy.pop("Content-Type", None)
return files, headers_copy
class GenericContainerHandler:
"""
Generic handler for container file API endpoints.
This single handler can process any endpoint defined in endpoints.json,
eliminating the need for individual handler methods per endpoint.
"""
def handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
_is_async: bool = False,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Union[Any, Coroutine[Any, Any, Any]]:
"""
Generic handler for any container file endpoint.
Args:
endpoint_name: Name of the endpoint (e.g., "list_container_files")
container_provider_config: Provider-specific configuration
litellm_params: LiteLLM parameters including api_key, api_base
logging_obj: Logging object for request logging
extra_headers: Additional HTTP headers
extra_query: Additional query parameters
timeout: Request timeout
_is_async: Whether to make async request
client: Optional HTTP client
**kwargs: Path params and query params (e.g., container_id, file_id, after, limit)
"""
if _is_async:
return self._async_handle(
endpoint_name=endpoint_name,
container_provider_config=container_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
extra_headers=extra_headers,
extra_query=extra_query,
timeout=timeout,
client=client,
**kwargs,
)
return self._sync_handle(
endpoint_name=endpoint_name,
container_provider_config=container_provider_config,
litellm_params=litellm_params,
logging_obj=logging_obj,
extra_headers=extra_headers,
extra_query=extra_query,
timeout=timeout,
client=client,
**kwargs,
)
def _sync_handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Any:
"""Synchronous request handler."""
endpoint_config = _get_endpoint_config(endpoint_name)
if not endpoint_config:
raise ValueError(f"Unknown endpoint: {endpoint_name}")
# Get HTTP client
if client is None or not isinstance(client, HTTPHandler):
http_client = _get_httpx_client(
params={"ssl_verify": litellm_params.get("ssl_verify", None)}
)
else:
http_client = client
# Build request
headers = container_provider_config.validate_environment(
headers=extra_headers or {},
api_key=litellm_params.get("api_key", None),
)
if extra_headers:
headers.update(extra_headers)
api_base = container_provider_config.get_complete_url(
api_base=litellm_params.get("api_base", None),
litellm_params=dict(litellm_params),
)
# Build URL with path params
path_params = {
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
}
url = _build_url(api_base, endpoint_config["path"], path_params)
# Build query params
query_params = _build_query_params(
endpoint_config.get("query_params", []), kwargs
)
if extra_query:
query_params.update(extra_query)
# Log request
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"api_base": url,
"headers": headers,
"params": query_params,
},
)
# Make request
method = endpoint_config["method"].upper()
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)
try:
if method == "GET":
response = http_client.get(
url=url, headers=headers, params=query_params
)
elif method == "DELETE":
response = http_client.delete(
url=url, headers=headers, params=query_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = http_client.post(
url=url, headers=headers, params=query_params, files=files
)
else:
response = http_client.post(
url=url, headers=headers, params=query_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# For binary responses, return raw content
if returns_binary:
return response.content
# Check for error response
response_json = response.json()
if "error" in response_json:
from litellm.llms.base_llm.chat.transformation import BaseLLMException
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
raise BaseLLMException(
status_code=response.status_code,
message=error_msg,
headers=dict(response.headers),
)
# Parse response
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
if response_type:
return response_type(**response_json)
return response_json
except Exception as e:
raise e
async def _async_handle(
self,
endpoint_name: str,
container_provider_config: "BaseContainerConfig",
litellm_params: GenericLiteLLMParams,
logging_obj: "LiteLLMLoggingObj",
extra_headers: Optional[Dict[str, Any]] = None,
extra_query: Optional[Dict[str, Any]] = None,
timeout: Union[float, httpx.Timeout] = 600,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
**kwargs,
) -> Any:
"""Asynchronous request handler."""
endpoint_config = _get_endpoint_config(endpoint_name)
if not endpoint_config:
raise ValueError(f"Unknown endpoint: {endpoint_name}")
# Get HTTP client
if client is None or not isinstance(client, AsyncHTTPHandler):
http_client = get_async_httpx_client(
llm_provider=litellm.LlmProviders.OPENAI,
params={"ssl_verify": litellm_params.get("ssl_verify", None)},
)
else:
http_client = client
# Build request
headers = container_provider_config.validate_environment(
headers=extra_headers or {},
api_key=litellm_params.get("api_key", None),
)
if extra_headers:
headers.update(extra_headers)
api_base = container_provider_config.get_complete_url(
api_base=litellm_params.get("api_base", None),
litellm_params=dict(litellm_params),
)
# Build URL with path params
path_params = {
p: kwargs.get(p, "") for p in endpoint_config.get("path_params", [])
}
url = _build_url(api_base, endpoint_config["path"], path_params)
# Build query params
query_params = _build_query_params(
endpoint_config.get("query_params", []), kwargs
)
if extra_query:
query_params.update(extra_query)
# Log request
logging_obj.pre_call(
input="",
api_key="",
additional_args={
"api_base": url,
"headers": headers,
"params": query_params,
},
)
# Make request
method = endpoint_config["method"].upper()
returns_binary = endpoint_config.get("returns_binary", False)
is_multipart = endpoint_config.get("is_multipart", False)
try:
if method == "GET":
response = await http_client.get(
url=url, headers=headers, params=query_params
)
elif method == "DELETE":
response = await http_client.delete(
url=url, headers=headers, params=query_params
)
elif method == "POST":
if is_multipart and "file" in kwargs:
files, headers = _prepare_multipart_file_upload(
kwargs["file"], headers
)
response = await http_client.post(
url=url, headers=headers, params=query_params, files=files
)
else:
response = await http_client.post(
url=url, headers=headers, params=query_params
)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
# For binary responses, return raw content
if returns_binary:
return response.content
# Check for error response
response_json = response.json()
if "error" in response_json:
from litellm.llms.base_llm.chat.transformation import BaseLLMException
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
raise BaseLLMException(
status_code=response.status_code,
message=error_msg,
headers=dict(response.headers),
)
# Parse response
response_type = RESPONSE_TYPES.get(endpoint_config["response_type"])
if response_type:
return response_type(**response_json)
return response_json
except Exception as e:
raise e
# Singleton instance
generic_container_handler = GenericContainerHandler()