chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,536 @@
|
||||
"""Abstraction function for OpenAI's realtime API"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional, cast
|
||||
|
||||
import litellm
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES, request_timeout
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.llms.base_llm.realtime.transformation import BaseRealtimeConfig
|
||||
from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.realtime import (
|
||||
RealtimeClientSecretRequest,
|
||||
RealtimeExpiresAfter,
|
||||
RealtimeQueryParams,
|
||||
RealtimeSessionConfig,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ..litellm_core_utils.get_litellm_params import get_litellm_params
|
||||
from ..litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ..llms.azure.realtime.handler import AzureOpenAIRealtime
|
||||
from ..llms.bedrock.realtime.handler import BedrockRealtime
|
||||
from ..llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
|
||||
from ..llms.openai.realtime.handler import OpenAIRealtime
|
||||
from ..llms.vertex_ai.realtime.transformation import VertexAIRealtimeConfig
|
||||
from ..llms.vertex_ai.vertex_llm_base import VertexBase
|
||||
from ..llms.xai.realtime.handler import XAIRealtime
|
||||
from ..utils import client as wrapper_client
|
||||
|
||||
azure_realtime = AzureOpenAIRealtime()
|
||||
openai_realtime = OpenAIRealtime()
|
||||
bedrock_realtime = BedrockRealtime()
|
||||
xai_realtime = XAIRealtime()
|
||||
vertex_llm_base = VertexBase()
|
||||
base_llm_http_handler = BaseLLMHTTPHandler()
|
||||
|
||||
|
||||
def _build_litellm_metadata(kwargs: dict) -> dict:
|
||||
"""Build the litellm_metadata dict for guardrail checking (internal only, not forwarded to provider)."""
|
||||
metadata: dict = {**(kwargs.get("litellm_metadata") or {})}
|
||||
guardrails = (
|
||||
(kwargs.get("metadata") or {}).get("guardrails")
|
||||
or kwargs.get("guardrails")
|
||||
or []
|
||||
)
|
||||
if guardrails:
|
||||
metadata["guardrails"] = guardrails
|
||||
return metadata
|
||||
|
||||
|
||||
def _get_realtime_http_provider_config(
|
||||
custom_llm_provider: str,
|
||||
dynamic_api_base: Optional[str],
|
||||
dynamic_api_key: Optional[str],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
) -> tuple[Any, str, str]:
|
||||
"""
|
||||
Return (provider_config, resolved_api_base, resolved_api_key) for the
|
||||
realtime HTTP endpoints (client_secrets / realtime_calls).
|
||||
|
||||
Uses ProviderConfigManager so each provider keeps its credential-resolution
|
||||
and URL-construction logic in its own transformation class.
|
||||
"""
|
||||
from litellm.llms.base_llm.realtime.http_transformation import (
|
||||
BaseRealtimeHTTPConfig,
|
||||
)
|
||||
|
||||
provider_config: Optional[BaseRealtimeHTTPConfig] = None
|
||||
if custom_llm_provider in LlmProviders._member_map_.values():
|
||||
provider_config = ProviderConfigManager.get_provider_realtime_http_config(
|
||||
model="",
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
raw_api_base = dynamic_api_base or litellm_params.api_base
|
||||
raw_api_key = dynamic_api_key or litellm_params.api_key
|
||||
|
||||
if provider_config is not None:
|
||||
resolved_api_base = provider_config.get_api_base(api_base=raw_api_base)
|
||||
resolved_api_key = provider_config.get_api_key(api_key=raw_api_key)
|
||||
else:
|
||||
# Fallback for providers without a dedicated HTTP config (treated as OpenAI-compatible).
|
||||
resolved_api_base = raw_api_base or litellm.api_base or "https://api.openai.com"
|
||||
resolved_api_key = (
|
||||
raw_api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
return provider_config, resolved_api_base.rstrip("/"), resolved_api_key
|
||||
|
||||
|
||||
@wrapper_client
|
||||
async def acreate_realtime_client_secret(
|
||||
model: Optional[str] = None,
|
||||
session: Optional[Dict[str, Any]] = None,
|
||||
expires_after: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
req = RealtimeClientSecretRequest(
|
||||
model=model,
|
||||
session=RealtimeSessionConfig(**session) if session else None,
|
||||
expires_after=RealtimeExpiresAfter(**expires_after) if expires_after else None,
|
||||
)
|
||||
model_name = (
|
||||
(req.session.model if req.session is not None else None)
|
||||
or req.model
|
||||
or "gpt-4o-realtime-preview"
|
||||
)
|
||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
(
|
||||
model_name,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = get_llm_provider(
|
||||
model=model_name,
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
)
|
||||
(
|
||||
provider_config,
|
||||
resolved_api_base,
|
||||
resolved_api_key,
|
||||
) = _get_realtime_http_provider_config(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
dynamic_api_base=dynamic_api_base,
|
||||
dynamic_api_key=dynamic_api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model_name,
|
||||
optional_params={"expires_after": expires_after, "session": session},
|
||||
litellm_params={"api_base": resolved_api_base},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
request_data = req.model_dump(exclude_none=True, exclude={"model"})
|
||||
return await base_llm_http_handler.async_realtime_client_secret_handler(
|
||||
api_base=resolved_api_base,
|
||||
api_key=resolved_api_key,
|
||||
request_data=request_data,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=timeout or request_timeout,
|
||||
provider_config=provider_config,
|
||||
model=model_name,
|
||||
extra_headers=kwargs.get("extra_headers"),
|
||||
client=kwargs.get("client"),
|
||||
api_version=litellm_params.api_version,
|
||||
)
|
||||
|
||||
|
||||
@wrapper_client
|
||||
async def arealtime_calls(
|
||||
openai_ephemeral_key: str,
|
||||
sdp_body: bytes,
|
||||
model: Optional[str] = None,
|
||||
session: Optional[Dict[str, Any]] = None,
|
||||
timeout: Optional[float] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_name = model or "gpt-4o-realtime-preview"
|
||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
(
|
||||
model_name,
|
||||
custom_llm_provider,
|
||||
dynamic_api_key,
|
||||
dynamic_api_base,
|
||||
) = get_llm_provider(
|
||||
model=model_name,
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
)
|
||||
provider_config, resolved_api_base, _ = _get_realtime_http_provider_config(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
dynamic_api_base=dynamic_api_base,
|
||||
dynamic_api_key=dynamic_api_key,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model_name,
|
||||
optional_params={"realtime_calls": True, "session": session},
|
||||
litellm_params={"api_base": resolved_api_base},
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return await base_llm_http_handler.async_realtime_calls_handler(
|
||||
api_base=resolved_api_base,
|
||||
openai_ephemeral_key=openai_ephemeral_key,
|
||||
sdp_body=sdp_body,
|
||||
logging_obj=litellm_logging_obj,
|
||||
timeout=timeout or request_timeout,
|
||||
provider_config=provider_config,
|
||||
model=model_name,
|
||||
session_config=session,
|
||||
extra_headers=kwargs.get("extra_headers"),
|
||||
client=kwargs.get("client"),
|
||||
api_version=litellm_params.api_version,
|
||||
)
|
||||
|
||||
|
||||
@wrapper_client
|
||||
async def _arealtime( # noqa: PLR0915
|
||||
model: str,
|
||||
websocket: Any, # fastapi websocket
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
query_params: Optional[RealtimeQueryParams] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Private function to handle the realtime API call.
|
||||
|
||||
For PROXY use only.
|
||||
"""
|
||||
headers = cast(Optional[dict], kwargs.get("headers"))
|
||||
extra_headers = cast(Optional[dict], kwargs.get("extra_headers"))
|
||||
if headers is None:
|
||||
headers = {}
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
litellm_logging_obj: LiteLLMLogging = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
user = kwargs.get("user", None)
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
litellm_params_dict = get_litellm_params(**kwargs)
|
||||
|
||||
model, _custom_llm_provider, dynamic_api_key, dynamic_api_base = get_llm_provider(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
# Ensure query params use the normalized provider model (no proxy aliases).
|
||||
if query_params is not None:
|
||||
query_params = {**query_params, "model": model}
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params={},
|
||||
litellm_params=litellm_params_dict,
|
||||
custom_llm_provider=_custom_llm_provider,
|
||||
)
|
||||
|
||||
provider_config: Optional[BaseRealtimeConfig] = None
|
||||
if _custom_llm_provider in LlmProviders._member_map_.values():
|
||||
provider_config = ProviderConfigManager.get_provider_realtime_config(
|
||||
model=model,
|
||||
provider=LlmProviders(_custom_llm_provider),
|
||||
)
|
||||
if provider_config is not None:
|
||||
await base_llm_http_handler.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
provider_config=provider_config,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
user_api_key_dict=kwargs.get("user_api_key_dict"),
|
||||
litellm_metadata=_build_litellm_metadata(kwargs),
|
||||
)
|
||||
elif _custom_llm_provider == "azure":
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or litellm_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
dynamic_api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
api_version = api_version or litellm_params.api_version or "2024-10-01-preview"
|
||||
|
||||
realtime_protocol = (
|
||||
kwargs.get("realtime_protocol")
|
||||
or litellm_params.get("realtime_protocol")
|
||||
or os.environ.get("LITELLM_AZURE_REALTIME_PROTOCOL")
|
||||
or "beta"
|
||||
)
|
||||
await azure_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
api_version=api_version,
|
||||
azure_ad_token=None,
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
logging_obj=litellm_logging_obj,
|
||||
realtime_protocol=realtime_protocol,
|
||||
user_api_key_dict=kwargs.get("user_api_key_dict"),
|
||||
litellm_metadata=_build_litellm_metadata(kwargs),
|
||||
)
|
||||
elif _custom_llm_provider == "openai":
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or litellm_params.api_base
|
||||
or litellm.api_base
|
||||
or "https://api.openai.com/"
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
dynamic_api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
await openai_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
query_params=query_params,
|
||||
user_api_key_dict=kwargs.get("user_api_key_dict"),
|
||||
litellm_metadata=_build_litellm_metadata(kwargs),
|
||||
)
|
||||
elif _custom_llm_provider == "bedrock":
|
||||
# Extract AWS parameters from kwargs
|
||||
aws_region_name = kwargs.get("aws_region_name")
|
||||
aws_access_key_id = kwargs.get("aws_access_key_id")
|
||||
aws_secret_access_key = kwargs.get("aws_secret_access_key")
|
||||
aws_session_token = kwargs.get("aws_session_token")
|
||||
aws_role_name = kwargs.get("aws_role_name")
|
||||
aws_session_name = kwargs.get("aws_session_name")
|
||||
aws_profile_name = kwargs.get("aws_profile_name")
|
||||
aws_web_identity_token = kwargs.get("aws_web_identity_token")
|
||||
aws_sts_endpoint = kwargs.get("aws_sts_endpoint")
|
||||
aws_bedrock_runtime_endpoint = kwargs.get("aws_bedrock_runtime_endpoint")
|
||||
aws_external_id = kwargs.get("aws_external_id")
|
||||
|
||||
await bedrock_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
api_base=dynamic_api_base or api_base,
|
||||
api_key=dynamic_api_key or api_key,
|
||||
timeout=timeout,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_external_id=aws_external_id,
|
||||
)
|
||||
elif _custom_llm_provider == "xai":
|
||||
api_base = (
|
||||
dynamic_api_base
|
||||
or litellm_params.api_base
|
||||
or get_secret_str("XAI_API_BASE")
|
||||
or "https://api.x.ai/v1"
|
||||
)
|
||||
# set API KEY
|
||||
api_key = dynamic_api_key or litellm.api_key or get_secret_str("XAI_API_KEY")
|
||||
|
||||
await xai_realtime.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
client=None,
|
||||
timeout=timeout,
|
||||
query_params=query_params,
|
||||
user_api_key_dict=kwargs.get("user_api_key_dict"),
|
||||
litellm_metadata=_build_litellm_metadata(kwargs),
|
||||
)
|
||||
elif _custom_llm_provider == "vertex_ai":
|
||||
vertex_credentials = (
|
||||
kwargs.get("vertex_credentials")
|
||||
or kwargs.get("vertex_ai_credentials")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
vertex_project = (
|
||||
kwargs.get("vertex_project")
|
||||
or kwargs.get("vertex_ai_project")
|
||||
or litellm.vertex_project
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_location = (
|
||||
kwargs.get("vertex_location")
|
||||
or kwargs.get("vertex_ai_location")
|
||||
or litellm.vertex_location
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
|
||||
resolved_location = vertex_llm_base.get_vertex_region(
|
||||
vertex_region=vertex_location, model=model
|
||||
)
|
||||
|
||||
(
|
||||
access_token,
|
||||
resolved_project,
|
||||
) = await vertex_llm_base._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
|
||||
vertex_realtime_config = VertexAIRealtimeConfig(
|
||||
access_token=access_token,
|
||||
project=resolved_project,
|
||||
location=resolved_location,
|
||||
)
|
||||
|
||||
await base_llm_http_handler.async_realtime(
|
||||
model=model,
|
||||
websocket=websocket,
|
||||
logging_obj=litellm_logging_obj,
|
||||
provider_config=vertex_realtime_config,
|
||||
api_base=dynamic_api_base or litellm_params.api_base,
|
||||
api_key=None,
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
user_api_key_dict=kwargs.get("user_api_key_dict"),
|
||||
litellm_metadata=_build_litellm_metadata(kwargs),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
|
||||
|
||||
async def _realtime_health_check(
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
realtime_protocol: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Health check for realtime API - tries connection to the realtime API websocket
|
||||
|
||||
Args:
|
||||
model: str - model name
|
||||
api_base: str - api base
|
||||
api_version: Optional[str] - api version
|
||||
api_key: str - api key
|
||||
custom_llm_provider: str - custom llm provider
|
||||
realtime_protocol: Optional[str] - protocol version ("GA"/"v1" for GA path, "beta"/None for beta path)
|
||||
|
||||
Returns:
|
||||
bool - True if connection is successful, False otherwise
|
||||
Raises:
|
||||
Exception - if the connection is not successful
|
||||
"""
|
||||
import websockets
|
||||
|
||||
url: Optional[str] = None
|
||||
if custom_llm_provider == "azure":
|
||||
url = azure_realtime._construct_url(
|
||||
api_base=api_base or "",
|
||||
model=model,
|
||||
api_version=api_version or "2024-10-01-preview",
|
||||
realtime_protocol=realtime_protocol,
|
||||
)
|
||||
elif custom_llm_provider == "openai":
|
||||
url = openai_realtime._construct_url(
|
||||
api_base=api_base or "https://api.openai.com/",
|
||||
query_params={"model": model},
|
||||
)
|
||||
elif custom_llm_provider == "xai":
|
||||
url = xai_realtime._construct_url(
|
||||
api_base=api_base or "https://api.x.ai/v1", query_params={"model": model}
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_location = litellm.vertex_location or get_secret_str("VERTEXAI_LOCATION")
|
||||
resolved_location = vertex_llm_base.get_vertex_region(
|
||||
vertex_region=vertex_location, model=model
|
||||
)
|
||||
(
|
||||
access_token,
|
||||
resolved_project,
|
||||
) = await vertex_llm_base._ensure_access_token_async(
|
||||
credentials=None,
|
||||
project_id=litellm.vertex_project or get_secret_str("VERTEXAI_PROJECT"),
|
||||
custom_llm_provider="vertex_ai",
|
||||
)
|
||||
vertex_realtime_config = VertexAIRealtimeConfig(
|
||||
access_token=access_token,
|
||||
project=resolved_project,
|
||||
location=resolved_location,
|
||||
)
|
||||
url = vertex_realtime_config.get_complete_url(api_base=api_base, model=model)
|
||||
ssl_context = get_shared_realtime_ssl_context()
|
||||
headers = vertex_realtime_config.validate_environment(
|
||||
headers={}, model=model, api_key=None
|
||||
)
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
additional_headers=headers,
|
||||
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(f"Unsupported model: {model}")
|
||||
ssl_context = get_shared_realtime_ssl_context()
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
additional_headers={
|
||||
"api-key": api_key, # type: ignore
|
||||
},
|
||||
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
|
||||
ssl=ssl_context,
|
||||
):
|
||||
return True
|
||||
Reference in New Issue
Block a user