chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,270 @@
|
||||
"""Support for OpenAI gpt-5 model family."""
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.utils import _supports_factory
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
def _normalize_reasoning_effort_for_chat_completion(
|
||||
value: Union[str, dict, None],
|
||||
) -> Optional[str]:
|
||||
"""Convert reasoning_effort to the string format expected by OpenAI chat completion API.
|
||||
|
||||
The chat completion API expects a simple string: 'none', 'low', 'medium', 'high', or 'xhigh'.
|
||||
Config/deployments may pass the Responses API format: {'effort': 'high', 'summary': 'detailed'}.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict) and "effort" in value:
|
||||
return value["effort"]
|
||||
return None
|
||||
|
||||
|
||||
def _get_effort_level(value: Union[str, dict, None]) -> Optional[str]:
|
||||
"""Extract the effective effort level from reasoning_effort (string or dict).
|
||||
|
||||
Use this for guards that compare effort level (e.g. xhigh validation, "none" checks).
|
||||
Ensures dict inputs like {"effort": "none", "summary": "detailed"} are correctly
|
||||
treated as effort="none" for validation purposes.
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
if isinstance(value, dict) and "effort" in value:
|
||||
return value["effort"]
|
||||
return None
|
||||
|
||||
|
||||
class OpenAIGPT5Config(OpenAIGPTConfig):
|
||||
"""Configuration for gpt-5 models including GPT-5-Codex variants.
|
||||
|
||||
Handles OpenAI API quirks for the gpt-5 series like:
|
||||
|
||||
- Mapping ``max_tokens`` -> ``max_completion_tokens``.
|
||||
- Dropping unsupported ``temperature`` values when requested.
|
||||
- Support for GPT-5-Codex models optimized for code generation.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_model(cls, model: str) -> bool:
|
||||
# gpt-5-chat* behaves like a regular chat model (supports temperature, etc.)
|
||||
# Don't route it through GPT-5 reasoning-specific parameter restrictions.
|
||||
return "gpt-5" in model and "gpt-5-chat" not in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_search_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a GPT-5 search variant (e.g. gpt-5-search-api).
|
||||
|
||||
Search-only models have a severely restricted parameter set compared to
|
||||
regular GPT-5 models. They are identified by name convention (contain
|
||||
both ``gpt-5`` and ``search``). Note: ``supports_web_search`` in model
|
||||
info is a *different* concept — it indicates a model can *use* web
|
||||
search as a tool, which many non-search-only models also support.
|
||||
"""
|
||||
return "gpt-5" in model and "search" in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_codex_model(cls, model: str) -> bool:
|
||||
"""Check if the model is specifically a GPT-5 Codex variant."""
|
||||
return "gpt-5-codex" in model
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_2_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a gpt-5.2 variant (including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
return model_name.startswith("gpt-5.2") or model_name.startswith("gpt-5.4")
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_4_model(cls, model: str) -> bool:
|
||||
"""Check if the model is a gpt-5.4 variant (including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
return model_name.startswith("gpt-5.4")
|
||||
|
||||
@classmethod
|
||||
def is_model_gpt_5_4_plus_model(cls, model: str) -> bool:
|
||||
"""Check if the model is gpt-5.4 or newer (5.4, 5.5, 5.6, etc., including pro)."""
|
||||
model_name = model.split("/")[-1]
|
||||
if not model_name.startswith("gpt-5."):
|
||||
return False
|
||||
try:
|
||||
version_str = model_name.replace("gpt-5.", "").split("-")[0]
|
||||
major = version_str.split(".")[0]
|
||||
return int(major) >= 4
|
||||
except (ValueError, IndexError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _supports_reasoning_effort_level(cls, model: str, level: str) -> bool:
|
||||
"""Check if the model supports a specific reasoning_effort level.
|
||||
|
||||
Looks up ``supports_{level}_reasoning_effort`` in the model map via
|
||||
the shared ``_supports_factory`` helper.
|
||||
Returns False for unknown models (safe fallback).
|
||||
"""
|
||||
return _supports_factory(
|
||||
model=model,
|
||||
custom_llm_provider=None,
|
||||
key=f"supports_{level}_reasoning_effort",
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
if self.is_model_gpt_5_search_model(model):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"web_search_options",
|
||||
"service_tier",
|
||||
"safety_identifier",
|
||||
"response_format",
|
||||
"user",
|
||||
"store",
|
||||
"verbosity",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
from litellm.utils import supports_tool_choice
|
||||
|
||||
base_gpt_series_params = super().get_supported_openai_params(model=model)
|
||||
gpt_5_only_params = ["reasoning_effort", "verbosity"]
|
||||
base_gpt_series_params.extend(gpt_5_only_params)
|
||||
if not supports_tool_choice(model=model):
|
||||
base_gpt_series_params.remove("tool_choice")
|
||||
|
||||
non_supported_params = [
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"stop",
|
||||
"logit_bias",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"audio",
|
||||
"web_search_options",
|
||||
]
|
||||
|
||||
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs when reasoning_effort="none"
|
||||
if not self._supports_reasoning_effort_level(model, "none"):
|
||||
non_supported_params.extend(["logprobs", "top_p", "top_logprobs"])
|
||||
|
||||
return [
|
||||
param
|
||||
for param in base_gpt_series_params
|
||||
if param not in non_supported_params
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
if self.is_model_gpt_5_search_model(model):
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
# Get raw reasoning_effort and effective effort level for all guards.
|
||||
# Use effective_effort (extracted string) for xhigh validation, "none" checks, and
|
||||
# tool/sampling guards — dict inputs like {"effort": "none", "summary": "detailed"}
|
||||
# must be treated as effort="none" to avoid incorrect tool-drop or sampling errors.
|
||||
raw_reasoning_effort = non_default_params.get(
|
||||
"reasoning_effort"
|
||||
) or optional_params.get("reasoning_effort")
|
||||
effective_effort = _get_effort_level(raw_reasoning_effort)
|
||||
|
||||
# Normalize dict reasoning_effort to string for Chat Completions API.
|
||||
# Example: {"effort": "high", "summary": "detailed"} -> "high"
|
||||
if isinstance(raw_reasoning_effort, dict) and "effort" in raw_reasoning_effort:
|
||||
normalized = _normalize_reasoning_effort_for_chat_completion(
|
||||
raw_reasoning_effort
|
||||
)
|
||||
if normalized is not None:
|
||||
if "reasoning_effort" in non_default_params:
|
||||
non_default_params["reasoning_effort"] = normalized
|
||||
if "reasoning_effort" in optional_params:
|
||||
optional_params["reasoning_effort"] = normalized
|
||||
|
||||
if effective_effort is not None and effective_effort == "xhigh":
|
||||
if not self._supports_reasoning_effort_level(model, "xhigh"):
|
||||
if litellm.drop_params or drop_params:
|
||||
non_default_params.pop("reasoning_effort", None)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"reasoning_effort='xhigh' is only supported for gpt-5.1-codex-max, gpt-5.2, and gpt-5.4+ models."
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
################################################################
|
||||
# max_tokens is not supported for gpt-5 models on OpenAI API
|
||||
# Relevant issue: https://github.com/BerriAI/litellm/issues/13381
|
||||
################################################################
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
|
||||
# gpt-5.1/5.2 support logprobs, top_p, top_logprobs only when reasoning_effort="none"
|
||||
supports_none = self._supports_reasoning_effort_level(model, "none")
|
||||
if supports_none:
|
||||
sampling_params = ["logprobs", "top_logprobs", "top_p"]
|
||||
has_sampling = any(p in non_default_params for p in sampling_params)
|
||||
if has_sampling and effective_effort not in (None, "none"):
|
||||
if litellm.drop_params or drop_params:
|
||||
for p in sampling_params:
|
||||
non_default_params.pop(p, None)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"gpt-5.1/5.2/5.4 only support logprobs, top_p, top_logprobs when "
|
||||
"reasoning_effort='none'. Current reasoning_effort='{}'. "
|
||||
"To drop unsupported params set `litellm.drop_params = True`"
|
||||
).format(effective_effort),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
if "temperature" in non_default_params:
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
# models supporting reasoning_effort="none" also support flexible temperature
|
||||
if supports_none and (
|
||||
effective_effort == "none" or effective_effort is None
|
||||
):
|
||||
optional_params["temperature"] = temperature_value
|
||||
elif temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
elif litellm.drop_params or drop_params:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message=(
|
||||
"gpt-5 models (including gpt-5-codex) don't support temperature={}. "
|
||||
"Only temperature=1 is supported. "
|
||||
"For gpt-5.1, temperature is supported when reasoning_effort='none' (or not specified, as it defaults to 'none'). "
|
||||
"To drop unsupported params set `litellm.drop_params = True`"
|
||||
).format(temperature_value),
|
||||
status_code=400,
|
||||
)
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Support for GPT-4o audio Family
|
||||
|
||||
OpenAI Doc: https://platform.openai.com/docs/guides/audio/quickstart?audio-generation-quickstart-example=audio-in&lang=python
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIGPTAudioConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/audio
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the `gpt-audio` models
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
audio_specific_params = ["audio"]
|
||||
return all_openai_params + audio_specific_params
|
||||
|
||||
def is_model_gpt_audio_model(self, model: str) -> bool:
|
||||
if model in litellm.open_ai_chat_completion_models and "audio" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return super()._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,819 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Coroutine,
|
||||
Iterator,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_extract_reasoning_content,
|
||||
_handle_invalid_parallel_tool_calls,
|
||||
_should_convert_tool_call_to_json_mode,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import get_tool_call_names
|
||||
from litellm.litellm_core_utils.prompt_templates.image_handling import (
|
||||
async_convert_url_to_base64,
|
||||
convert_url_to_base64,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionFileObject,
|
||||
ChatCompletionFileObjectFile,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionImageUrlObject,
|
||||
OpenAIChatCompletionChoices,
|
||||
OpenAIMessageContentListBlock,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Choices,
|
||||
Function,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
)
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters:
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition.
|
||||
|
||||
- `function_call` (string or object): This optional parameter controls how the model calls functions.
|
||||
|
||||
- `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
# Add a class variable to track if this is the base class
|
||||
_is_base_class = True
|
||||
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
self.__class__._is_base_class = False
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = [
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"modalities",
|
||||
"prediction",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"seed",
|
||||
"stop",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"function_call",
|
||||
"functions",
|
||||
"max_retries",
|
||||
"extra_headers",
|
||||
"parallel_tool_calls",
|
||||
"audio",
|
||||
"web_search_options",
|
||||
"service_tier",
|
||||
"safety_identifier",
|
||||
"prompt_cache_key",
|
||||
"prompt_cache_retention",
|
||||
"store",
|
||||
] # works across all models
|
||||
|
||||
model_specific_params = []
|
||||
if (
|
||||
model != "gpt-3.5-turbo-16k" and model != "gpt-4"
|
||||
): # gpt-4 does not support 'response_format'
|
||||
model_specific_params.append("response_format")
|
||||
|
||||
# Normalize model name for responses API (e.g., "responses/gpt-4.1" -> "gpt-4.1")
|
||||
model_for_check = (
|
||||
model.split("responses/", 1)[1] if "responses/" in model else model
|
||||
)
|
||||
if (
|
||||
model_for_check in litellm.open_ai_chat_completion_models
|
||||
) or model_for_check in litellm.open_ai_text_completion_models:
|
||||
model_specific_params.append(
|
||||
"user"
|
||||
) # user is not a param supported by all openai-compatible endpoints - e.g. azure ai
|
||||
return base_params + model_specific_params
|
||||
|
||||
def _map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
If any supported_openai_params are in non_default_params, add them to optional_params, so they are use in API call
|
||||
|
||||
Args:
|
||||
non_default_params (dict): Non-default parameters to filter.
|
||||
optional_params (dict): Optional parameters to update.
|
||||
model (str): Model name for parameter support check.
|
||||
|
||||
Returns:
|
||||
dict: Updated optional_params with supported non-default parameters.
|
||||
"""
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
for param, value in non_default_params.items():
|
||||
if param in supported_openai_params:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return self._map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def contains_pdf_url(self, content_item: ChatCompletionFileObjectFile) -> bool:
|
||||
potential_pdf_url_starts = ["https://", "http://", "www."]
|
||||
file_id = content_item.get("file_id")
|
||||
if file_id and any(
|
||||
file_id.startswith(start) for start in potential_pdf_url_starts
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle_pdf_url(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
content_copy = content_item.copy()
|
||||
file_id = content_copy.get("file_id")
|
||||
if file_id is not None:
|
||||
base64_data = convert_url_to_base64(file_id)
|
||||
content_copy["file_data"] = base64_data
|
||||
content_copy["filename"] = "my_file.pdf"
|
||||
content_copy.pop("file_id")
|
||||
return content_copy
|
||||
|
||||
async def _async_handle_pdf_url(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
file_id = content_item.get("file_id")
|
||||
if file_id is not None: # check for file id being url done in _handle_pdf_url
|
||||
base64_data = await async_convert_url_to_base64(file_id)
|
||||
content_item["file_data"] = base64_data
|
||||
content_item["filename"] = "my_file.pdf"
|
||||
content_item.pop("file_id")
|
||||
return content_item
|
||||
|
||||
def _common_file_data_check(
|
||||
self, content_item: ChatCompletionFileObjectFile
|
||||
) -> ChatCompletionFileObjectFile:
|
||||
file_data = content_item.get("file_data")
|
||||
filename = content_item.get("filename")
|
||||
if file_data is not None and filename is None:
|
||||
content_item["filename"] = "my_file.pdf"
|
||||
return content_item
|
||||
|
||||
def _apply_common_transform_content_item(
|
||||
self,
|
||||
content_item: OpenAIMessageContentListBlock,
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
litellm_specific_params = {"format"}
|
||||
if content_item.get("type") == "image_url":
|
||||
content_item = cast(ChatCompletionImageObject, content_item)
|
||||
if isinstance(content_item["image_url"], str):
|
||||
content_item["image_url"] = {
|
||||
"url": content_item["image_url"],
|
||||
}
|
||||
elif isinstance(content_item["image_url"], dict):
|
||||
new_image_url_obj = ChatCompletionImageUrlObject(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in content_item["image_url"].items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["image_url"] = new_image_url_obj
|
||||
elif content_item.get("type") == "file":
|
||||
content_item = cast(ChatCompletionFileObject, content_item)
|
||||
file_obj = content_item["file"]
|
||||
new_file_obj = ChatCompletionFileObjectFile(
|
||||
**{ # type: ignore
|
||||
k: v
|
||||
for k, v in file_obj.items()
|
||||
if k not in litellm_specific_params
|
||||
}
|
||||
)
|
||||
content_item["file"] = new_file_obj
|
||||
|
||||
return content_item
|
||||
|
||||
def _transform_content_item(
|
||||
self,
|
||||
content_item: OpenAIMessageContentListBlock,
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
content_item = self._apply_common_transform_content_item(content_item)
|
||||
content_item_type = content_item.get("type")
|
||||
potential_file_obj = content_item.get("file")
|
||||
if content_item_type == "file" and potential_file_obj:
|
||||
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
|
||||
content_item_typed = cast(ChatCompletionFileObject, content_item)
|
||||
if self.contains_pdf_url(file_obj):
|
||||
file_obj = self._handle_pdf_url(file_obj)
|
||||
file_obj = self._common_file_data_check(file_obj)
|
||||
content_item_typed["file"] = file_obj
|
||||
content_item = content_item_typed
|
||||
return content_item
|
||||
|
||||
async def _async_transform_content_item(
|
||||
self, content_item: OpenAIMessageContentListBlock, is_async: bool = False
|
||||
) -> OpenAIMessageContentListBlock:
|
||||
content_item = self._apply_common_transform_content_item(content_item)
|
||||
content_item_type = content_item.get("type")
|
||||
potential_file_obj = content_item.get("file")
|
||||
if content_item_type == "file" and potential_file_obj:
|
||||
file_obj = cast(ChatCompletionFileObjectFile, potential_file_obj)
|
||||
content_item_typed = cast(ChatCompletionFileObject, content_item)
|
||||
if self.contains_pdf_url(file_obj):
|
||||
file_obj = await self._async_handle_pdf_url(file_obj)
|
||||
file_obj = self._common_file_data_check(file_obj)
|
||||
content_item_typed["file"] = file_obj
|
||||
content_item = content_item_typed
|
||||
return content_item
|
||||
|
||||
# fmt: off
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
# fmt: on
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""OpenAI no longer supports image_url as a string, so we need to convert it to a dict"""
|
||||
|
||||
async def _async_transform():
|
||||
for message in messages:
|
||||
message_content = message.get("content")
|
||||
message_role = message.get("role")
|
||||
|
||||
if (
|
||||
message_role == "user"
|
||||
and message_content
|
||||
and isinstance(message_content, list)
|
||||
):
|
||||
message_content_types = cast(
|
||||
List[OpenAIMessageContentListBlock], message_content
|
||||
)
|
||||
for i, content_item in enumerate(message_content_types):
|
||||
message_content_types[
|
||||
i
|
||||
] = await self._async_transform_content_item(
|
||||
cast(OpenAIMessageContentListBlock, content_item),
|
||||
)
|
||||
return messages
|
||||
|
||||
if is_async:
|
||||
return _async_transform()
|
||||
else:
|
||||
for message in messages:
|
||||
message_content = message.get("content")
|
||||
message_role = message.get("role")
|
||||
if (
|
||||
message_role == "user"
|
||||
and message_content
|
||||
and isinstance(message_content, list)
|
||||
):
|
||||
message_content_types = cast(
|
||||
List[OpenAIMessageContentListBlock], message_content
|
||||
)
|
||||
for i, content_item in enumerate(message_content):
|
||||
message_content_types[i] = self._transform_content_item(
|
||||
cast(OpenAIMessageContentListBlock, content_item)
|
||||
)
|
||||
return messages
|
||||
|
||||
def remove_cache_control_flag_from_messages_and_tools(
|
||||
self,
|
||||
model: str, # allows overrides to selectively run this
|
||||
messages: List[AllMessageValues],
|
||||
tools: Optional[List["ChatCompletionToolParam"]] = None,
|
||||
) -> Tuple[List[AllMessageValues], Optional[List["ChatCompletionToolParam"]]]:
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
filter_value_from_dict,
|
||||
)
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
messages[i] = cast(
|
||||
AllMessageValues, filter_value_from_dict(message, "cache_control") # type: ignore
|
||||
)
|
||||
if tools is not None:
|
||||
for i, tool in enumerate(tools):
|
||||
tools[i] = cast(
|
||||
ChatCompletionToolParam,
|
||||
filter_value_from_dict(tool, "cache_control"), # type: ignore
|
||||
)
|
||||
return messages, tools
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the overall request to be sent to the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed request. Sent as the body of the API call.
|
||||
"""
|
||||
messages = self._transform_messages(messages=messages, model=model)
|
||||
messages, tools = self.remove_cache_control_flag_from_messages_and_tools(
|
||||
model=model, messages=messages, tools=optional_params.get("tools", [])
|
||||
)
|
||||
if tools is not None and len(tools) > 0:
|
||||
optional_params["tools"] = tools
|
||||
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
async def async_transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
transformed_messages = await self._transform_messages(
|
||||
messages=messages, model=model, is_async=True
|
||||
)
|
||||
(
|
||||
transformed_messages,
|
||||
tools,
|
||||
) = self.remove_cache_control_flag_from_messages_and_tools(
|
||||
model=model,
|
||||
messages=transformed_messages,
|
||||
tools=optional_params.get("tools", []),
|
||||
)
|
||||
if tools is not None and len(tools) > 0:
|
||||
optional_params["tools"] = tools
|
||||
if self.__class__._is_base_class:
|
||||
return {
|
||||
"model": model,
|
||||
"messages": transformed_messages,
|
||||
**optional_params,
|
||||
}
|
||||
else:
|
||||
## allow for any object specific behaviour to be handled
|
||||
return self.transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
def _passed_in_tools(self, optional_params: dict) -> bool:
|
||||
return optional_params.get("tools", None) is not None
|
||||
|
||||
def _check_and_fix_if_content_is_tool_call(
|
||||
self, content: str, optional_params: dict
|
||||
) -> Optional[ChatCompletionMessageToolCall]:
|
||||
"""
|
||||
Check if the content is a tool call
|
||||
"""
|
||||
import json
|
||||
|
||||
if not self._passed_in_tools(optional_params):
|
||||
return None
|
||||
tool_call_names = get_tool_call_names(optional_params.get("tools", []))
|
||||
try:
|
||||
json_content = json.loads(content)
|
||||
if (
|
||||
json_content.get("type") == "function"
|
||||
and json_content.get("name") in tool_call_names
|
||||
):
|
||||
return ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
name=json_content.get("name"),
|
||||
arguments=json_content.get("arguments"),
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
def _get_finish_reason(self, message: Message, received_finish_reason: str) -> str:
|
||||
if message.tool_calls is not None:
|
||||
return "tool_calls"
|
||||
else:
|
||||
return received_finish_reason
|
||||
|
||||
def _transform_choices(
|
||||
self,
|
||||
choices: List[OpenAIChatCompletionChoices],
|
||||
json_mode: Optional[bool] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> List[Choices]:
|
||||
transformed_choices = []
|
||||
|
||||
for choice in choices:
|
||||
## HANDLE JSON MODE - anthropic returns single function call]
|
||||
tool_calls = choice["message"].get("tool_calls", None)
|
||||
new_tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None
|
||||
message_content = choice["message"].get("content", None)
|
||||
if tool_calls is not None:
|
||||
_openai_tool_calls = []
|
||||
for _tc in tool_calls:
|
||||
_openai_tc = ChatCompletionMessageToolCall(**_tc) # type: ignore
|
||||
_openai_tool_calls.append(_openai_tc)
|
||||
fixed_tool_calls = _handle_invalid_parallel_tool_calls(
|
||||
_openai_tool_calls
|
||||
)
|
||||
|
||||
if fixed_tool_calls is not None:
|
||||
new_tool_calls = fixed_tool_calls
|
||||
elif (
|
||||
optional_params is not None
|
||||
and message_content
|
||||
and isinstance(message_content, str)
|
||||
):
|
||||
new_tool_call = self._check_and_fix_if_content_is_tool_call(
|
||||
message_content, optional_params
|
||||
)
|
||||
if new_tool_call is not None:
|
||||
choice["message"]["content"] = None # remove the content
|
||||
new_tool_calls = [new_tool_call]
|
||||
|
||||
translated_message: Optional[Message] = None
|
||||
finish_reason: Optional[str] = None
|
||||
if new_tool_calls and _should_convert_tool_call_to_json_mode(
|
||||
tool_calls=new_tool_calls,
|
||||
convert_tool_call_to_json_mode=json_mode,
|
||||
):
|
||||
# to support response_format on claude models
|
||||
json_mode_content_str: Optional[str] = (
|
||||
str(new_tool_calls[0]["function"].get("arguments", "")) or None
|
||||
)
|
||||
if json_mode_content_str is not None:
|
||||
translated_message = Message(content=json_mode_content_str)
|
||||
finish_reason = "stop"
|
||||
|
||||
if translated_message is None:
|
||||
## get the reasoning content
|
||||
(
|
||||
reasoning_content,
|
||||
content_str,
|
||||
) = _extract_reasoning_content(cast(dict, choice["message"]))
|
||||
|
||||
translated_message = Message(
|
||||
role="assistant",
|
||||
content=content_str,
|
||||
reasoning_content=reasoning_content,
|
||||
thinking_blocks=None,
|
||||
tool_calls=new_tool_calls,
|
||||
)
|
||||
|
||||
if finish_reason is None:
|
||||
finish_reason = choice["finish_reason"]
|
||||
|
||||
translated_choice = Choices(
|
||||
finish_reason=finish_reason,
|
||||
index=choice["index"],
|
||||
message=translated_message,
|
||||
logprobs=None,
|
||||
enhancements=None,
|
||||
)
|
||||
|
||||
translated_choice.finish_reason = map_finish_reason(
|
||||
self._get_finish_reason(translated_message, choice["finish_reason"])
|
||||
)
|
||||
transformed_choices.append(translated_choice)
|
||||
|
||||
return transformed_choices
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the response from the API.
|
||||
|
||||
Returns:
|
||||
dict: The transformed response.
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
response_headers = getattr(raw_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
message="Unable to get json response - {}, Original Response: {}".format(
|
||||
str(e), raw_response.text
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
headers=response_headers,
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
final_response_obj = convert_to_model_response_object(
|
||||
response_object=completion_response,
|
||||
model_response_object=model_response,
|
||||
hidden_params={"headers": raw_response_headers},
|
||||
_response_headers=raw_response_headers,
|
||||
)
|
||||
|
||||
return cast(ModelResponse, final_response_obj)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers),
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the API call.
|
||||
|
||||
Returns:
|
||||
str: The complete URL for the API call.
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
endpoint = "chat/completions"
|
||||
|
||||
# Remove trailing slash from api_base if present
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
# Check if endpoint is already in the api_base
|
||||
if endpoint in api_base:
|
||||
return api_base
|
||||
|
||||
return f"{api_base}/{endpoint}"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
|
||||
"""
|
||||
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
# Strip api_base to just the base URL (scheme + host + port)
|
||||
parsed_url = httpx.URL(api_base)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.host}"
|
||||
if parsed_url.port:
|
||||
base_url += f":{parsed_url.port}"
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{base_url}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get models: {response.text}")
|
||||
|
||||
models = response.json()["data"]
|
||||
return [model["id"] for model in models]
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: Optional[str] = None) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_token_counter(self) -> Optional["BaseTokenCounter"]:
|
||||
from litellm.llms.openai.responses.count_tokens.token_counter import (
|
||||
OpenAITokenCounter,
|
||||
)
|
||||
|
||||
return OpenAITokenCounter()
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return OpenAIChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionStreamingHandler(BaseModelResponseIterator):
|
||||
def _map_reasoning_to_reasoning_content(self, choices: list) -> list:
|
||||
"""
|
||||
Map 'reasoning' field to 'reasoning_content' field in delta.
|
||||
|
||||
Some OpenAI-compatible providers (e.g., GLM-5, hosted_vllm) return
|
||||
delta.reasoning, but LiteLLM expects delta.reasoning_content.
|
||||
|
||||
Args:
|
||||
choices: List of choice objects from the streaming chunk
|
||||
|
||||
Returns:
|
||||
List of choices with reasoning field mapped to reasoning_content
|
||||
"""
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
if "reasoning" in delta:
|
||||
delta["reasoning_content"] = delta.pop("reasoning")
|
||||
return choices
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
try:
|
||||
choices = chunk.get("choices", [])
|
||||
choices = self._map_reasoning_to_reasoning_content(choices)
|
||||
|
||||
kwargs = {
|
||||
"id": chunk["id"],
|
||||
"object": "chat.completion.chunk",
|
||||
"created": chunk.get("created"),
|
||||
"model": chunk.get("model"),
|
||||
"choices": choices,
|
||||
}
|
||||
if "usage" in chunk and chunk["usage"] is not None:
|
||||
kwargs["usage"] = chunk["usage"]
|
||||
return ModelResponseStream(**kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,3 @@
|
||||
Translation of OpenAI `/chat/completions` input and output to a custom guardrail.
|
||||
|
||||
This enables guardrails to be applied to OpenAI `/chat/completions` requests and responses.
|
||||
@@ -0,0 +1,12 @@
|
||||
"""OpenAI Chat Completions message handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.chat.guardrail_translation.handler import (
|
||||
OpenAIChatCompletionsHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.completion: OpenAIChatCompletionsHandler,
|
||||
CallTypes.acompletion: OpenAIChatCompletionsHandler,
|
||||
}
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,808 @@
|
||||
"""
|
||||
OpenAI Chat Completions Message Handler for Unified Guardrails
|
||||
|
||||
This module provides a class-based handler for OpenAI-format chat completions.
|
||||
The class methods can be overridden for custom behavior.
|
||||
|
||||
Pattern Overview:
|
||||
-----------------
|
||||
1. Extract text content from messages/responses (both string and list formats)
|
||||
2. Create async tasks to apply guardrails to each text segment
|
||||
3. Track mappings to know where each response belongs
|
||||
4. Apply guardrail responses back to the original structure
|
||||
|
||||
This pattern can be replicated for other message formats (e.g., Anthropic).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.main import stream_chunk_builder
|
||||
from litellm.types.llms.openai import ChatCompletionToolParam
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
GenericGuardrailAPIInputs,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
|
||||
class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI chat completions messages with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input messages (pre-call hook)
|
||||
2. Process output responses (post-call hook)
|
||||
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input messages by applying guardrails to text content.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
if messages is None:
|
||||
return data
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[ChatCompletionToolParam] = []
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
tool_call_task_mappings: List[Tuple[int, int]] = []
|
||||
# text_task_mappings: Track (message_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
# tool_call_task_mappings: Track (message_index, tool_call_index) for each tool call
|
||||
|
||||
# Step 1: Extract all text content, images, and tool calls
|
||||
for msg_idx, message in enumerate(messages):
|
||||
self._extract_inputs(
|
||||
message=message,
|
||||
msg_idx=msg_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
text_task_mappings=text_task_mappings,
|
||||
tool_call_task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts and tool calls in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check # type: ignore
|
||||
if messages:
|
||||
inputs[
|
||||
"structured_messages"
|
||||
] = messages # pass the openai /chat/completions messages to the guardrail, as-is
|
||||
# Pass tools (function definitions) to the guardrail
|
||||
tools = data.get("tools")
|
||||
if tools:
|
||||
inputs["tools"] = tools
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
guardrailed_tool_calls = guardrailed_inputs.get("tool_calls", [])
|
||||
guardrailed_tools = guardrailed_inputs.get("tools")
|
||||
if guardrailed_tools is not None:
|
||||
data["tools"] = guardrailed_tools
|
||||
|
||||
# Step 3: Map guardrail responses back to original message structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
await self._apply_guardrail_responses_to_input_texts(
|
||||
messages=messages,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=text_task_mappings,
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed tool calls back to messages
|
||||
if guardrailed_tool_calls:
|
||||
# Note: The guardrail may modify tool_calls_to_check in place
|
||||
# or we may need to handle returned tool calls differently
|
||||
await self._apply_guardrail_responses_to_input_tool_calls(
|
||||
messages=messages,
|
||||
tool_calls=guardrailed_tool_calls, # type: ignore
|
||||
task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed input messages: %s", messages
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
for fn in data.get("functions") or []:
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
return names
|
||||
|
||||
def _extract_inputs(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
msg_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
tool_calls_to_check: List[ChatCompletionToolParam],
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_call_task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a message.
|
||||
|
||||
Override this method to customize text/image/tool call extraction logic.
|
||||
"""
|
||||
content = message.get("content", None)
|
||||
if content is not None:
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
text_task_mappings.append((msg_idx, None))
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content (e.g., multimodal with text and images)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Extract text
|
||||
text_str = content_item.get("text", None)
|
||||
if text_str is not None:
|
||||
texts_to_check.append(text_str)
|
||||
text_task_mappings.append((msg_idx, int(content_idx)))
|
||||
|
||||
# Extract images (image_url)
|
||||
if content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if url:
|
||||
images_to_check.append(url)
|
||||
elif isinstance(image_url, str):
|
||||
images_to_check.append(image_url)
|
||||
|
||||
# Extract tool calls (typically in assistant messages)
|
||||
tool_calls = message.get("tool_calls", None)
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||
if isinstance(tool_call, dict):
|
||||
# Add the full tool call object to the list
|
||||
tool_calls_to_check.append(cast(ChatCompletionToolParam, tool_call))
|
||||
tool_call_task_mappings.append((msg_idx, int(tool_call_idx)))
|
||||
|
||||
async def _apply_guardrail_responses_to_input_texts(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to input message text content.
|
||||
|
||||
Override this method to customize how text responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
msg_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
# Handle content
|
||||
content = messages[msg_idx].get("content", None)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
messages[msg_idx]["content"] = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
messages[msg_idx]["content"][content_idx_optional][
|
||||
"text"
|
||||
] = guardrail_response
|
||||
|
||||
async def _apply_guardrail_responses_to_input_tool_calls(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrailed tool calls back to input messages.
|
||||
|
||||
The guardrail may have modified the tool_calls list in place,
|
||||
so we apply the modified tool calls back to the original messages.
|
||||
|
||||
Override this method to customize how tool call responses are applied.
|
||||
"""
|
||||
for task_idx, (msg_idx, tool_call_idx) in enumerate(task_mappings):
|
||||
if task_idx < len(tool_calls):
|
||||
guardrailed_tool_call = tool_calls[task_idx]
|
||||
message_tool_calls = messages[msg_idx].get("tool_calls", None)
|
||||
if message_tool_calls is not None and isinstance(
|
||||
message_tool_calls, list
|
||||
):
|
||||
if tool_call_idx < len(message_tool_calls):
|
||||
# Replace the tool call with the guardrailed version
|
||||
message_tool_calls[tool_call_idx] = guardrailed_tool_call
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ModelResponse object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- String content: choice.message.content = "text here"
|
||||
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
|
||||
"""
|
||||
|
||||
# Step 0: Check if response has any text content to process
|
||||
if not self._has_text_content(response):
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Chat Completions: No text content in response, skipping guardrail"
|
||||
)
|
||||
return response
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[Dict[str, Any]] = []
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
tool_call_task_mappings: List[Tuple[int, int]] = []
|
||||
# text_task_mappings: Track (choice_index, content_index) for each text
|
||||
# content_index is None for string content, int for list content
|
||||
# tool_call_task_mappings: Track (choice_index, tool_call_index) for each tool call
|
||||
|
||||
# Step 1: Extract all text content, images, and tool calls from response choices
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
self._extract_output_text_images_and_tool_calls(
|
||||
choice=choice,
|
||||
choice_idx=choice_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
text_task_mappings=text_task_mappings,
|
||||
tool_call_task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts and tool calls in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check # type: ignore
|
||||
# Include model information from the response if available
|
||||
if hasattr(response, "model") and response.model:
|
||||
inputs["model"] = response.model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Map guardrail responses back to original response structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
await self._apply_guardrail_responses_to_output_texts(
|
||||
response=response,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=text_task_mappings,
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed tool calls back to response
|
||||
if tool_calls_to_check:
|
||||
await self._apply_guardrail_responses_to_output_tool_calls(
|
||||
response=response,
|
||||
tool_calls=tool_calls_to_check,
|
||||
task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed output response: %s", response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List["ModelResponseStream"],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List["ModelResponseStream"]:
|
||||
"""
|
||||
Process output streaming responses by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
responses_so_far: List of LiteLLM ModelResponseStream objects
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified list of responses with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- String content: choice.message.content = "text here"
|
||||
- List content: choice.message.content = [{"type": "text", "text": "text here"}, ...]
|
||||
"""
|
||||
# check if the stream has ended
|
||||
has_stream_ended = False
|
||||
for chunk in responses_so_far:
|
||||
if chunk.choices and chunk.choices[0].finish_reason is not None:
|
||||
has_stream_ended = True
|
||||
break
|
||||
|
||||
if has_stream_ended:
|
||||
# convert to model response
|
||||
model_response = cast(
|
||||
ModelResponse,
|
||||
stream_chunk_builder(
|
||||
chunks=responses_so_far, logging_obj=litellm_logging_obj
|
||||
),
|
||||
)
|
||||
# run process_output_response
|
||||
await self.process_output_response(
|
||||
response=model_response,
|
||||
guardrail_to_apply=guardrail_to_apply,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return responses_so_far
|
||||
|
||||
# Step 0: Check if any response has text content to process
|
||||
has_any_text_content = False
|
||||
for response in responses_so_far:
|
||||
if self._has_text_content(response):
|
||||
has_any_text_content = True
|
||||
break
|
||||
|
||||
if not has_any_text_content:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Chat Completions: No text content in streaming responses, skipping guardrail"
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
# Step 1: Combine all streaming chunks into complete text per choice
|
||||
# For streaming, we need to concatenate all delta.content across all chunks
|
||||
# Key: (choice_idx, content_idx), Value: combined text
|
||||
combined_texts = self._combine_streaming_texts(responses_so_far)
|
||||
|
||||
# Step 2: Create lists for guardrail processing
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
# Track (choice_index, content_index) for each combined text
|
||||
|
||||
for (map_choice_idx, map_content_idx), combined_text in combined_texts.items():
|
||||
texts_to_check.append(combined_text)
|
||||
task_mappings.append((map_choice_idx, map_content_idx))
|
||||
|
||||
# Step 3: Apply guardrail to all combined texts in batch
|
||||
if texts_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"responses": responses_so_far}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
# Include model information from the first response if available
|
||||
if (
|
||||
responses_so_far
|
||||
and hasattr(responses_so_far[0], "model")
|
||||
and responses_so_far[0].model
|
||||
):
|
||||
inputs["model"] = responses_so_far[0].model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 4: Apply guardrailed text back to all streaming chunks
|
||||
# For each choice, replace the combined text across all chunks
|
||||
await self._apply_guardrail_responses_to_output_streaming(
|
||||
responses=responses_so_far,
|
||||
guardrailed_texts=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processed output streaming responses: %s",
|
||||
responses_so_far,
|
||||
)
|
||||
|
||||
return responses_so_far
|
||||
|
||||
def _combine_streaming_texts(
|
||||
self, responses_so_far: List["ModelResponseStream"]
|
||||
) -> Dict[Tuple[int, Optional[int]], str]:
|
||||
"""
|
||||
Combine all streaming chunks into complete text per choice.
|
||||
|
||||
For streaming, we need to concatenate all delta.content across all chunks.
|
||||
|
||||
Args:
|
||||
responses_so_far: List of LiteLLM ModelResponseStream objects
|
||||
|
||||
Returns:
|
||||
Dict mapping (choice_idx, content_idx) to combined text string
|
||||
"""
|
||||
combined_texts: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
|
||||
for response_idx, response in enumerate(responses_so_far):
|
||||
for choice_idx, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content - accumulate for this choice
|
||||
str_key: Tuple[int, Optional[int]] = (choice_idx, None)
|
||||
if str_key not in combined_texts:
|
||||
combined_texts[str_key] = ""
|
||||
combined_texts[str_key] += content
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - accumulate for each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
text_str = content_item.get("text")
|
||||
if text_str:
|
||||
list_key: Tuple[int, Optional[int]] = (
|
||||
choice_idx,
|
||||
content_idx,
|
||||
)
|
||||
if list_key not in combined_texts:
|
||||
combined_texts[list_key] = ""
|
||||
combined_texts[list_key] += text_str
|
||||
|
||||
return combined_texts
|
||||
|
||||
def _has_text_content(
|
||||
self, response: Union["ModelResponse", "ModelResponseStream"]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if response has any text content or tool calls to process.
|
||||
|
||||
Override this method to customize text content detection.
|
||||
"""
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
|
||||
if isinstance(response, ModelResponse):
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, litellm.Choices):
|
||||
# Check for text content
|
||||
if choice.message.content and isinstance(
|
||||
choice.message.content, str
|
||||
):
|
||||
return True
|
||||
# Check for tool calls
|
||||
if choice.message.tool_calls and isinstance(
|
||||
choice.message.tool_calls, list
|
||||
):
|
||||
if len(choice.message.tool_calls) > 0:
|
||||
return True
|
||||
elif isinstance(response, ModelResponseStream):
|
||||
for streaming_choice in response.choices:
|
||||
if isinstance(streaming_choice, litellm.StreamingChoices):
|
||||
# Check for text content
|
||||
if streaming_choice.delta.content and isinstance(
|
||||
streaming_choice.delta.content, str
|
||||
):
|
||||
return True
|
||||
# Check for tool calls
|
||||
if streaming_choice.delta.tool_calls and isinstance(
|
||||
streaming_choice.delta.tool_calls, list
|
||||
):
|
||||
if len(streaming_choice.delta.tool_calls) > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_output_text_images_and_tool_calls(
|
||||
self,
|
||||
choice: Union[Choices, StreamingChoices],
|
||||
choice_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
tool_calls_to_check: List[Dict[str, Any]],
|
||||
text_task_mappings: List[Tuple[int, Optional[int]]],
|
||||
tool_call_task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a response choice.
|
||||
|
||||
Override this method to customize text/image/tool call extraction logic.
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Chat Completions: Processing choice: %s", choice
|
||||
)
|
||||
|
||||
# Determine content source and tool calls based on choice type
|
||||
content = None
|
||||
tool_calls: Optional[List[Any]] = None
|
||||
if isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
tool_calls = choice.message.tool_calls
|
||||
elif isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
tool_calls = choice.delta.tool_calls
|
||||
else:
|
||||
# Unknown choice type, skip processing
|
||||
return
|
||||
|
||||
# Process content if it exists
|
||||
if content and isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
text_task_mappings.append((choice_idx, None))
|
||||
|
||||
elif content and isinstance(content, list):
|
||||
# List content (e.g., multimodal response)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Extract text
|
||||
content_text = content_item.get("text")
|
||||
if content_text:
|
||||
texts_to_check.append(content_text)
|
||||
text_task_mappings.append((choice_idx, int(content_idx)))
|
||||
|
||||
# Extract images
|
||||
if content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if url:
|
||||
images_to_check.append(url)
|
||||
|
||||
# Process tool calls if they exist
|
||||
if tool_calls is not None and isinstance(tool_calls, list):
|
||||
for tool_call_idx, tool_call in enumerate(tool_calls):
|
||||
# Convert tool call to dict format for guardrail processing
|
||||
tool_call_dict = self._convert_tool_call_to_dict(tool_call)
|
||||
if tool_call_dict:
|
||||
tool_calls_to_check.append(tool_call_dict)
|
||||
tool_call_task_mappings.append((choice_idx, int(tool_call_idx)))
|
||||
|
||||
def _convert_tool_call_to_dict(
|
||||
self, tool_call: Union[Dict[str, Any], Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Convert a tool call object to dictionary format.
|
||||
|
||||
Tool calls can be either dict or object depending on the type.
|
||||
"""
|
||||
if isinstance(tool_call, dict):
|
||||
return tool_call
|
||||
elif hasattr(tool_call, "id") and hasattr(tool_call, "function"):
|
||||
# Convert object to dict
|
||||
function = tool_call.function
|
||||
function_dict = {}
|
||||
if hasattr(function, "name"):
|
||||
function_dict["name"] = function.name
|
||||
if hasattr(function, "arguments"):
|
||||
function_dict["arguments"] = function.arguments
|
||||
|
||||
tool_call_dict = {
|
||||
"id": tool_call.id if hasattr(tool_call, "id") else None,
|
||||
"type": tool_call.type if hasattr(tool_call, "type") else "function",
|
||||
"function": function_dict,
|
||||
}
|
||||
return tool_call_dict
|
||||
return None
|
||||
|
||||
async def _apply_guardrail_responses_to_output_texts(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail text responses back to output response.
|
||||
|
||||
Override this method to customize how text responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
choice_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
choice = cast(Choices, response.choices[choice_idx])
|
||||
|
||||
# Handle content
|
||||
content = choice.message.content
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
choice.message.content = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
choice.message.content[content_idx_optional]["text"] = guardrail_response # type: ignore
|
||||
|
||||
async def _apply_guardrail_responses_to_output_tool_calls(
|
||||
self,
|
||||
response: "ModelResponse",
|
||||
tool_calls: List[Dict[str, Any]],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrailed tool calls back to output response.
|
||||
|
||||
The guardrail may have modified the tool_calls list in place,
|
||||
so we apply the modified tool calls back to the original response.
|
||||
|
||||
Override this method to customize how tool call responses are applied.
|
||||
"""
|
||||
for task_idx, (choice_idx, tool_call_idx) in enumerate(task_mappings):
|
||||
if task_idx < len(tool_calls):
|
||||
guardrailed_tool_call = tool_calls[task_idx]
|
||||
choice = cast(Choices, response.choices[choice_idx])
|
||||
choice_tool_calls = choice.message.tool_calls
|
||||
|
||||
if choice_tool_calls is not None and isinstance(
|
||||
choice_tool_calls, list
|
||||
):
|
||||
if tool_call_idx < len(choice_tool_calls):
|
||||
# Update the tool call with guardrailed version
|
||||
existing_tool_call = choice_tool_calls[tool_call_idx]
|
||||
# Update object attributes (output responses always have typed objects)
|
||||
if "function" in guardrailed_tool_call:
|
||||
func_dict = guardrailed_tool_call["function"]
|
||||
if "arguments" in func_dict:
|
||||
existing_tool_call.function.arguments = func_dict[
|
||||
"arguments"
|
||||
]
|
||||
if "name" in func_dict:
|
||||
existing_tool_call.function.name = func_dict["name"]
|
||||
|
||||
async def _apply_guardrail_responses_to_output_streaming(
|
||||
self,
|
||||
responses: List["ModelResponseStream"],
|
||||
guardrailed_texts: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to output streaming responses.
|
||||
|
||||
For streaming responses, the guardrailed text (which is the combined text from all chunks)
|
||||
is placed in the first chunk, and subsequent chunks are cleared.
|
||||
|
||||
Args:
|
||||
responses: List of ModelResponseStream objects to modify
|
||||
guardrailed_texts: List of guardrailed text responses (combined from all chunks)
|
||||
task_mappings: List of tuples (choice_idx, content_idx)
|
||||
|
||||
Override this method to customize how responses are applied to streaming responses.
|
||||
"""
|
||||
# Build a mapping of what guardrailed text to use for each (choice_idx, content_idx)
|
||||
guardrail_map: Dict[Tuple[int, Optional[int]], str] = {}
|
||||
for task_idx, guardrail_response in enumerate(guardrailed_texts):
|
||||
mapping = task_mappings[task_idx]
|
||||
choice_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
guardrail_map[(choice_idx, content_idx_optional)] = guardrail_response
|
||||
|
||||
# Track which choices we've already set the guardrailed text for
|
||||
# Key: (choice_idx, content_idx), Value: boolean (True if already set)
|
||||
already_set: Dict[Tuple[int, Optional[int]], bool] = {}
|
||||
|
||||
# Iterate through all responses and update content
|
||||
for response_idx, response in enumerate(responses):
|
||||
for choice_idx_in_response, choice in enumerate(response.choices):
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
content = choice.delta.content
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
content = choice.message.content
|
||||
else:
|
||||
continue
|
||||
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str):
|
||||
# String content
|
||||
str_key: Tuple[int, Optional[int]] = (choice_idx_in_response, None)
|
||||
if str_key in guardrail_map:
|
||||
if str_key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = guardrail_map[str_key]
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = guardrail_map[str_key]
|
||||
already_set[str_key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the content
|
||||
if isinstance(choice, litellm.StreamingChoices):
|
||||
choice.delta.content = ""
|
||||
elif isinstance(choice, litellm.Choices):
|
||||
choice.message.content = ""
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content - handle each content item
|
||||
for content_idx, content_item in enumerate(content):
|
||||
if "text" in content_item:
|
||||
list_key: Tuple[int, Optional[int]] = (
|
||||
choice_idx_in_response,
|
||||
content_idx,
|
||||
)
|
||||
if list_key in guardrail_map:
|
||||
if list_key not in already_set:
|
||||
# First chunk - set the complete guardrailed text
|
||||
content_item["text"] = guardrail_map[list_key]
|
||||
already_set[list_key] = True
|
||||
else:
|
||||
# Subsequent chunks - clear the text
|
||||
content_item["text"] = ""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
LLM Calling done in `openai/openai.py`
|
||||
"""
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Support for o1/o3 model family
|
||||
|
||||
https://platform.openai.com/docs/guides/reasoning
|
||||
|
||||
Translations handled by LiteLLM:
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
- streaming => faked by LiteLLM
|
||||
- Tools, response_format => drop param (if user opts in to dropping param)
|
||||
- Logprobs => drop param (if user opts in to dropping param)
|
||||
"""
|
||||
|
||||
from typing import Any, Coroutine, List, Literal, Optional, Union, cast, overload
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionUserMessage
|
||||
from litellm.utils import (
|
||||
supports_function_calling,
|
||||
supports_parallel_function_calling,
|
||||
supports_response_schema,
|
||||
supports_system_messages,
|
||||
)
|
||||
|
||||
from .gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class OpenAIOSeriesConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/guides/reasoning
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def translate_developer_role_to_system_role(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
O-series models support `developer` role.
|
||||
"""
|
||||
return messages
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the supported OpenAI params for the given model
|
||||
|
||||
"""
|
||||
|
||||
all_openai_params = super().get_supported_openai_params(model=model)
|
||||
non_supported_params = [
|
||||
"logprobs",
|
||||
"top_p",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"top_logprobs",
|
||||
]
|
||||
|
||||
o_series_only_param = ["reasoning_effort"]
|
||||
|
||||
all_openai_params.extend(o_series_only_param)
|
||||
|
||||
try:
|
||||
model, custom_llm_provider, api_base, api_key = get_llm_provider(
|
||||
model=model
|
||||
)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"Unable to infer model provider for model={model}, defaulting to openai for o1 supported param check"
|
||||
)
|
||||
custom_llm_provider = "openai"
|
||||
|
||||
_supports_function_calling = supports_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
_supports_response_schema = supports_response_schema(model, custom_llm_provider)
|
||||
_supports_parallel_tool_calls = supports_parallel_function_calling(
|
||||
model, custom_llm_provider
|
||||
)
|
||||
|
||||
if not _supports_function_calling:
|
||||
non_supported_params.append("tools")
|
||||
non_supported_params.append("tool_choice")
|
||||
non_supported_params.append("function_call")
|
||||
non_supported_params.append("functions")
|
||||
|
||||
if not _supports_parallel_tool_calls:
|
||||
non_supported_params.append("parallel_tool_calls")
|
||||
|
||||
if not _supports_response_schema:
|
||||
non_supported_params.append("response_format")
|
||||
|
||||
return [
|
||||
param for param in all_openai_params if param not in non_supported_params
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
if "max_tokens" in non_default_params:
|
||||
optional_params["max_completion_tokens"] = non_default_params.pop(
|
||||
"max_tokens"
|
||||
)
|
||||
if "temperature" in non_default_params:
|
||||
temperature_value: Optional[float] = non_default_params.pop("temperature")
|
||||
if temperature_value is not None:
|
||||
if temperature_value == 1:
|
||||
optional_params["temperature"] = temperature_value
|
||||
else:
|
||||
## UNSUPPORTED TOOL CHOICE VALUE
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="O-series models don't support temperature={}. Only temperature=1 is supported. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||
temperature_value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
return super()._map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def is_model_o_series_model(self, model: str) -> bool:
|
||||
model = model.split("/")[-1] # could be "openai/o3" or "o3"
|
||||
return (
|
||||
len(model) > 1
|
||||
and model[0] == "o"
|
||||
and model[1].isdigit()
|
||||
and model in litellm.open_ai_chat_completion_models
|
||||
)
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
"""
|
||||
Handles limitations of O-1 model family.
|
||||
- modalities: image => drop param (if user opts in to dropping param)
|
||||
- role: system ==> translate to role 'user'
|
||||
"""
|
||||
_supports_system_messages = supports_system_messages(model, "openai")
|
||||
for i, message in enumerate(messages):
|
||||
if message["role"] == "system" and not _supports_system_messages:
|
||||
new_message = ChatCompletionUserMessage(
|
||||
content=message["content"], role="user"
|
||||
)
|
||||
messages[i] = new_message # Replace the old message with the new one
|
||||
|
||||
if is_async:
|
||||
return super()._transform_messages(
|
||||
messages, model, is_async=cast(Literal[True], True)
|
||||
)
|
||||
else:
|
||||
return super()._transform_messages(
|
||||
messages, model, is_async=cast(Literal[False], False)
|
||||
)
|
||||
@@ -0,0 +1,290 @@
|
||||
"""
|
||||
Common helpers / utils across al OpenAI endpoints
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import ssl
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||
AsyncHTTPHandler,
|
||||
get_ssl_configuration,
|
||||
)
|
||||
|
||||
|
||||
def _get_client_init_params(cls: type) -> Tuple[str, ...]:
|
||||
"""Extract __init__ parameter names (excluding 'self') from a class."""
|
||||
return tuple(p for p in inspect.signature(cls.__init__).parameters if p != "self") # type: ignore[misc]
|
||||
|
||||
|
||||
_OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(OpenAI)
|
||||
_AZURE_OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(AzureOpenAI)
|
||||
|
||||
|
||||
class OpenAIError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
request: Optional[httpx.Request] = None,
|
||||
response: Optional[httpx.Response] = None,
|
||||
headers: Optional[Union[dict, httpx.Headers]] = None,
|
||||
body: Optional[dict] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.headers = headers
|
||||
if request:
|
||||
self.request = request
|
||||
else:
|
||||
self.request = httpx.Request(method="POST", url="https://api.openai.com/v1")
|
||||
if response:
|
||||
self.response = response
|
||||
else:
|
||||
self.response = httpx.Response(
|
||||
status_code=status_code, request=self.request
|
||||
)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=self.message,
|
||||
headers=self.headers,
|
||||
request=self.request,
|
||||
response=self.response,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
####### Error Handling Utils for OpenAI API #######################
|
||||
###################################################################
|
||||
def drop_params_from_unprocessable_entity_error(
|
||||
e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError],
|
||||
data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message.
|
||||
|
||||
Args:
|
||||
e (UnprocessableEntityError): The UnprocessableEntityError exception
|
||||
data (Dict[str, Any]): The original data dictionary containing all parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new dictionary with invalid parameters removed
|
||||
"""
|
||||
invalid_params: List[str] = []
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
error_json = e.response.json()
|
||||
error_message = error_json.get("error", {})
|
||||
error_body = error_message
|
||||
else:
|
||||
error_body = e.body
|
||||
if (
|
||||
error_body is not None
|
||||
and isinstance(error_body, dict)
|
||||
and error_body.get("message")
|
||||
):
|
||||
message = error_body.get("message", {})
|
||||
if isinstance(message, str):
|
||||
try:
|
||||
message = json.loads(message)
|
||||
except json.JSONDecodeError:
|
||||
message = {"detail": message}
|
||||
detail = message.get("detail")
|
||||
|
||||
if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict):
|
||||
for error_dict in detail:
|
||||
if (
|
||||
error_dict.get("loc")
|
||||
and isinstance(error_dict.get("loc"), list)
|
||||
and len(error_dict.get("loc")) == 2
|
||||
):
|
||||
invalid_params.append(error_dict["loc"][1])
|
||||
|
||||
new_data = {k: v for k, v in data.items() if k not in invalid_params}
|
||||
|
||||
return new_data
|
||||
|
||||
|
||||
class BaseOpenAILLM:
|
||||
"""
|
||||
Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def get_cached_openai_client(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]:
|
||||
"""Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
_cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key)
|
||||
return _cached_client
|
||||
|
||||
@staticmethod
|
||||
def set_cached_openai_client(
|
||||
openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI],
|
||||
client_type: Literal["openai", "azure"],
|
||||
client_initialization_params: dict,
|
||||
):
|
||||
"""Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS"""
|
||||
_cache_key = BaseOpenAILLM.get_openai_client_cache_key(
|
||||
client_initialization_params=client_initialization_params,
|
||||
client_type=client_type,
|
||||
)
|
||||
litellm.in_memory_llm_clients_cache.set_cache(
|
||||
key=_cache_key,
|
||||
value=openai_client,
|
||||
ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_cache_key(
|
||||
client_initialization_params: dict, client_type: Literal["openai", "azure"]
|
||||
) -> str:
|
||||
"""Creates a cache key for the OpenAI client based on the client initialization parameters"""
|
||||
hashed_api_key = None
|
||||
if client_initialization_params.get("api_key") is not None:
|
||||
hash_object = hashlib.sha256(
|
||||
client_initialization_params.get("api_key", "").encode()
|
||||
)
|
||||
# Hexadecimal representation of the hash
|
||||
hashed_api_key = hash_object.hexdigest()
|
||||
|
||||
# Create a more readable cache key using a list of key-value pairs
|
||||
key_parts = [
|
||||
f"hashed_api_key={hashed_api_key}",
|
||||
f"is_async={client_initialization_params.get('is_async')}",
|
||||
]
|
||||
|
||||
LITELLM_CLIENT_SPECIFIC_PARAMS = (
|
||||
"timeout",
|
||||
"max_retries",
|
||||
"organization",
|
||||
"api_base",
|
||||
)
|
||||
openai_client_fields = (
|
||||
BaseOpenAILLM.get_openai_client_initialization_param_fields(
|
||||
client_type=client_type
|
||||
)
|
||||
+ LITELLM_CLIENT_SPECIFIC_PARAMS
|
||||
)
|
||||
|
||||
for param in openai_client_fields:
|
||||
key_parts.append(f"{param}={client_initialization_params.get(param)}")
|
||||
|
||||
_cache_key = ",".join(key_parts)
|
||||
return _cache_key
|
||||
|
||||
@staticmethod
|
||||
def get_openai_client_initialization_param_fields(
|
||||
client_type: Literal["openai", "azure"]
|
||||
) -> Tuple[str, ...]:
|
||||
"""Returns a tuple of fields that are used to initialize the OpenAI client"""
|
||||
if client_type == "openai":
|
||||
return _OPENAI_INIT_PARAMS
|
||||
else:
|
||||
return _AZURE_OPENAI_INIT_PARAMS
|
||||
|
||||
@staticmethod
|
||||
def _get_async_http_client(
|
||||
shared_session: Optional["ClientSession"] = None,
|
||||
) -> Optional[httpx.AsyncClient]:
|
||||
if litellm.aclient_session is not None:
|
||||
return litellm.aclient_session
|
||||
|
||||
if getattr(litellm, "network_mock", False):
|
||||
from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport
|
||||
|
||||
return httpx.AsyncClient(transport=MockOpenAITransport())
|
||||
|
||||
# Get unified SSL configuration
|
||||
ssl_config = get_ssl_configuration()
|
||||
|
||||
return httpx.AsyncClient(
|
||||
verify=ssl_config,
|
||||
transport=AsyncHTTPHandler._create_async_transport(
|
||||
ssl_context=ssl_config
|
||||
if isinstance(ssl_config, ssl.SSLContext)
|
||||
else None,
|
||||
ssl_verify=ssl_config if isinstance(ssl_config, bool) else None,
|
||||
shared_session=shared_session,
|
||||
),
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sync_http_client() -> Optional[httpx.Client]:
|
||||
if litellm.client_session is not None:
|
||||
return litellm.client_session
|
||||
|
||||
if getattr(litellm, "network_mock", False):
|
||||
from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport
|
||||
|
||||
return httpx.Client(transport=MockOpenAITransport())
|
||||
|
||||
# Get unified SSL configuration
|
||||
ssl_config = get_ssl_configuration()
|
||||
|
||||
return httpx.Client(
|
||||
verify=ssl_config,
|
||||
follow_redirects=True,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICredentials(NamedTuple):
|
||||
api_base: str
|
||||
api_key: Optional[str]
|
||||
organization: Optional[str]
|
||||
|
||||
|
||||
def get_openai_credentials(
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
organization: Optional[str] = None,
|
||||
) -> OpenAICredentials:
|
||||
"""Resolve OpenAI credentials from params, litellm globals, and env vars."""
|
||||
resolved_api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or os.getenv("OPENAI_BASE_URL")
|
||||
or os.getenv("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
resolved_organization = (
|
||||
organization
|
||||
or litellm.organization
|
||||
or os.getenv("OPENAI_ORGANIZATION", None)
|
||||
or None
|
||||
)
|
||||
resolved_api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or os.getenv("OPENAI_API_KEY")
|
||||
)
|
||||
return OpenAICredentials(
|
||||
api_base=resolved_api_base,
|
||||
api_key=resolved_api_key,
|
||||
organization=resolved_organization,
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
# OpenAI Text Completion Guardrail Translation Handler
|
||||
|
||||
Handler for processing OpenAI's text completion endpoint (`/v1/completions`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes text completion requests by:
|
||||
1. Extracting the text prompt(s) from the request
|
||||
2. Applying guardrails to the prompt text(s)
|
||||
3. Updating the request with the guardrailed prompt(s)
|
||||
4. Applying guardrails to the completion output text
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
**Single Prompt:**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"max_tokens": 7,
|
||||
"temperature": 0
|
||||
}
|
||||
```
|
||||
|
||||
**Multiple Prompts (Batch):**
|
||||
```json
|
||||
{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": [
|
||||
"Tell me a joke",
|
||||
"Write a poem"
|
||||
],
|
||||
"max_tokens": 50
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7",
|
||||
"object": "text_completion",
|
||||
"created": 1589478378,
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"choices": [
|
||||
{
|
||||
"text": "\n\nThis is indeed a test",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "length"
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 5,
|
||||
"completion_tokens": 7,
|
||||
"total_tokens": 12
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the text completion endpoint.
|
||||
|
||||
### Example: Using Guardrails with Text Completion
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "Say this is a test",
|
||||
"guardrails": ["content_moderation"],
|
||||
"max_tokens": 7
|
||||
}'
|
||||
```
|
||||
|
||||
The guardrail will be applied to both:
|
||||
- **Input**: The prompt text before sending to the LLM
|
||||
- **Output**: The completion text in the response
|
||||
|
||||
### Example: PII Masking in Prompts and Completions
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": "My name is John Doe and my email is john@example.com",
|
||||
"guardrails": ["mask_pii"],
|
||||
"metadata": {
|
||||
"guardrails": ["mask_pii"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
### Example: Batch Prompts with Guardrails
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo-instruct",
|
||||
"prompt": [
|
||||
"Tell me about AI",
|
||||
"What is machine learning?"
|
||||
],
|
||||
"guardrails": ["content_filter"],
|
||||
"max_tokens": 100
|
||||
}'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Field**: `prompt` (string or list of strings)
|
||||
- **Processing**:
|
||||
- String prompts: Apply guardrail directly
|
||||
- List prompts: Apply guardrail to each string in the list
|
||||
- **Result**: Updated prompt(s) in request
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Field**: `choices[*].text` (string)
|
||||
- **Processing**: Applies guardrail to each completion text
|
||||
- **Result**: Updated completion texts in response
|
||||
|
||||
### Supported Prompt Types
|
||||
|
||||
1. **String**: Single prompt as a string
|
||||
2. **List of Strings**: Multiple prompts for batch completion
|
||||
3. **List of Lists**: Token-based prompts (passed through unchanged)
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how prompts are processed
|
||||
- `process_output_response()`: Customize how completion texts are processed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.text_completion` - Synchronous text completion
|
||||
- `CallTypes.atext_completion` - Asynchronous text completion
|
||||
|
||||
## Notes
|
||||
|
||||
- The handler processes both input prompts and output completion texts
|
||||
- List prompts are processed individually (each string in the list)
|
||||
- Non-string prompt items (e.g., token lists) are passed through unchanged
|
||||
- Both sync and async call types use the same handler
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Text Completion handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.completion.guardrail_translation.handler import (
|
||||
OpenAITextCompletionHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.text_completion: OpenAITextCompletionHandler,
|
||||
CallTypes.atext_completion: OpenAITextCompletionHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAITextCompletionHandler"]
|
||||
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
OpenAI Text Completion Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's text completion endpoint.
|
||||
The handler processes the 'prompt' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.utils import TextCompletionResponse
|
||||
|
||||
|
||||
class OpenAITextCompletionHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI text completion requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input prompt (pre-call hook)
|
||||
2. Process output response (post-call hook)
|
||||
|
||||
The handler specifically processes the 'prompt' parameter which can be:
|
||||
- A single string
|
||||
- A list of strings (for batch completions)
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input prompt by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'prompt' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to prompt
|
||||
"""
|
||||
prompt = data.get("prompt")
|
||||
if prompt is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: No prompt found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(prompt, str):
|
||||
# Single string prompt
|
||||
inputs = GenericGuardrailAPIInputs(texts=[prompt])
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to string prompt. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(prompt),
|
||||
len(data["prompt"]),
|
||||
)
|
||||
|
||||
elif isinstance(prompt, list):
|
||||
# List of string prompts (batch completion)
|
||||
texts_to_check = []
|
||||
text_indices = [] # Track which prompts are strings
|
||||
|
||||
for idx, p in enumerate(prompt):
|
||||
if isinstance(p, str):
|
||||
texts_to_check.append(p)
|
||||
text_indices.append(idx)
|
||||
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Replace guardrailed texts back
|
||||
for guardrail_idx, prompt_idx in enumerate(text_indices):
|
||||
if guardrail_idx < len(guardrailed_texts):
|
||||
data["prompt"][prompt_idx] = guardrailed_texts[guardrail_idx]
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to prompt[%d]. "
|
||||
"Original length: %d, New length: %d",
|
||||
prompt_idx,
|
||||
len(texts_to_check[guardrail_idx]),
|
||||
len(guardrailed_texts[guardrail_idx]),
|
||||
)
|
||||
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Text Completion: Unexpected prompt type: %s. Expected string or list.",
|
||||
type(prompt),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "TextCompletionResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to completion text.
|
||||
|
||||
Args:
|
||||
response: Text completion response object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrails applied to completion text
|
||||
"""
|
||||
if not hasattr(response, "choices") or not response.choices:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: No choices in response to process"
|
||||
)
|
||||
return response
|
||||
|
||||
# Collect all texts to check
|
||||
texts_to_check = []
|
||||
choice_indices = []
|
||||
|
||||
for idx, choice in enumerate(response.choices):
|
||||
if hasattr(choice, "text") and isinstance(choice.text, str):
|
||||
texts_to_check.append(choice.text)
|
||||
choice_indices.append(idx)
|
||||
|
||||
# Apply guardrails in batch
|
||||
if texts_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
# Include model information from the response if available
|
||||
if hasattr(response, "model") and response.model:
|
||||
inputs["model"] = response.model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Apply guardrailed texts back to choices
|
||||
for guardrail_idx, choice_idx in enumerate(choice_indices):
|
||||
if guardrail_idx < len(guardrailed_texts):
|
||||
original_text = response.choices[choice_idx].text
|
||||
response.choices[choice_idx].text = guardrailed_texts[guardrail_idx]
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text Completion: Applied guardrail to choice[%d] text. "
|
||||
"Original length: %d, New length: %d",
|
||||
choice_idx,
|
||||
len(original_text),
|
||||
len(guardrailed_texts[guardrail_idx]),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,318 @@
|
||||
import json
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import LlmProviders, ModelResponse, TextCompletionResponse
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ..common_utils import BaseOpenAILLM, OpenAIError
|
||||
from .transformation import OpenAITextCompletionConfig
|
||||
|
||||
|
||||
class OpenAITextCompletion(BaseLLM):
|
||||
openai_text_completion_global_config = OpenAITextCompletionConfig()
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def validate_environment(self, api_key):
|
||||
headers = {
|
||||
"content-type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
def completion(
|
||||
self,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
timeout: float,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
api_base: Optional[str] = None,
|
||||
acompletion: bool = False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
):
|
||||
try:
|
||||
if headers is None:
|
||||
headers = self.validate_environment(api_key=api_key)
|
||||
if model is None or messages is None:
|
||||
raise OpenAIError(status_code=422, message="Missing model or messages")
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_text_completion_config(
|
||||
model=model,
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
data = provider_config.transform_text_completion_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
headers=headers,
|
||||
)
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout, max_retries=max_retries, organization=organization, client=client) # type: ignore
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
data=data,
|
||||
headers=headers,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
client=client,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return TextCompletionResponse(**response_json)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def acompletion(
|
||||
self,
|
||||
logging_obj,
|
||||
api_base: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: str,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
organization: Optional[str] = None,
|
||||
client=None,
|
||||
):
|
||||
try:
|
||||
if client is None:
|
||||
openai_aclient = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=BaseOpenAILLM._get_async_http_client(),
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_aclient = client
|
||||
|
||||
raw_response = await openai_aclient.completions.with_raw_response.create(
|
||||
**data
|
||||
)
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response_obj = TextCompletionResponse(**response_json)
|
||||
response_obj._hidden_params.original_response = json.dumps(response_json)
|
||||
return response_obj
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
def streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
api_base: Optional[str] = None,
|
||||
max_retries=None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.client_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
try:
|
||||
raw_response = openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
for chunk in streamwrapper:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
logging_obj,
|
||||
api_key: str,
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model_response: ModelResponse,
|
||||
model: str,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
organization=None,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
http_client=litellm.aclient_session,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
raw_response = await openai_client.completions.with_raw_response.create(**data)
|
||||
response = raw_response.parse()
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="text-completion-openai",
|
||||
logging_obj=logging_obj,
|
||||
stream_options=data.get("stream_options", None),
|
||||
)
|
||||
|
||||
try:
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
Support for gpt model family
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.llms.base_llm.completion.transformation import BaseTextCompletionConfig
|
||||
from litellm.types.llms.openai import AllMessageValues, OpenAITextCompletionUserMessage
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, TextCompletionResponse
|
||||
|
||||
from ..chat.gpt_transformation import OpenAIGPTConfig
|
||||
from .utils import _transform_prompt
|
||||
|
||||
|
||||
class OpenAITextCompletionConfig(BaseTextCompletionConfig, OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/completions/create
|
||||
|
||||
The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters:
|
||||
|
||||
- `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token.
|
||||
|
||||
- `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion.
|
||||
|
||||
- `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line.
|
||||
|
||||
- `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion.
|
||||
|
||||
- `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens.
|
||||
|
||||
- `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion.
|
||||
|
||||
- `n` (integer or null): This optional parameter sets how many completions to generate for each prompt.
|
||||
|
||||
- `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics.
|
||||
|
||||
- `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens.
|
||||
|
||||
- `suffix` (string or null): Defines the suffix that comes after a completion of inserted text.
|
||||
|
||||
- `temperature` (number or null): This optional parameter defines the sampling temperature to use.
|
||||
|
||||
- `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling.
|
||||
"""
|
||||
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
suffix: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
suffix: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def convert_to_chat_model_response_object(
|
||||
self,
|
||||
response_object: Optional[TextCompletionResponse] = None,
|
||||
model_response_object: Optional[ModelResponse] = None,
|
||||
):
|
||||
try:
|
||||
## RESPONSE OBJECT
|
||||
if response_object is None or model_response_object is None:
|
||||
raise ValueError("Error in response object format")
|
||||
choice_list: List[Choices] = []
|
||||
for idx, choice in enumerate(response_object["choices"]):
|
||||
message = Message(
|
||||
content=choice["text"],
|
||||
role="assistant",
|
||||
)
|
||||
choice = Choices(
|
||||
finish_reason=choice["finish_reason"],
|
||||
index=idx,
|
||||
message=message,
|
||||
logprobs=choice.get("logprobs", None),
|
||||
)
|
||||
choice_list.append(choice)
|
||||
model_response_object.choices = choice_list # type: ignore
|
||||
|
||||
if "usage" in response_object:
|
||||
setattr(model_response_object, "usage", response_object["usage"])
|
||||
|
||||
if "id" in response_object:
|
||||
model_response_object.id = response_object["id"]
|
||||
|
||||
if "model" in response_object:
|
||||
model_response_object.model = response_object["model"]
|
||||
|
||||
model_response_object._hidden_params[
|
||||
"original_response"
|
||||
] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response
|
||||
return model_response_object
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"functions",
|
||||
"function_call",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"n",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"max_tokens",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"logit_bias",
|
||||
"user",
|
||||
"response_format",
|
||||
"seed",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"max_retries",
|
||||
"logprobs",
|
||||
"top_logprobs",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def transform_text_completion_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
prompt = _transform_prompt(messages)
|
||||
return {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params,
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
from typing import List, Union, cast
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
AllPromptValues,
|
||||
OpenAITextCompletionUserMessage,
|
||||
)
|
||||
|
||||
|
||||
def is_tokens_or_list_of_tokens(value: List):
|
||||
# Check if it's a list of integers (tokens)
|
||||
if isinstance(value, list) and all(isinstance(item, int) for item in value):
|
||||
return True
|
||||
# Check if it's a list of lists of integers (list of tokens)
|
||||
if isinstance(value, list) and all(
|
||||
isinstance(item, list) and all(isinstance(i, int) for i in item)
|
||||
for item in value
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _transform_prompt(
|
||||
messages: Union[List[AllMessageValues], List[OpenAITextCompletionUserMessage]],
|
||||
) -> AllPromptValues:
|
||||
if len(messages) == 1: # base case
|
||||
message_content = messages[0].get("content")
|
||||
if (
|
||||
message_content
|
||||
and isinstance(message_content, list)
|
||||
and is_tokens_or_list_of_tokens(message_content)
|
||||
):
|
||||
openai_prompt: AllPromptValues = cast(AllPromptValues, message_content)
|
||||
else:
|
||||
openai_prompt = ""
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, messages[0]))
|
||||
openai_prompt += content
|
||||
else:
|
||||
prompt_str_list: List[str] = []
|
||||
for m in messages:
|
||||
try: # expect list of int/list of list of int to be a 1 message array only.
|
||||
content = convert_content_list_to_str(cast(AllMessageValues, m))
|
||||
prompt_str_list.append(content)
|
||||
except Exception as e:
|
||||
raise e
|
||||
openai_prompt = prompt_str_list
|
||||
return openai_prompt
|
||||
@@ -0,0 +1,343 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import (
|
||||
StandardBuiltInToolCostTracking,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.containers.main import (
|
||||
ContainerCreateOptionalRequestParams,
|
||||
ContainerFileListResponse,
|
||||
ContainerListResponse,
|
||||
ContainerObject,
|
||||
DeleteContainerResult,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from ...base_llm.containers.transformation import BaseContainerConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class OpenAIContainerConfig(BaseContainerConfig):
|
||||
"""Configuration class for OpenAI container API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_supported_openai_params(self) -> list:
|
||||
"""Get the list of supported OpenAI parameters for container API."""
|
||||
return [
|
||||
"name",
|
||||
"expires_after",
|
||||
"file_ids",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
container_create_optional_params: ContainerCreateOptionalRequestParams,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(container_create_optional_params)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""Get the complete URL for OpenAI container API."""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
return f"{api_base.rstrip('/')}/containers"
|
||||
|
||||
def transform_container_create_request(
|
||||
self,
|
||||
name: str,
|
||||
container_create_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""Transform the container creation request for OpenAI API."""
|
||||
# Remove extra_headers from optional params as they're handled separately
|
||||
container_create_optional_request_params = {
|
||||
k: v
|
||||
for k, v in container_create_optional_request_params.items()
|
||||
if k not in ["extra_headers"]
|
||||
}
|
||||
|
||||
# Create the request data
|
||||
request_dict = {
|
||||
"name": name,
|
||||
**container_create_optional_request_params,
|
||||
}
|
||||
|
||||
return request_dict
|
||||
|
||||
def transform_container_create_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerObject:
|
||||
"""Transform the OpenAI container creation response."""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
container_obj = ContainerObject(**response_data) # type: ignore[arg-type]
|
||||
|
||||
# Add cost for container creation (OpenAI containers are code interpreter sessions)
|
||||
# https://platform.openai.com/docs/pricing
|
||||
# Each container creation is 1 code interpreter session
|
||||
container_cost = StandardBuiltInToolCostTracking.get_cost_for_code_interpreter(
|
||||
sessions=1,
|
||||
provider="openai",
|
||||
)
|
||||
|
||||
if (
|
||||
not hasattr(container_obj, "_hidden_params")
|
||||
or container_obj._hidden_params is None
|
||||
):
|
||||
container_obj._hidden_params = {}
|
||||
if "additional_headers" not in container_obj._hidden_params:
|
||||
container_obj._hidden_params["additional_headers"] = {}
|
||||
container_obj._hidden_params["additional_headers"][
|
||||
"llm_provider-x-litellm-response-cost"
|
||||
] = container_cost
|
||||
|
||||
return container_obj
|
||||
|
||||
def transform_container_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform the container list request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- GET /v1/containers
|
||||
"""
|
||||
# Use the api_base directly for container list
|
||||
url = api_base
|
||||
|
||||
# Prepare query parameters
|
||||
params = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if limit is not None:
|
||||
params["limit"] = str(limit)
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
|
||||
# Add any extra query parameters
|
||||
if extra_query:
|
||||
params.update(extra_query)
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_container_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerListResponse:
|
||||
"""Transform the OpenAI container list response."""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
container_list = ContainerListResponse(**response_data) # type: ignore[arg-type]
|
||||
|
||||
return container_list
|
||||
|
||||
def transform_container_retrieve_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform the OpenAI container retrieve request."""
|
||||
# For container retrieve, we just need to construct the URL
|
||||
url = f"{api_base.rstrip('/')}/{container_id}"
|
||||
|
||||
# No additional data needed for GET request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_container_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerObject:
|
||||
"""Transform the OpenAI container retrieve response."""
|
||||
response_data = raw_response.json()
|
||||
# Transform the response data
|
||||
container_obj = ContainerObject(**response_data) # type: ignore[arg-type]
|
||||
|
||||
return container_obj
|
||||
|
||||
def transform_container_delete_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform the container delete request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- DELETE /v1/containers/{container_id}
|
||||
"""
|
||||
# Construct the URL for container delete
|
||||
url = f"{api_base.rstrip('/')}/{container_id}"
|
||||
|
||||
# No data needed for DELETE request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_container_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteContainerResult:
|
||||
"""Transform the OpenAI container delete response."""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
delete_result = DeleteContainerResult(**response_data) # type: ignore[arg-type]
|
||||
|
||||
return delete_result
|
||||
|
||||
def transform_container_file_list_request(
|
||||
self,
|
||||
container_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform the container file list request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- GET /v1/containers/{container_id}/files
|
||||
"""
|
||||
# Construct the URL for container files
|
||||
url = f"{api_base.rstrip('/')}/{container_id}/files"
|
||||
|
||||
# Prepare query parameters
|
||||
params: Dict[str, Any] = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if limit is not None:
|
||||
params["limit"] = str(limit)
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
|
||||
# Add any extra query parameters
|
||||
if extra_query:
|
||||
params.update(extra_query)
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_container_file_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ContainerFileListResponse:
|
||||
"""Transform the OpenAI container file list response."""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
file_list = ContainerFileListResponse(**response_data) # type: ignore[arg-type]
|
||||
|
||||
return file_list
|
||||
|
||||
def transform_container_file_content_request(
|
||||
self,
|
||||
container_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform the container file content request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- GET /v1/containers/{container_id}/files/{file_id}/content
|
||||
"""
|
||||
# Construct the URL for container file content
|
||||
url = f"{api_base.rstrip('/')}/{container_id}/files/{file_id}/content"
|
||||
|
||||
# No query parameters needed
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_container_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""Transform the OpenAI container file content response.
|
||||
|
||||
Returns the raw binary content of the file.
|
||||
"""
|
||||
return raw_response.content
|
||||
|
||||
def get_error_class(
|
||||
self,
|
||||
error_message: str,
|
||||
status_code: int,
|
||||
headers: Union[dict, httpx.Headers],
|
||||
) -> BaseLLMException:
|
||||
from ...base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
@@ -0,0 +1,177 @@
|
||||
"""
|
||||
Helper util for handling openai-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import Literal, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import CallTypes, ModelInfo, Usage
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
|
||||
def cost_router(call_type: CallTypes) -> Literal["cost_per_token", "cost_per_second"]:
|
||||
if call_type == CallTypes.atranscription or call_type == CallTypes.transcription:
|
||||
return "cost_per_second"
|
||||
else:
|
||||
return "cost_per_token"
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str, usage: Usage, service_tier: Optional[str] = None
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing anthropic caching information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
## CALCULATE INPUT COST
|
||||
return generic_cost_per_token(
|
||||
model=model,
|
||||
usage=usage,
|
||||
custom_llm_provider="openai",
|
||||
service_tier=service_tier,
|
||||
)
|
||||
# ### Non-cached text tokens
|
||||
# non_cached_text_tokens = usage.prompt_tokens
|
||||
# cached_tokens: Optional[int] = None
|
||||
# if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
|
||||
# cached_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
# non_cached_text_tokens = non_cached_text_tokens - cached_tokens
|
||||
# prompt_cost: float = non_cached_text_tokens * model_info["input_cost_per_token"]
|
||||
# ## Prompt Caching cost calculation
|
||||
# if model_info.get("cache_read_input_token_cost") is not None and cached_tokens:
|
||||
# # Note: We read ._cache_read_input_tokens from the Usage - since cost_calculator.py standardizes the cache read tokens on usage._cache_read_input_tokens
|
||||
# prompt_cost += cached_tokens * (
|
||||
# model_info.get("cache_read_input_token_cost", 0) or 0
|
||||
# )
|
||||
|
||||
# _audio_tokens: Optional[int] = (
|
||||
# usage.prompt_tokens_details.audio_tokens
|
||||
# if usage.prompt_tokens_details is not None
|
||||
# else None
|
||||
# )
|
||||
# _audio_cost_per_token: Optional[float] = model_info.get(
|
||||
# "input_cost_per_audio_token"
|
||||
# )
|
||||
# if _audio_tokens is not None and _audio_cost_per_token is not None:
|
||||
# audio_cost: float = _audio_tokens * _audio_cost_per_token
|
||||
# prompt_cost += audio_cost
|
||||
|
||||
# ## CALCULATE OUTPUT COST
|
||||
# completion_cost: float = (
|
||||
# usage["completion_tokens"] * model_info["output_cost_per_token"]
|
||||
# )
|
||||
# _output_cost_per_audio_token: Optional[float] = model_info.get(
|
||||
# "output_cost_per_audio_token"
|
||||
# )
|
||||
# _output_audio_tokens: Optional[int] = (
|
||||
# usage.completion_tokens_details.audio_tokens
|
||||
# if usage.completion_tokens_details is not None
|
||||
# else None
|
||||
# )
|
||||
# if _output_cost_per_audio_token is not None and _output_audio_tokens is not None:
|
||||
# audio_cost = _output_audio_tokens * _output_cost_per_audio_token
|
||||
# completion_cost += audio_cost
|
||||
|
||||
# return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def cost_per_second(
|
||||
model: str, custom_llm_provider: Optional[str], duration: float = 0.0
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per second for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- custom_llm_provider: str, the custom llm provider
|
||||
- duration: float, the duration of the response in seconds
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
|
||||
## GET MODEL INFO
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider or "openai"
|
||||
)
|
||||
prompt_cost = 0.0
|
||||
completion_cost = 0.0
|
||||
## Speech / Audio cost calculation
|
||||
if (
|
||||
"output_cost_per_second" in model_info
|
||||
and model_info["output_cost_per_second"] is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_second: {model_info.get('output_cost_per_second')}; duration: {duration}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
completion_cost = model_info["output_cost_per_second"] * duration
|
||||
elif (
|
||||
"input_cost_per_second" in model_info
|
||||
and model_info["input_cost_per_second"] is not None
|
||||
):
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - input_cost_per_second: {model_info.get('input_cost_per_second')}; duration: {duration}"
|
||||
)
|
||||
## COST PER SECOND ##
|
||||
prompt_cost = model_info["input_cost_per_second"] * duration
|
||||
completion_cost = 0.0
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def video_generation_cost(
|
||||
model: str,
|
||||
duration_seconds: float,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
model_info: Optional[ModelInfo] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculates the cost for video generation based on duration in seconds.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- duration_seconds: float, the duration of the generated video in seconds
|
||||
- custom_llm_provider: str, the custom llm provider
|
||||
- model_info: Optional[dict], deployment-level model info containing
|
||||
custom video pricing. When provided, skips the global
|
||||
get_model_info() lookup so that deployment-specific pricing is used.
|
||||
|
||||
Returns:
|
||||
float - total_cost_in_usd
|
||||
"""
|
||||
## GET MODEL INFO
|
||||
if model_info is None:
|
||||
model_info = get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider or "openai"
|
||||
)
|
||||
|
||||
# Check for video-specific cost per second
|
||||
video_cost_per_second = model_info.get("output_cost_per_video_per_second")
|
||||
if video_cost_per_second is not None:
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_video_per_second: {video_cost_per_second}; duration: {duration_seconds}"
|
||||
)
|
||||
return video_cost_per_second * duration_seconds
|
||||
|
||||
# Fallback to general output cost per second
|
||||
output_cost_per_second = model_info.get("output_cost_per_second")
|
||||
if output_cost_per_second is not None:
|
||||
verbose_logger.debug(
|
||||
f"For model={model} - output_cost_per_second: {output_cost_per_second}; duration: {duration_seconds}"
|
||||
)
|
||||
return output_cost_per_second * duration_seconds
|
||||
|
||||
# If no cost information found, return 0
|
||||
verbose_logger.warning(
|
||||
f"No cost information found for video model {model}. Please add pricing to model_prices_and_context_window.json"
|
||||
)
|
||||
return 0.0
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Embeddings handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.embeddings.guardrail_translation.handler import (
|
||||
OpenAIEmbeddingsHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.embedding: OpenAIEmbeddingsHandler,
|
||||
CallTypes.aembedding: OpenAIEmbeddingsHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAIEmbeddingsHandler"]
|
||||
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
OpenAI Embeddings Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's embeddings endpoint.
|
||||
The handler processes the 'input' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
|
||||
class OpenAIEmbeddingsHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI embeddings requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input text (pre-call hook)
|
||||
2. Process output response (post-call hook) - embeddings don't typically need output guardrails
|
||||
|
||||
The handler specifically processes the 'input' parameter which can be:
|
||||
- A single string
|
||||
- A list of strings (for batch embeddings)
|
||||
- A list of integers (token IDs - not processed by guardrails)
|
||||
- A list of lists of integers (batch token IDs - not processed by guardrails)
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input text by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'input' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to input
|
||||
"""
|
||||
input_data = data.get("input")
|
||||
if input_data is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Embeddings: No input found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(input_data, str):
|
||||
data = await self._process_string_input(
|
||||
data, input_data, guardrail_to_apply, litellm_logging_obj
|
||||
)
|
||||
elif isinstance(input_data, list):
|
||||
data = await self._process_list_input(
|
||||
data, input_data, guardrail_to_apply, litellm_logging_obj
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Embeddings: Unexpected input type: %s. Expected string or list.",
|
||||
type(input_data),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def _process_string_input(
|
||||
self,
|
||||
data: dict,
|
||||
input_data: str,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any],
|
||||
) -> dict:
|
||||
"""Process a single string input through the guardrail."""
|
||||
inputs = GenericGuardrailAPIInputs(texts=[input_data])
|
||||
if model := data.get("model"):
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
if guardrailed_texts := guardrailed_inputs.get("texts"):
|
||||
data["input"] = guardrailed_texts[0]
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Embeddings: Applied guardrail to string input. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(input_data),
|
||||
len(data["input"]),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def _process_list_input(
|
||||
self,
|
||||
data: dict,
|
||||
input_data: List[Union[str, int, List[int]]],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any],
|
||||
) -> dict:
|
||||
"""Process a list input through the guardrail (if it contains strings)."""
|
||||
if len(input_data) == 0:
|
||||
return data
|
||||
|
||||
first_item = input_data[0]
|
||||
|
||||
# Skip non-text inputs (token IDs)
|
||||
if isinstance(first_item, (int, list)):
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Embeddings: Input is token IDs, skipping guardrail processing"
|
||||
)
|
||||
return data
|
||||
|
||||
if not isinstance(first_item, str):
|
||||
verbose_proxy_logger.warning(
|
||||
"OpenAI Embeddings: Unexpected input list item type: %s",
|
||||
type(first_item),
|
||||
)
|
||||
return data
|
||||
|
||||
# List of strings - apply guardrail
|
||||
inputs = GenericGuardrailAPIInputs(texts=input_data) # type: ignore
|
||||
if model := data.get("model"):
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
if guardrailed_texts := guardrailed_inputs.get("texts"):
|
||||
data["input"] = guardrailed_texts
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Embeddings: Applied guardrail to %d inputs",
|
||||
len(guardrailed_texts),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "EmbeddingResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - embeddings responses contain vectors, not text.
|
||||
|
||||
For embeddings, the output is numerical vectors, so there's typically
|
||||
no text content to apply guardrails to. This method is a no-op but
|
||||
is included for interface consistency.
|
||||
|
||||
Args:
|
||||
response: Embedding response object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata
|
||||
|
||||
Returns:
|
||||
Unmodified response (embeddings don't have text output to guard)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Embeddings: Output response processing skipped - "
|
||||
"embeddings contain vectors, not text"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
OpenAI Evals API configuration
|
||||
"""
|
||||
|
||||
from .transformation import OpenAIEvalsConfig
|
||||
|
||||
__all__ = ["OpenAIEvalsConfig"]
|
||||
@@ -0,0 +1,426 @@
|
||||
"""
|
||||
OpenAI Evals API configuration and transformations
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.evals.transformation import (
|
||||
BaseEvalsAPIConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai_evals import (
|
||||
CancelEvalResponse,
|
||||
CancelRunResponse,
|
||||
CreateEvalRequest,
|
||||
CreateRunRequest,
|
||||
DeleteEvalResponse,
|
||||
Eval,
|
||||
ListEvalsParams,
|
||||
ListEvalsResponse,
|
||||
ListRunsParams,
|
||||
ListRunsResponse,
|
||||
Run,
|
||||
RunDeleteResponse,
|
||||
UpdateEvalRequest,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
class OpenAIEvalsConfig(BaseEvalsAPIConfig):
|
||||
"""OpenAI-specific Evals API configuration"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.OPENAI
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""Add OpenAI-specific headers"""
|
||||
import litellm
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
# Get API key following OpenAI pattern
|
||||
api_key = None
|
||||
if litellm_params:
|
||||
api_key = litellm_params.api_key
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("OPENAI_API_KEY is required for Evals API")
|
||||
|
||||
# Add required headers
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
endpoint: str,
|
||||
eval_id: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get complete URL for OpenAI Evals API"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
|
||||
if eval_id:
|
||||
return f"{api_base}/v1/evals/{eval_id}"
|
||||
return f"{api_base}/v1/{endpoint}"
|
||||
|
||||
def transform_create_eval_request(
|
||||
self,
|
||||
create_request: CreateEvalRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""Transform create eval request for OpenAI"""
|
||||
verbose_logger.debug("Transforming create eval request: %s", create_request)
|
||||
|
||||
# OpenAI expects the request body directly
|
||||
request_body = {k: v for k, v in create_request.items() if v is not None}
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_create_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""Transform OpenAI response to Eval object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming create eval response: %s", response_json)
|
||||
|
||||
return Eval(**response_json)
|
||||
|
||||
def transform_list_evals_request(
|
||||
self,
|
||||
list_params: ListEvalsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform list evals request for OpenAI"""
|
||||
api_base = "https://api.openai.com"
|
||||
if litellm_params and litellm_params.api_base:
|
||||
api_base = litellm_params.api_base
|
||||
|
||||
url = self.get_complete_url(api_base=api_base, endpoint="evals")
|
||||
|
||||
# Build query parameters
|
||||
query_params: Dict[str, Any] = {}
|
||||
if "limit" in list_params and list_params["limit"]:
|
||||
query_params["limit"] = list_params["limit"]
|
||||
if "after" in list_params and list_params["after"]:
|
||||
query_params["after"] = list_params["after"]
|
||||
if "before" in list_params and list_params["before"]:
|
||||
query_params["before"] = list_params["before"]
|
||||
if "order" in list_params and list_params["order"]:
|
||||
query_params["order"] = list_params["order"]
|
||||
if "order_by" in list_params and list_params["order_by"]:
|
||||
query_params["order_by"] = list_params["order_by"]
|
||||
|
||||
verbose_logger.debug(
|
||||
"List evals request made to OpenAI Evals endpoint with params: %s",
|
||||
query_params,
|
||||
)
|
||||
|
||||
return url, query_params
|
||||
|
||||
def transform_list_evals_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListEvalsResponse:
|
||||
"""Transform OpenAI response to ListEvalsResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming list evals response: %s", response_json)
|
||||
|
||||
return ListEvalsResponse(**response_json)
|
||||
|
||||
def transform_get_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform get eval request for OpenAI"""
|
||||
url = self.get_complete_url(
|
||||
api_base=api_base, endpoint="evals", eval_id=eval_id
|
||||
)
|
||||
|
||||
verbose_logger.debug("Get eval request - URL: %s", url)
|
||||
|
||||
return url, headers
|
||||
|
||||
def transform_get_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""Transform OpenAI response to Eval object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming get eval response: %s", response_json)
|
||||
|
||||
return Eval(**response_json)
|
||||
|
||||
def transform_update_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
update_request: UpdateEvalRequest,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""Transform update eval request for OpenAI"""
|
||||
url = self.get_complete_url(
|
||||
api_base=api_base, endpoint="evals", eval_id=eval_id
|
||||
)
|
||||
|
||||
# Build request body
|
||||
request_body = {k: v for k, v in update_request.items() if v is not None}
|
||||
|
||||
verbose_logger.debug(
|
||||
"Update eval request - URL: %s, body: %s", url, request_body
|
||||
)
|
||||
|
||||
return url, headers, request_body
|
||||
|
||||
def transform_update_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Eval:
|
||||
"""Transform OpenAI response to Eval object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming update eval response: %s", response_json)
|
||||
|
||||
return Eval(**response_json)
|
||||
|
||||
def transform_delete_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform delete eval request for OpenAI"""
|
||||
url = self.get_complete_url(
|
||||
api_base=api_base, endpoint="evals", eval_id=eval_id
|
||||
)
|
||||
|
||||
verbose_logger.debug("Delete eval request - URL: %s", url)
|
||||
|
||||
return url, headers
|
||||
|
||||
def transform_delete_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteEvalResponse:
|
||||
"""Transform OpenAI response to DeleteEvalResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming delete eval response: %s", response_json)
|
||||
|
||||
return DeleteEvalResponse(**response_json)
|
||||
|
||||
def transform_cancel_eval_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""Transform cancel eval request for OpenAI"""
|
||||
url = f"{self.get_complete_url(api_base=api_base, endpoint='evals', eval_id=eval_id)}/cancel"
|
||||
|
||||
# Empty body for cancel request
|
||||
request_body: Dict[str, Any] = {}
|
||||
|
||||
verbose_logger.debug("Cancel eval request - URL: %s", url)
|
||||
|
||||
return url, headers, request_body
|
||||
|
||||
def transform_cancel_eval_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelEvalResponse:
|
||||
"""Transform OpenAI response to CancelEvalResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming cancel eval response: %s", response_json)
|
||||
|
||||
return CancelEvalResponse(**response_json)
|
||||
|
||||
# Run API Transformations
|
||||
def transform_create_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
create_request: CreateRunRequest,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform create run request for OpenAI"""
|
||||
api_base = "https://api.openai.com"
|
||||
if litellm_params and litellm_params.api_base:
|
||||
api_base = litellm_params.api_base
|
||||
|
||||
url = f"{api_base}/v1/evals/{eval_id}/runs"
|
||||
|
||||
# Build request body
|
||||
request_body = {k: v for k, v in create_request.items() if v is not None}
|
||||
|
||||
verbose_logger.debug(
|
||||
"Create run request - URL: %s, body: %s", url, request_body
|
||||
)
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_create_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Run:
|
||||
"""Transform OpenAI response to Run object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming create run response: %s", response_json)
|
||||
|
||||
return Run(**response_json)
|
||||
|
||||
def transform_list_runs_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
list_params: ListRunsParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform list runs request for OpenAI"""
|
||||
api_base = "https://api.openai.com"
|
||||
if litellm_params and litellm_params.api_base:
|
||||
api_base = litellm_params.api_base
|
||||
|
||||
url = f"{api_base}/v1/evals/{eval_id}/runs"
|
||||
|
||||
# Build query parameters
|
||||
query_params: Dict[str, Any] = {}
|
||||
if "limit" in list_params and list_params["limit"]:
|
||||
query_params["limit"] = list_params["limit"]
|
||||
if "after" in list_params and list_params["after"]:
|
||||
query_params["after"] = list_params["after"]
|
||||
if "before" in list_params and list_params["before"]:
|
||||
query_params["before"] = list_params["before"]
|
||||
if "order" in list_params and list_params["order"]:
|
||||
query_params["order"] = list_params["order"]
|
||||
|
||||
verbose_logger.debug(
|
||||
"List runs request made to OpenAI Evals endpoint with params: %s",
|
||||
query_params,
|
||||
)
|
||||
|
||||
return url, query_params
|
||||
|
||||
def transform_list_runs_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ListRunsResponse:
|
||||
"""Transform OpenAI response to ListRunsResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming list runs response: %s", response_json)
|
||||
|
||||
return ListRunsResponse(**response_json)
|
||||
|
||||
def transform_get_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""Transform get run request for OpenAI"""
|
||||
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}"
|
||||
|
||||
verbose_logger.debug("Get run request - URL: %s", url)
|
||||
|
||||
return url, headers
|
||||
|
||||
def transform_get_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Run:
|
||||
"""Transform OpenAI response to Run object"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming get run response: %s", response_json)
|
||||
|
||||
return Run(**response_json)
|
||||
|
||||
def transform_cancel_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""Transform cancel run request for OpenAI"""
|
||||
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}/cancel"
|
||||
|
||||
# Empty body for cancel request
|
||||
request_body: Dict[str, Any] = {}
|
||||
|
||||
verbose_logger.debug("Cancel run request - URL: %s", url)
|
||||
|
||||
return url, headers, request_body
|
||||
|
||||
def transform_cancel_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelRunResponse:
|
||||
"""Transform OpenAI response to CancelRunResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming cancel run response: %s", response_json)
|
||||
|
||||
return CancelRunResponse(**response_json)
|
||||
|
||||
def transform_delete_run_request(
|
||||
self,
|
||||
eval_id: str,
|
||||
run_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict, Dict]:
|
||||
"""Transform delete run request for OpenAI"""
|
||||
url = f"{api_base}/v1/evals/{eval_id}/runs/{run_id}"
|
||||
|
||||
# Empty body for delete request
|
||||
request_body: Dict[str, Any] = {}
|
||||
|
||||
verbose_logger.debug("Delete run request - URL: %s", url)
|
||||
|
||||
return url, headers, request_body
|
||||
|
||||
def transform_delete_run_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> RunDeleteResponse:
|
||||
"""Transform OpenAI response to RunDeleteResponse"""
|
||||
response_json = raw_response.json()
|
||||
verbose_logger.debug("Transforming delete run response: %s", response_json)
|
||||
|
||||
return RunDeleteResponse(**response_json)
|
||||
@@ -0,0 +1,278 @@
|
||||
from typing import Any, Coroutine, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.utils import LiteLLMFineTuningJob
|
||||
|
||||
|
||||
class OpenAIFineTuningAPI:
|
||||
"""
|
||||
OpenAI methods to support for batches
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
def get_openai_client(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
_is_async: bool = False,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI,]]:
|
||||
received_args = locals()
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None
|
||||
if client is None:
|
||||
data = {}
|
||||
for k, v in received_args.items():
|
||||
if k == "self" or k == "client" or k == "_is_async":
|
||||
pass
|
||||
elif k == "api_base" and v is not None:
|
||||
data["base_url"] = v
|
||||
elif v is not None:
|
||||
data[k] = v
|
||||
if _is_async is True:
|
||||
openai_client = AsyncOpenAI(**data)
|
||||
else:
|
||||
openai_client = OpenAI(**data) # type: ignore
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
return openai_client
|
||||
|
||||
async def acreate_fine_tuning_job(
|
||||
self,
|
||||
create_fine_tuning_job_data: dict,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> LiteLLMFineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.create(
|
||||
**create_fine_tuning_job_data
|
||||
)
|
||||
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
|
||||
def create_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
create_fine_tuning_job_data: dict,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acreate_fine_tuning_job( # type: ignore
|
||||
create_fine_tuning_job_data=create_fine_tuning_job_data,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"creating fine tuning job, args= %s", create_fine_tuning_job_data
|
||||
)
|
||||
response = cast(OpenAI, openai_client).fine_tuning.jobs.create(
|
||||
**create_fine_tuning_job_data
|
||||
)
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
|
||||
async def acancel_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> LiteLLMFineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
|
||||
def cancel_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
fine_tuning_job_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.acancel_fine_tuning_job( # type: ignore
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("canceling fine tuning job, args= %s", fine_tuning_job_id)
|
||||
response = cast(OpenAI, openai_client).fine_tuning.jobs.cancel(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
|
||||
async def alist_fine_tuning_jobs(
|
||||
self,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
response = await openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
def list_fine_tuning_jobs(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
):
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, (AsyncOpenAI, AsyncAzureOpenAI)):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.alist_fine_tuning_jobs( # type: ignore
|
||||
after=after,
|
||||
limit=limit,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("list fine tuning job, after= %s, limit= %s", after, limit)
|
||||
response = openai_client.fine_tuning.jobs.list(after=after, limit=limit) # type: ignore
|
||||
return response
|
||||
|
||||
async def aretrieve_fine_tuning_job(
|
||||
self,
|
||||
fine_tuning_job_id: str,
|
||||
openai_client: Union[AsyncOpenAI, AsyncAzureOpenAI],
|
||||
) -> LiteLLMFineTuningJob:
|
||||
response = await openai_client.fine_tuning.jobs.retrieve(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
|
||||
def retrieve_fine_tuning_job(
|
||||
self,
|
||||
_is_async: bool,
|
||||
fine_tuning_job_id: str,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = None,
|
||||
) -> Union[LiteLLMFineTuningJob, Coroutine[Any, Any, LiteLLMFineTuningJob]]:
|
||||
openai_client: Optional[
|
||||
Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]
|
||||
] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
_is_async=_is_async,
|
||||
api_version=api_version,
|
||||
)
|
||||
if openai_client is None:
|
||||
raise ValueError(
|
||||
"OpenAI client is not initialized. Make sure api_key is passed or OPENAI_API_KEY is set in the environment."
|
||||
)
|
||||
|
||||
if _is_async is True:
|
||||
if not isinstance(openai_client, AsyncOpenAI):
|
||||
raise ValueError(
|
||||
"OpenAI client is not an instance of AsyncOpenAI. Make sure you passed an AsyncOpenAI client."
|
||||
)
|
||||
return self.aretrieve_fine_tuning_job( # type: ignore
|
||||
fine_tuning_job_id=fine_tuning_job_id,
|
||||
openai_client=openai_client,
|
||||
)
|
||||
verbose_logger.debug("retrieving fine tuning job, id= %s", fine_tuning_job_id)
|
||||
response = cast(OpenAI, openai_client).fine_tuning.jobs.retrieve(
|
||||
fine_tuning_job_id=fine_tuning_job_id
|
||||
)
|
||||
return LiteLLMFineTuningJob(**response.model_dump())
|
||||
@@ -0,0 +1,29 @@
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
|
||||
from .dalle2_transformation import DallE2ImageEditConfig
|
||||
from .transformation import OpenAIImageEditConfig
|
||||
|
||||
__all__ = [
|
||||
"OpenAIImageEditConfig",
|
||||
"DallE2ImageEditConfig",
|
||||
"get_openai_image_edit_config",
|
||||
]
|
||||
|
||||
|
||||
def get_openai_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
"""
|
||||
Get the appropriate OpenAI image edit config based on the model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "dall-e-2", "gpt-image-1")
|
||||
|
||||
Returns:
|
||||
The appropriate config instance for the model
|
||||
"""
|
||||
model_normalized = model.lower().replace("-", "").replace("_", "")
|
||||
|
||||
if model_normalized == "dalle2":
|
||||
return DallE2ImageEditConfig()
|
||||
else:
|
||||
# Default to standard OpenAI config for gpt-image-1 and other models
|
||||
return OpenAIImageEditConfig()
|
||||
@@ -0,0 +1,104 @@
|
||||
from io import BufferedReader
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.types.images.main import ImageEditRequestParams
|
||||
from litellm.types.llms.openai import FileTypes
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
from .transformation import OpenAIImageEditConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class DallE2ImageEditConfig(OpenAIImageEditConfig):
|
||||
"""
|
||||
DALL-E-2 specific configuration for image edit API.
|
||||
|
||||
DALL-E-2 only supports editing a single image (not an array).
|
||||
Uses "image" field name instead of "image[]".
|
||||
"""
|
||||
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
"""
|
||||
Transform image edit request for DALL-E-2.
|
||||
|
||||
DALL-E-2 only accepts a single image with field name "image" (not "image[]").
|
||||
"""
|
||||
request_params = {
|
||||
"model": model,
|
||||
**image_edit_optional_request_params,
|
||||
}
|
||||
if image is not None:
|
||||
request_params["image"] = image
|
||||
if prompt is not None:
|
||||
request_params["prompt"] = prompt
|
||||
|
||||
request = ImageEditRequestParams(**request_params)
|
||||
request_dict = cast(Dict, request)
|
||||
|
||||
#########################################################
|
||||
# Separate images and masks as `files` and send other parameters as `data`
|
||||
#########################################################
|
||||
_image_list = request_dict.get("image")
|
||||
_mask = request_dict.get("mask")
|
||||
data_without_files = {
|
||||
k: v for k, v in request_dict.items() if k not in ["image", "mask"]
|
||||
}
|
||||
files_list: List[Tuple[str, Any]] = []
|
||||
|
||||
# Handle image parameter - DALL-E-2 only supports single image
|
||||
if _image_list is not None:
|
||||
image_list = (
|
||||
[_image_list] if not isinstance(_image_list, list) else _image_list
|
||||
)
|
||||
|
||||
# Validate only one image is provided
|
||||
if len(image_list) > 1:
|
||||
raise litellm.BadRequestError(
|
||||
message="DALL-E-2 only supports editing a single image. Please provide one image.",
|
||||
model=model,
|
||||
llm_provider="openai",
|
||||
)
|
||||
|
||||
# Use "image" field name (singular) for DALL-E-2
|
||||
for _image in image_list:
|
||||
if _image is not None:
|
||||
self._add_image_to_files(
|
||||
files_list=files_list,
|
||||
image=_image,
|
||||
field_name="image",
|
||||
)
|
||||
|
||||
# Handle mask parameter if provided
|
||||
if _mask is not None:
|
||||
# Handle case where mask can be a list (extract first mask)
|
||||
if isinstance(_mask, list):
|
||||
_mask = _mask[0] if _mask else None
|
||||
|
||||
if _mask is not None:
|
||||
mask_content_type: str = ImageEditRequestUtils.get_image_content_type(
|
||||
_mask
|
||||
)
|
||||
if isinstance(_mask, BufferedReader):
|
||||
files_list.append(("mask", (_mask.name, _mask, mask_content_type)))
|
||||
else:
|
||||
files_list.append(("mask", ("mask.png", _mask, mask_content_type)))
|
||||
|
||||
return data_without_files, files_list
|
||||
@@ -0,0 +1,202 @@
|
||||
from io import BufferedReader
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import (
|
||||
ImageEditOptionalRequestParams,
|
||||
ImageEditRequestParams,
|
||||
)
|
||||
from litellm.types.llms.openai import FileTypes
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIImageEditConfig(BaseImageEditConfig):
|
||||
"""
|
||||
Base configuration for OpenAI image edit API.
|
||||
Used for models like gpt-image-1 that support multiple images.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
All OpenAI Image Edits params are supported
|
||||
"""
|
||||
return [
|
||||
"image",
|
||||
"prompt",
|
||||
"background",
|
||||
"input_fidelity",
|
||||
"mask",
|
||||
"model",
|
||||
"n",
|
||||
"quality",
|
||||
"response_format",
|
||||
"size",
|
||||
"user",
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
"extra_body",
|
||||
"timeout",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(image_edit_optional_params)
|
||||
|
||||
def _add_image_to_files(
|
||||
self,
|
||||
files_list: List[Tuple[str, Any]],
|
||||
image: Any,
|
||||
field_name: str,
|
||||
) -> None:
|
||||
"""Add an image to the files list with appropriate content type"""
|
||||
image_content_type = ImageEditRequestUtils.get_image_content_type(image)
|
||||
|
||||
if isinstance(image, BufferedReader):
|
||||
files_list.append((field_name, (image.name, image, image_content_type)))
|
||||
else:
|
||||
files_list.append((field_name, ("image.png", image, image_content_type)))
|
||||
|
||||
def transform_image_edit_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles]:
|
||||
"""
|
||||
Transform image edit request to OpenAI API format.
|
||||
|
||||
Handles multipart/form-data for images. Uses "image[]" field name
|
||||
to support multiple images (e.g., for gpt-image-1).
|
||||
"""
|
||||
# Build request params, only including non-None values
|
||||
request_params = {
|
||||
"model": model,
|
||||
**image_edit_optional_request_params,
|
||||
}
|
||||
if image is not None:
|
||||
request_params["image"] = image
|
||||
if prompt is not None:
|
||||
request_params["prompt"] = prompt
|
||||
|
||||
request = ImageEditRequestParams(**request_params)
|
||||
request_dict = cast(Dict, request)
|
||||
|
||||
#########################################################
|
||||
# Separate images and masks as `files` and send other parameters as `data`
|
||||
#########################################################
|
||||
_image_list = request_dict.get("image")
|
||||
_mask = request_dict.get("mask")
|
||||
data_without_files = {
|
||||
k: v for k, v in request_dict.items() if k not in ["image", "mask"]
|
||||
}
|
||||
files_list: List[Tuple[str, Any]] = []
|
||||
|
||||
# Handle image parameter
|
||||
if _image_list is not None:
|
||||
image_list = (
|
||||
[_image_list] if not isinstance(_image_list, list) else _image_list
|
||||
)
|
||||
|
||||
for _image in image_list:
|
||||
if _image is not None:
|
||||
self._add_image_to_files(
|
||||
files_list=files_list,
|
||||
image=_image,
|
||||
field_name="image[]",
|
||||
)
|
||||
# Handle mask parameter if provided
|
||||
if _mask is not None:
|
||||
# Handle case where mask can be a list (extract first mask)
|
||||
if isinstance(_mask, list):
|
||||
_mask = _mask[0] if _mask else None
|
||||
|
||||
if _mask is not None:
|
||||
mask_content_type: str = ImageEditRequestUtils.get_image_content_type(
|
||||
_mask
|
||||
)
|
||||
if isinstance(_mask, BufferedReader):
|
||||
files_list.append(("mask", (_mask.name, _mask, mask_content_type)))
|
||||
else:
|
||||
files_list.append(("mask", ("mask.png", _mask, mask_content_type)))
|
||||
|
||||
return data_without_files, files_list
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ImageResponse:
|
||||
"""No transform applied since outputs are in OpenAI spec already"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return ImageResponse(**raw_response_json)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the endpoint for OpenAI responses API
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/images/edits"
|
||||
@@ -0,0 +1,28 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .dall_e_2_transformation import DallE2ImageGenerationConfig
|
||||
from .dall_e_3_transformation import DallE3ImageGenerationConfig
|
||||
from .gpt_transformation import GPTImageGenerationConfig
|
||||
from .guardrail_translation import (
|
||||
OpenAIImageGenerationHandler,
|
||||
guardrail_translation_mappings,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"DallE2ImageGenerationConfig",
|
||||
"DallE3ImageGenerationConfig",
|
||||
"GPTImageGenerationConfig",
|
||||
"OpenAIImageGenerationHandler",
|
||||
"guardrail_translation_mappings",
|
||||
]
|
||||
|
||||
|
||||
def get_openai_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
if model.startswith("dall-e-2") or model == "": # empty model is dall-e-2
|
||||
return DallE2ImageGenerationConfig()
|
||||
elif model.startswith("dall-e-3"):
|
||||
return DallE3ImageGenerationConfig()
|
||||
else:
|
||||
return GPTImageGenerationConfig()
|
||||
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
Cost calculator for OpenAI image generation models (gpt-image-1, gpt-image-1-mini)
|
||||
|
||||
These models use token-based pricing instead of pixel-based pricing like DALL-E.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
from litellm.types.utils import ImageResponse, Usage
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate cost for OpenAI gpt-image-1 and gpt-image-1-mini models.
|
||||
|
||||
Uses the same usage format as Responses API, so we reuse the helper
|
||||
to transform to chat completion format and use generic_cost_per_token.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "gpt-image-1", "gpt-image-1-mini")
|
||||
image_response: The ImageResponse containing usage data
|
||||
custom_llm_provider: Optional provider name
|
||||
|
||||
Returns:
|
||||
float: Total cost in USD
|
||||
"""
|
||||
usage = getattr(image_response, "usage", None)
|
||||
|
||||
if usage is None:
|
||||
verbose_logger.debug(
|
||||
f"No usage data available for {model}, cannot calculate token-based cost"
|
||||
)
|
||||
return 0.0
|
||||
|
||||
# If usage is already a Usage object with completion_tokens_details set,
|
||||
# use it directly (it was already transformed in convert_to_image_response)
|
||||
if isinstance(usage, Usage) and usage.completion_tokens_details is not None:
|
||||
chat_usage = usage
|
||||
else:
|
||||
# Transform ImageUsage to Usage using the existing helper
|
||||
# ImageUsage has the same format as ResponseAPIUsage
|
||||
from litellm.responses.utils import ResponseAPILoggingUtils
|
||||
|
||||
chat_usage = (
|
||||
ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage(usage)
|
||||
)
|
||||
|
||||
# Use generic_cost_per_token for cost calculation
|
||||
prompt_cost, completion_cost = generic_cost_per_token(
|
||||
model=model,
|
||||
usage=chat_usage,
|
||||
custom_llm_provider=custom_llm_provider or "openai",
|
||||
)
|
||||
|
||||
total_cost = prompt_cost + completion_cost
|
||||
|
||||
verbose_logger.debug(
|
||||
f"OpenAI gpt-image cost calculation for {model}: "
|
||||
f"prompt_cost=${prompt_cost:.6f}, completion_cost=${completion_cost:.6f}, "
|
||||
f"total=${total_cost:.6f}"
|
||||
)
|
||||
|
||||
return total_cost
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class DallE2ImageGenerationConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
OpenAI dall-e-2 image generation config
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
return ["n", "response_format", "quality", "size", "user"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
optional_params[k] = non_default_params[k]
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
response = raw_response.json()
|
||||
|
||||
stringified_response = response
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=request_data.get("prompt", ""),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
response_type="image_generation",
|
||||
)
|
||||
|
||||
# set optional params
|
||||
image_response.size = optional_params.get(
|
||||
"size", "1024x1024"
|
||||
) # default is always 1024x1024
|
||||
image_response.quality = optional_params.get(
|
||||
"quality", "standard"
|
||||
) # always standard for dall-e-2
|
||||
image_response.output_format = optional_params.get(
|
||||
"output_format", "png"
|
||||
) # always png for dall-e-2
|
||||
|
||||
return image_response
|
||||
@@ -0,0 +1,87 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class DallE3ImageGenerationConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
OpenAI dall-e-3 image generation config
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
return ["n", "response_format", "quality", "size", "user", "style"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
optional_params[k] = non_default_params[k]
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
response = raw_response.json()
|
||||
|
||||
stringified_response = response
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=request_data.get("prompt", ""),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
response_type="image_generation",
|
||||
)
|
||||
|
||||
# set optional params
|
||||
image_response.size = optional_params.get(
|
||||
"size", "1024x1024"
|
||||
) # default is always 1024x1024
|
||||
image_response.quality = optional_params.get(
|
||||
"quality", "hd"
|
||||
) # always hd for dall-e-3
|
||||
image_response.output_format = optional_params.get(
|
||||
"output_format", "png"
|
||||
) # always png for dall-e-3
|
||||
|
||||
return image_response
|
||||
@@ -0,0 +1,96 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIImageGenerationOptionalParams
|
||||
from litellm.types.utils import ImageResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.logging import Logging as LiteLLMLoggingObj
|
||||
|
||||
|
||||
class GPTImageGenerationConfig(BaseImageGenerationConfig):
|
||||
"""
|
||||
OpenAI gpt-image-1 image generation config
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
return [
|
||||
"background",
|
||||
"moderation",
|
||||
"n",
|
||||
"output_compression",
|
||||
"output_format",
|
||||
"quality",
|
||||
"size",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k in non_default_params.keys():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
optional_params[k] = non_default_params[k]
|
||||
elif drop_params:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. Supported parameters are {supported_params}. Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: "LiteLLMLoggingObj",
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
response = raw_response.json()
|
||||
|
||||
stringified_response = response
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=request_data.get("prompt", ""),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": request_data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
image_response: ImageResponse = convert_to_model_response_object( # type: ignore
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
response_type="image_generation",
|
||||
)
|
||||
|
||||
# set optional params
|
||||
image_response.size = optional_params.get(
|
||||
"size", "1024x1024"
|
||||
) # default is always 1024x1024
|
||||
image_response.quality = optional_params.get(
|
||||
"quality", "high"
|
||||
) # always hd for dall-e-3
|
||||
image_response.output_format = optional_params.get(
|
||||
"response_format", "png"
|
||||
) # always png for dall-e-3
|
||||
|
||||
return image_response
|
||||
@@ -0,0 +1,106 @@
|
||||
# OpenAI Image Generation Guardrail Translation Handler
|
||||
|
||||
Handler for processing OpenAI's image generation endpoint with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes image generation requests by:
|
||||
1. Extracting the text prompt from the request
|
||||
2. Applying guardrails to the prompt text
|
||||
3. Updating the request with the guardrailed prompt
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "A cute baby sea otter",
|
||||
"n": 1,
|
||||
"size": "1024x1024",
|
||||
"quality": "standard"
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"created": 1589478378,
|
||||
"data": [
|
||||
{
|
||||
"url": "https://...",
|
||||
"revised_prompt": "A cute baby sea otter..."
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the image generation endpoint.
|
||||
|
||||
### Example: Using Guardrails with Image Generation
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/images/generations' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "A cute baby sea otter wearing a hat",
|
||||
"guardrails": ["content_moderation"],
|
||||
"size": "1024x1024"
|
||||
}'
|
||||
```
|
||||
|
||||
The guardrail will be applied to the prompt text before the image generation request is sent to the provider.
|
||||
|
||||
### Example: PII Masking in Prompts
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/images/generations' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "dall-e-3",
|
||||
"prompt": "Generate an image of John Doe at john@example.com",
|
||||
"guardrails": ["mask_pii"],
|
||||
"metadata": {
|
||||
"guardrails": ["mask_pii"]
|
||||
}
|
||||
}'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Field**: `prompt` (string)
|
||||
- **Processing**: Applies guardrail to prompt text
|
||||
- **Result**: Updated prompt in request
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Processing**: Not applicable (images don't contain text to guardrail)
|
||||
- **Result**: Response returned unchanged
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how the prompt is processed
|
||||
- `process_output_response()`: Add custom processing for image metadata if needed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.image_generation` - Synchronous image generation
|
||||
- `CallTypes.aimage_generation` - Asynchronous image generation
|
||||
|
||||
## Notes
|
||||
|
||||
- The handler only processes the `prompt` parameter
|
||||
- Output processing is a no-op since images don't contain text
|
||||
- Both sync and async call types use the same handler
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Image Generation handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.image_generation.guardrail_translation.handler import (
|
||||
OpenAIImageGenerationHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.image_generation: OpenAIImageGenerationHandler,
|
||||
CallTypes.aimage_generation: OpenAIImageGenerationHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAIImageGenerationHandler"]
|
||||
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
OpenAI Image Generation Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's image generation endpoint.
|
||||
The handler processes the 'prompt' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.utils import ImageResponse
|
||||
|
||||
|
||||
class OpenAIImageGenerationHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI image generation requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input prompt (pre-call hook)
|
||||
2. Process output response (post-call hook) - typically not needed for images
|
||||
|
||||
The handler specifically processes the 'prompt' parameter which contains
|
||||
the text description for image generation.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input prompt by applying guardrails to text content.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'prompt' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to prompt
|
||||
"""
|
||||
prompt = data.get("prompt")
|
||||
if prompt is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Image Generation: No prompt found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
# Apply guardrail to the prompt
|
||||
if isinstance(prompt, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[prompt])
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["prompt"] = guardrailed_texts[0] if guardrailed_texts else prompt
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Image Generation: Applied guardrail to prompt. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(prompt),
|
||||
len(data["prompt"]),
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Image Generation: Unexpected prompt type: %s. Expected string.",
|
||||
type(prompt),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "ImageResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - typically not needed for image generation.
|
||||
|
||||
Image responses don't contain text to apply guardrails to, so this
|
||||
method returns the response unchanged. This is provided for completeness
|
||||
and can be overridden if needed for custom image metadata processing.
|
||||
|
||||
Args:
|
||||
response: Image generation response object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object (unused)
|
||||
user_api_key_dict: User API key metadata (unused)
|
||||
|
||||
Returns:
|
||||
Unmodified response (images don't need text guardrails)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Image Generation: Output processing not needed for image responses"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
OpenAI Image Variations Handler
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import FileTypes, ImageResponse, LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||
from ...custom_httpx.llm_http_handler import LiteLLMLoggingObj
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIImageVariationsHandler:
|
||||
def get_sync_client(
|
||||
self,
|
||||
client: Optional[OpenAI],
|
||||
init_client_params: dict,
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(
|
||||
**init_client_params,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
return openai_client
|
||||
|
||||
def get_async_client(
|
||||
self, client: Optional[AsyncOpenAI], init_client_params: dict
|
||||
) -> AsyncOpenAI:
|
||||
if client is None:
|
||||
openai_client = AsyncOpenAI(
|
||||
**init_client_params,
|
||||
)
|
||||
else:
|
||||
openai_client = client
|
||||
return openai_client
|
||||
|
||||
async def async_image_variations(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
organization: Optional[str],
|
||||
client: Optional[AsyncOpenAI],
|
||||
data: dict,
|
||||
headers: dict,
|
||||
model: Optional[str],
|
||||
timeout: Optional[float],
|
||||
max_retries: int,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
image: FileTypes,
|
||||
provider_config: BaseImageVariationConfig,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
init_client_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": api_base,
|
||||
"http_client": litellm.client_session,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries, # type: ignore
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
client = self.get_async_client(
|
||||
client=client, init_client_params=init_client_params
|
||||
)
|
||||
|
||||
raw_response = await client.images.with_raw_response.create_variation(**data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return provider_config.transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=ImageResponse(**response_json),
|
||||
raw_response=httpx.Response(
|
||||
status_code=200,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
),
|
||||
logging_obj=logging_obj,
|
||||
request_data=data,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
|
||||
def image_variations(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
timeout: Optional[float],
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
print_verbose: Optional[Callable] = None,
|
||||
logger_fn=None,
|
||||
client=None,
|
||||
organization: Optional[str] = None,
|
||||
headers: Optional[dict] = None,
|
||||
) -> ImageResponse:
|
||||
try:
|
||||
provider_config = ProviderConfigManager.get_provider_image_variation_config(
|
||||
model=model or "", # openai defaults to dall-e-2
|
||||
provider=LlmProviders.OPENAI,
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(
|
||||
f"image variation provider not found: {custom_llm_provider}."
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
||||
data = provider_config.transform_request_image_variation(
|
||||
model=model,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
headers=headers or {},
|
||||
)
|
||||
json_data = data.get("data")
|
||||
if not json_data:
|
||||
raise ValueError(
|
||||
f"data field is required, for openai image variations. Got={data}"
|
||||
)
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
if litellm_params.get("async_call", False):
|
||||
return self.async_image_variations(
|
||||
api_base=api_base,
|
||||
data=json_data,
|
||||
headers=headers or {},
|
||||
model_response=model_response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
model=model,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
organization=organization,
|
||||
client=client,
|
||||
provider_config=provider_config,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
) # type: ignore
|
||||
|
||||
init_client_params = {
|
||||
"api_key": api_key,
|
||||
"base_url": api_base,
|
||||
"http_client": litellm.client_session,
|
||||
"timeout": timeout,
|
||||
"max_retries": max_retries, # type: ignore
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
client = self.get_sync_client(
|
||||
client=client, init_client_params=init_client_params
|
||||
)
|
||||
|
||||
raw_response = client.images.with_raw_response.create_variation(**json_data) # type: ignore
|
||||
response = raw_response.parse()
|
||||
response_json = response.model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
api_key=api_key,
|
||||
original_response=response_json,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
return provider_config.transform_response_image_variation(
|
||||
model=model,
|
||||
model_response=ImageResponse(**response_json),
|
||||
raw_response=httpx.Response(
|
||||
status_code=200,
|
||||
request=httpx.Request(
|
||||
method="GET", url="https://litellm.ai"
|
||||
), # mock request object
|
||||
),
|
||||
logging_obj=logging_obj,
|
||||
request_data=json_data,
|
||||
image=image,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", 500)
|
||||
error_headers = getattr(e, "headers", None)
|
||||
error_text = getattr(e, "text", str(e))
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_headers is None and error_response:
|
||||
error_headers = getattr(error_response, "headers", None)
|
||||
raise OpenAIError(
|
||||
status_code=status_code, message=error_text, headers=error_headers
|
||||
)
|
||||
@@ -0,0 +1,82 @@
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from aiohttp import ClientResponse
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.image_variations.transformation import LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import OpenAIImageVariationOptionalParams
|
||||
from litellm.types.utils import FileTypes, HttpHandlerRequestFields, ImageResponse
|
||||
|
||||
from ...base_llm.image_variations.transformation import BaseImageVariationConfig
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIImageVariationConfig(BaseImageVariationConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageVariationOptionalParams]:
|
||||
return ["n", "size", "response_format", "user"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
optional_params.update(non_default_params)
|
||||
return optional_params
|
||||
|
||||
def transform_request_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> HttpHandlerRequestFields:
|
||||
return {
|
||||
"data": {
|
||||
"image": image,
|
||||
**optional_params,
|
||||
}
|
||||
}
|
||||
|
||||
async def async_transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: ClientResponse,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
return model_response
|
||||
|
||||
def transform_response_image_variation(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
image: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> ImageResponse:
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
This file contains the calling OpenAI's `/v1/realtime` endpoint.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from litellm.constants import REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES
|
||||
from litellm.types.realtime import RealtimeQueryParams
|
||||
|
||||
from ....litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from ....litellm_core_utils.realtime_streaming import RealTimeStreaming
|
||||
from ....llms.custom_httpx.http_handler import get_shared_realtime_ssl_context
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
class OpenAIRealtime(OpenAIChatCompletion):
|
||||
"""
|
||||
Base handler for OpenAI-compatible realtime WebSocket connections.
|
||||
|
||||
Subclasses can override template methods to customize:
|
||||
- _get_default_api_base(): Default API base URL
|
||||
- _get_additional_headers(): Extra headers beyond Authorization
|
||||
- _get_ssl_config(): SSL configuration for WebSocket connection
|
||||
"""
|
||||
|
||||
def _get_default_api_base(self) -> str:
|
||||
"""
|
||||
Get the default API base URL for this provider.
|
||||
Override this in subclasses to set provider-specific defaults.
|
||||
"""
|
||||
return "https://api.openai.com/"
|
||||
|
||||
def _get_additional_headers(self, api_key: str) -> dict:
|
||||
"""
|
||||
Get additional headers beyond Authorization.
|
||||
Override this in subclasses to customize headers (e.g., remove OpenAI-Beta).
|
||||
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
|
||||
Returns:
|
||||
Dictionary of additional headers
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"OpenAI-Beta": "realtime=v1",
|
||||
}
|
||||
|
||||
def _get_ssl_config(self, url: str) -> Any:
|
||||
"""
|
||||
Get SSL configuration for WebSocket connection.
|
||||
Override this in subclasses to customize SSL behavior.
|
||||
|
||||
Args:
|
||||
url: WebSocket URL (ws:// or wss://)
|
||||
|
||||
Returns:
|
||||
SSL configuration (None, True, or SSLContext)
|
||||
"""
|
||||
if url.startswith("ws://"):
|
||||
return None
|
||||
|
||||
# Use the shared SSL context which respects custom CA certs and SSL settings
|
||||
ssl_config = get_shared_realtime_ssl_context()
|
||||
|
||||
# If ssl_config is False (ssl_verify=False), websockets library needs True instead
|
||||
# to establish connection without verification (False would fail)
|
||||
if ssl_config is False:
|
||||
return True
|
||||
|
||||
return ssl_config
|
||||
|
||||
def _construct_url(self, api_base: str, query_params: RealtimeQueryParams) -> str:
|
||||
"""
|
||||
Construct the backend websocket URL with all query parameters (including 'model').
|
||||
"""
|
||||
from httpx import URL
|
||||
|
||||
api_base = api_base.replace("https://", "wss://")
|
||||
api_base = api_base.replace("http://", "ws://")
|
||||
url = URL(api_base)
|
||||
# Set the correct path
|
||||
url = url.copy_with(path="/v1/realtime")
|
||||
# Include all query parameters including 'model'
|
||||
if query_params:
|
||||
url = url.copy_with(params=query_params)
|
||||
return str(url)
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
client: Optional[Any] = None,
|
||||
timeout: Optional[float] = None,
|
||||
query_params: Optional[RealtimeQueryParams] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
litellm_metadata: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
import websockets
|
||||
from websockets.asyncio.client import ClientConnection
|
||||
|
||||
if api_base is None:
|
||||
api_base = self._get_default_api_base()
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required for OpenAI realtime calls")
|
||||
|
||||
# Use all query params if provided, else fallback to just model
|
||||
if query_params is None:
|
||||
query_params = {"model": model}
|
||||
url = self._construct_url(api_base, query_params)
|
||||
|
||||
try:
|
||||
# Get provider-specific SSL configuration
|
||||
ssl_config = self._get_ssl_config(url)
|
||||
|
||||
# Get provider-specific headers
|
||||
headers = self._get_additional_headers(api_key)
|
||||
|
||||
# Log a masked request preview consistent with other endpoints.
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
"complete_input_dict": {"query_params": query_params},
|
||||
},
|
||||
)
|
||||
async with websockets.connect( # type: ignore
|
||||
url,
|
||||
additional_headers=headers, # type: ignore
|
||||
max_size=REALTIME_WEBSOCKET_MAX_MESSAGE_SIZE_BYTES,
|
||||
ssl=ssl_config,
|
||||
) as backend_ws:
|
||||
realtime_streaming = RealTimeStreaming(
|
||||
websocket,
|
||||
cast(ClientConnection, backend_ws),
|
||||
logging_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data={"litellm_metadata": litellm_metadata or {}},
|
||||
)
|
||||
await realtime_streaming.bidirectional_forward()
|
||||
|
||||
except websockets.exceptions.InvalidStatusCode as e: # type: ignore
|
||||
await websocket.close(code=e.status_code, reason=str(e))
|
||||
except Exception as e:
|
||||
try:
|
||||
await websocket.close(
|
||||
code=1011, reason=f"Internal server error: {str(e)}"
|
||||
)
|
||||
except RuntimeError as close_error:
|
||||
if "already completed" in str(close_error) or "websocket.close" in str(
|
||||
close_error
|
||||
):
|
||||
# The WebSocket is already closed or the response is completed, so we can ignore this error
|
||||
pass
|
||||
else:
|
||||
# If it's a different RuntimeError, we might want to log it or handle it differently
|
||||
raise Exception(
|
||||
f"Unexpected error while closing WebSocket: {close_error}"
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
"""OpenAI realtime HTTP transformation config (client_secrets + realtime_calls)."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.realtime.http_transformation import BaseRealtimeHTTPConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
|
||||
class OpenAIRealtimeHTTPConfig(BaseRealtimeHTTPConfig):
|
||||
def get_api_base(self, api_base: Optional[str], **kwargs) -> str:
|
||||
return (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com"
|
||||
)
|
||||
|
||||
def get_api_key(self, api_key: Optional[str], **kwargs) -> str:
|
||||
return (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
or ""
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return f"{base}/v1/realtime/client_secrets"
|
||||
|
||||
def get_realtime_calls_url(
|
||||
self, api_base: Optional[str], model: str, api_version: Optional[str] = None
|
||||
) -> str:
|
||||
base = self.get_api_base(api_base).rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return f"{base}/v1/realtime/calls"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {
|
||||
**headers,
|
||||
"Authorization": f"Bearer {api_key or ''}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
OpenAI Responses API token counting implementation.
|
||||
"""
|
||||
|
||||
from litellm.llms.openai.responses.count_tokens.handler import (
|
||||
OpenAICountTokensHandler,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.token_counter import (
|
||||
OpenAITokenCounter,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OpenAICountTokensHandler",
|
||||
"OpenAICountTokensConfig",
|
||||
"OpenAITokenCounter",
|
||||
]
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
OpenAI Responses API token counting handler.
|
||||
|
||||
Uses httpx for HTTP requests to OpenAI's /v1/responses/input_tokens endpoint.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
|
||||
|
||||
class OpenAICountTokensHandler(OpenAICountTokensConfig):
|
||||
"""
|
||||
Handler for OpenAI Responses API token counting requests.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[Any]],
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a token counting request to OpenAI's Responses API.
|
||||
|
||||
Returns:
|
||||
Dictionary containing {"input_tokens": <number>}
|
||||
|
||||
Raises:
|
||||
OpenAIError: If the API request fails
|
||||
"""
|
||||
try:
|
||||
self.validate_request(model, input)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing OpenAI CountTokens request for model: {model}"
|
||||
)
|
||||
|
||||
request_body = self.transform_request_to_count_tokens(
|
||||
model=model,
|
||||
input=input,
|
||||
tools=tools,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
endpoint_url = self.get_openai_count_tokens_endpoint(api_base)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
headers = self.get_required_headers(api_key)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.OPENAI
|
||||
)
|
||||
|
||||
request_timeout = (
|
||||
timeout if timeout is not None else litellm.request_timeout
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
json=request_body,
|
||||
timeout=request_timeout,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"OpenAI API error: {error_text}")
|
||||
raise OpenAIError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
openai_response = response.json()
|
||||
verbose_logger.debug(f"OpenAI response: {openai_response}")
|
||||
return openai_response
|
||||
|
||||
except OpenAIError:
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise OpenAIError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except (httpx.RequestError, json.JSONDecodeError, ValueError) as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise OpenAIError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
OpenAI Token Counter implementation using the Responses API /input_tokens endpoint.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
from litellm.llms.openai.responses.count_tokens.handler import (
|
||||
OpenAICountTokensHandler,
|
||||
)
|
||||
from litellm.llms.openai.responses.count_tokens.transformation import (
|
||||
OpenAICountTokensConfig,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
# Global handler instance - reuse across all token counting requests
|
||||
openai_count_tokens_handler = OpenAICountTokensHandler()
|
||||
|
||||
|
||||
class OpenAITokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for OpenAI provider using the Responses API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return custom_llm_provider == LlmProviders.OPENAI.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using OpenAI's Responses API /input_tokens endpoint.
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Get OpenAI API key from deployment config or environment
|
||||
api_key = litellm_params.get("api_key")
|
||||
if not api_key:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
verbose_logger.warning("No OpenAI API key found for token counting")
|
||||
return None
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
|
||||
# Convert chat messages to Responses API input format
|
||||
input_items, instructions = OpenAICountTokensConfig.messages_to_responses_input(
|
||||
messages
|
||||
)
|
||||
|
||||
# Use system param if instructions not extracted from messages
|
||||
if instructions is None and system is not None:
|
||||
instructions = system if isinstance(system, str) else str(system)
|
||||
|
||||
# If no input items were produced (e.g., system-only messages), fall back to local counting
|
||||
if not input_items:
|
||||
return None
|
||||
|
||||
try:
|
||||
result = await openai_count_tokens_handler.handle_count_tokens_request(
|
||||
model=model_to_use,
|
||||
input=input_items if input_items is not None else [],
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
tools=tools,
|
||||
instructions=instructions,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
original_response=result,
|
||||
)
|
||||
except OpenAIError as e:
|
||||
verbose_logger.warning(
|
||||
f"OpenAI CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error calling OpenAI CountTokens API: {e}")
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="openai_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
OpenAI Responses API token counting transformation logic.
|
||||
|
||||
This module handles the transformation of requests to OpenAI's /v1/responses/input_tokens endpoint.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class OpenAICountTokensConfig:
|
||||
"""
|
||||
Configuration and transformation logic for OpenAI Responses API token counting.
|
||||
|
||||
OpenAI Responses API Token Counting Specification:
|
||||
- Endpoint: POST https://api.openai.com/v1/responses/input_tokens
|
||||
- Response: {"input_tokens": <number>}
|
||||
"""
|
||||
|
||||
def get_openai_count_tokens_endpoint(self, api_base: Optional[str] = None) -> str:
|
||||
base = api_base or "https://api.openai.com/v1"
|
||||
base = base.rstrip("/")
|
||||
return f"{base}/responses/input_tokens"
|
||||
|
||||
def transform_request_to_count_tokens(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, List[Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform request to OpenAI Responses API token counting format.
|
||||
|
||||
The Responses API uses `input` (not `messages`) and `instructions` (not `system`).
|
||||
"""
|
||||
request: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"input": input,
|
||||
}
|
||||
|
||||
if instructions is not None:
|
||||
request["instructions"] = instructions
|
||||
|
||||
if tools is not None:
|
||||
request["tools"] = self._transform_tools_for_responses_api(tools)
|
||||
|
||||
return request
|
||||
|
||||
def get_required_headers(self, api_key: str) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
def validate_request(self, model: str, input: Union[str, List[Any]]) -> None:
|
||||
if not model:
|
||||
raise ValueError("model parameter is required")
|
||||
|
||||
if not input:
|
||||
raise ValueError("input parameter is required")
|
||||
|
||||
@staticmethod
|
||||
def _transform_tools_for_responses_api(
|
||||
tools: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transform OpenAI chat tools format to Responses API tools format.
|
||||
|
||||
Chat format: {"type": "function", "function": {"name": "...", "parameters": {...}}}
|
||||
Responses format: {"type": "function", "name": "...", "parameters": {...}}
|
||||
"""
|
||||
transformed = []
|
||||
for tool in tools:
|
||||
if tool.get("type") == "function" and "function" in tool:
|
||||
func = tool["function"]
|
||||
item: Dict[str, Any] = {
|
||||
"type": "function",
|
||||
"name": func.get("name", ""),
|
||||
"description": func.get("description", ""),
|
||||
"parameters": func.get("parameters", {}),
|
||||
}
|
||||
if "strict" in func:
|
||||
item["strict"] = func["strict"]
|
||||
transformed.append(item)
|
||||
else:
|
||||
# Pass through non-function tools (e.g., web_search, file_search)
|
||||
transformed.append(tool)
|
||||
return transformed
|
||||
|
||||
@staticmethod
|
||||
def messages_to_responses_input(
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> tuple:
|
||||
"""
|
||||
Convert standard chat messages format to OpenAI Responses API input format.
|
||||
|
||||
Returns:
|
||||
(input_items, instructions) tuple where instructions is extracted
|
||||
from system/developer messages.
|
||||
"""
|
||||
input_items: List[Dict[str, Any]] = []
|
||||
instructions_parts: List[str] = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
|
||||
if role in ("system", "developer"):
|
||||
# Extract system/developer messages as instructions
|
||||
if isinstance(content, str):
|
||||
instructions_parts.append(content)
|
||||
elif isinstance(content, list):
|
||||
# Handle content blocks - extract text
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
instructions_parts.append("\n".join(text_parts))
|
||||
elif role == "user":
|
||||
if isinstance(content, list):
|
||||
# Extract text from content blocks for Responses API
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
elif isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
content = "\n".join(text_parts)
|
||||
input_items.append({"role": "user", "content": content})
|
||||
elif role == "assistant":
|
||||
# Map tool_calls to Responses API function_call items
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if content:
|
||||
input_items.append({"role": "assistant", "content": content})
|
||||
if tool_calls:
|
||||
for tc in tool_calls:
|
||||
func = tc.get("function", {})
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call",
|
||||
"call_id": tc.get("id", ""),
|
||||
"name": func.get("name", ""),
|
||||
"arguments": func.get("arguments", ""),
|
||||
}
|
||||
)
|
||||
elif not content:
|
||||
input_items.append({"role": "assistant", "content": content})
|
||||
elif role == "tool":
|
||||
input_items.append(
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": msg.get("tool_call_id", ""),
|
||||
"output": content if isinstance(content, str) else str(content),
|
||||
}
|
||||
)
|
||||
|
||||
instructions = "\n".join(instructions_parts) if instructions_parts else None
|
||||
return input_items, instructions
|
||||
@@ -0,0 +1,119 @@
|
||||
# OpenAI Responses API Guardrail Translation Handler
|
||||
|
||||
This module provides guardrail translation support for the OpenAI Responses API format.
|
||||
|
||||
## Overview
|
||||
|
||||
The `OpenAIResponsesHandler` class handles the translation of guardrail operations for both input and output of the Responses API. It follows the same pattern as the Chat Completions handler but is adapted for the Responses API's specific data structures.
|
||||
|
||||
## Responses API Format
|
||||
|
||||
### Input Format
|
||||
The Responses API accepts input in two formats:
|
||||
|
||||
1. **String input**: Simple text string
|
||||
```python
|
||||
{"input": "Hello world", "model": "gpt-4"}
|
||||
```
|
||||
|
||||
2. **List input**: Array of message objects (ResponseInputParam)
|
||||
```python
|
||||
{
|
||||
"input": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hello", # Can be string or list of content items
|
||||
"type": "message"
|
||||
}
|
||||
],
|
||||
"model": "gpt-4"
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
The Responses API returns a `ResponsesAPIResponse` object with:
|
||||
|
||||
```python
|
||||
{
|
||||
"id": "resp_123",
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"id": "msg_123",
|
||||
"status": "completed",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Assistant response",
|
||||
"annotations": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and registered for `CallTypes.responses` and `CallTypes.aresponses`.
|
||||
|
||||
### Example
|
||||
|
||||
```python
|
||||
from litellm.llms import get_guardrail_translation_mapping
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# Get the handler
|
||||
handler_class = get_guardrail_translation_mapping(CallTypes.responses)
|
||||
handler = handler_class()
|
||||
|
||||
# Process input
|
||||
data = {"input": "User message", "model": "gpt-4"}
|
||||
processed_data = await handler.process_input_messages(data, guardrail_instance)
|
||||
|
||||
# Process output
|
||||
response = await litellm.aresponses(**processed_data)
|
||||
processed_response = await handler.process_output_response(response, guardrail_instance)
|
||||
```
|
||||
|
||||
## Key Methods
|
||||
|
||||
### `process_input_messages(data, guardrail_to_apply)`
|
||||
Processes input data by:
|
||||
1. Handling both string and list input formats
|
||||
2. Extracting text content from messages
|
||||
3. Applying guardrails to text content in parallel
|
||||
4. Mapping guardrail responses back to the original structure
|
||||
|
||||
### `process_output_response(response, guardrail_to_apply)`
|
||||
Processes output response by:
|
||||
1. Extracting text from output items' content
|
||||
2. Applying guardrails to all text content in parallel
|
||||
3. Replacing original text with guardrailed versions
|
||||
|
||||
## Extending the Handler
|
||||
|
||||
The handler can be customized by overriding these methods:
|
||||
|
||||
- `_extract_input_text_and_create_tasks()`: Customize input text extraction logic
|
||||
- `_apply_guardrail_responses_to_input()`: Customize how guardrail responses are applied to input
|
||||
- `_extract_output_text_and_create_tasks()`: Customize output text extraction logic
|
||||
- `_apply_guardrail_responses_to_output()`: Customize how guardrail responses are applied to output
|
||||
- `_has_text_content()`: Customize text content detection
|
||||
|
||||
## Testing
|
||||
|
||||
Comprehensive tests are available in `tests/llm_translation/test_openai_responses_guardrail_handler.py`:
|
||||
|
||||
```bash
|
||||
pytest tests/llm_translation/test_openai_responses_guardrail_handler.py -v
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
- **Parallel Processing**: All text content is processed in parallel using `asyncio.gather()`
|
||||
- **Mapping Tracking**: Uses tuples to track the location of each text segment for accurate replacement
|
||||
- **Type Safety**: Handles both Pydantic objects and dict representations
|
||||
- **Multimodal Support**: Properly handles mixed content with text and other media types
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
"""OpenAI Responses API handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.responses.guardrail_translation.handler import (
|
||||
OpenAIResponsesHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.responses: OpenAIResponsesHandler,
|
||||
CallTypes.aresponses: OpenAIResponsesHandler,
|
||||
}
|
||||
__all__ = ["guardrail_translation_mappings"]
|
||||
@@ -0,0 +1,760 @@
|
||||
"""
|
||||
OpenAI Responses API Handler for Unified Guardrails
|
||||
|
||||
This module provides a class-based handler for OpenAI Responses API format.
|
||||
The class methods can be overridden for custom behavior.
|
||||
|
||||
Pattern Overview:
|
||||
-----------------
|
||||
1. Extract text content from input/output (both string and list formats)
|
||||
2. Create async tasks to apply guardrails to each text segment
|
||||
3. Track mappings to know where each response belongs
|
||||
4. Apply guardrail responses back to the original structure
|
||||
|
||||
Responses API Format:
|
||||
---------------------
|
||||
Input: Union[str, List[Dict]] where each dict has:
|
||||
- role: str
|
||||
- content: Union[str, List[Dict]] (can have text items)
|
||||
- type: str (e.g., "message")
|
||||
|
||||
Output: response.output is List[GenericResponseOutputItem] where each has:
|
||||
- type: str (e.g., "message")
|
||||
- id: str
|
||||
- status: str
|
||||
- role: str
|
||||
- content: List[OutputText] where OutputText has:
|
||||
- type: str (e.g., "output_text")
|
||||
- text: str
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
LiteLLMResponsesTransformationHandler,
|
||||
OpenAiResponsesToChatCompletionStreamIterator,
|
||||
)
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.responses.main import (
|
||||
GenericResponseOutputItem,
|
||||
OutputFunctionToolCall,
|
||||
OutputText,
|
||||
)
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.llms.openai import ResponseInputParam
|
||||
from litellm.types.utils import ResponsesAPIResponse
|
||||
|
||||
|
||||
class OpenAIResponsesHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI Responses API with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input (pre-call hook)
|
||||
2. Process output response (post-call hook)
|
||||
|
||||
Methods can be overridden to customize behavior for different message formats.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input by applying guardrails to text content.
|
||||
|
||||
Handles both string input and list of message objects.
|
||||
"""
|
||||
input_data: Optional[Union[str, "ResponseInputParam"]] = data.get("input")
|
||||
tools_to_check: List[ChatCompletionToolParam] = []
|
||||
if input_data is None:
|
||||
return data
|
||||
|
||||
structured_messages = (
|
||||
LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages(
|
||||
input=input_data,
|
||||
responses_api_request=data,
|
||||
)
|
||||
)
|
||||
|
||||
# Handle simple string input
|
||||
if isinstance(input_data, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[input_data])
|
||||
original_tools: List[Dict[str, Any]] = []
|
||||
|
||||
# Extract and transform tools if present
|
||||
if "tools" in data and data["tools"]:
|
||||
original_tools = list(data["tools"])
|
||||
self._extract_and_transform_tools(data["tools"], tools_to_check)
|
||||
if tools_to_check:
|
||||
inputs["tools"] = tools_to_check
|
||||
if structured_messages:
|
||||
inputs["structured_messages"] = structured_messages # type: ignore
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_data
|
||||
self._apply_guardrailed_tools_to_data(
|
||||
data, original_tools, guardrailed_inputs.get("tools")
|
||||
)
|
||||
verbose_proxy_logger.debug("OpenAI Responses API: Processed string input")
|
||||
return data
|
||||
|
||||
# Handle list input (ResponseInputParam)
|
||||
if not isinstance(input_data, list):
|
||||
return data
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
task_mappings: List[Tuple[int, Optional[int]]] = []
|
||||
original_tools_list: List[Dict[str, Any]] = list(data.get("tools") or [])
|
||||
|
||||
# Step 1: Extract all text content, images, and tools
|
||||
for msg_idx, message in enumerate(input_data):
|
||||
self._extract_input_text_and_images(
|
||||
message=message,
|
||||
msg_idx=msg_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
# Extract and transform tools if present
|
||||
if "tools" in data and data["tools"]:
|
||||
self._extract_and_transform_tools(data["tools"], tools_to_check)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check:
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tools_to_check:
|
||||
inputs["tools"] = tools_to_check
|
||||
if structured_messages:
|
||||
inputs["structured_messages"] = structured_messages # type: ignore
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
self._apply_guardrailed_tools_to_data(
|
||||
data,
|
||||
original_tools_list,
|
||||
guardrailed_inputs.get("tools"),
|
||||
)
|
||||
|
||||
# Step 3: Map guardrail responses back to original input structure
|
||||
await self._apply_guardrail_responses_to_input(
|
||||
messages=input_data,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Responses API: Processed input messages: %s", input_data
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from Responses API request (tools[].name for function, tools[].server_label for mcp)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
if tool.get("type") == "function" and tool.get("name"):
|
||||
names.append(str(tool["name"]))
|
||||
elif tool.get("type") == "mcp" and tool.get("server_label"):
|
||||
names.append(str(tool["server_label"]))
|
||||
return names
|
||||
|
||||
def _extract_and_transform_tools(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
tools_to_check: List[ChatCompletionToolParam],
|
||||
) -> None:
|
||||
"""
|
||||
Extract and transform tools from Responses API format to Chat Completion format.
|
||||
|
||||
Uses the LiteLLM transformation function to convert Responses API tools
|
||||
to Chat Completion tools that can be passed to guardrails.
|
||||
"""
|
||||
if tools is not None and isinstance(tools, list):
|
||||
# Transform Responses API tools to Chat Completion tools
|
||||
(
|
||||
transformed_tools,
|
||||
_,
|
||||
) = LiteLLMCompletionResponsesConfig.transform_responses_api_tools_to_chat_completion_tools(
|
||||
tools # type: ignore
|
||||
)
|
||||
tools_to_check.extend(
|
||||
cast(List[ChatCompletionToolParam], transformed_tools)
|
||||
)
|
||||
|
||||
def _remap_tools_to_responses_api_format(
|
||||
self, guardrailed_tools: List[Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remap guardrail-returned tools (Chat Completion format) back to
|
||||
Responses API request tool format.
|
||||
"""
|
||||
return LiteLLMCompletionResponsesConfig.transform_chat_completion_tool_params_to_responses_api_tools(
|
||||
guardrailed_tools # type: ignore
|
||||
)
|
||||
|
||||
def _merge_tools_after_guardrail(
|
||||
self,
|
||||
original_tools: List[Dict[str, Any]],
|
||||
remapped: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Merge remapped guardrailed tools with original tools that were not sent
|
||||
to the guardrail (e.g. web_search, web_search_preview), preserving order.
|
||||
"""
|
||||
if not original_tools:
|
||||
return remapped
|
||||
result: List[Dict[str, Any]] = []
|
||||
j = 0
|
||||
for tool in original_tools:
|
||||
if isinstance(tool, dict) and tool.get("type") in (
|
||||
"web_search",
|
||||
"web_search_preview",
|
||||
):
|
||||
result.append(tool)
|
||||
else:
|
||||
if j < len(remapped):
|
||||
result.append(remapped[j])
|
||||
j += 1
|
||||
return result
|
||||
|
||||
def _apply_guardrailed_tools_to_data(
|
||||
self,
|
||||
data: dict,
|
||||
original_tools: List[Dict[str, Any]],
|
||||
guardrailed_tools: Optional[List[Any]],
|
||||
) -> None:
|
||||
"""Remap guardrailed tools to Responses API format and merge with original, then set data['tools']."""
|
||||
if guardrailed_tools is not None:
|
||||
remapped = self._remap_tools_to_responses_api_format(guardrailed_tools)
|
||||
data["tools"] = self._merge_tools_after_guardrail(original_tools, remapped)
|
||||
|
||||
def _extract_input_text_and_images(
|
||||
self,
|
||||
message: Any, # Can be Dict[str, Any] or ResponseInputParam
|
||||
msg_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content and images from an input message.
|
||||
|
||||
Override this method to customize text/image extraction logic.
|
||||
"""
|
||||
content = message.get("content", None)
|
||||
if content is None:
|
||||
return
|
||||
|
||||
if isinstance(content, str):
|
||||
# Simple string content
|
||||
texts_to_check.append(content)
|
||||
task_mappings.append((msg_idx, None))
|
||||
|
||||
elif isinstance(content, list):
|
||||
# List content (e.g., multimodal with text and images)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
if isinstance(content_item, dict):
|
||||
# Extract text
|
||||
text_str = content_item.get("text", None)
|
||||
if text_str is not None:
|
||||
texts_to_check.append(text_str)
|
||||
task_mappings.append((msg_idx, int(content_idx)))
|
||||
|
||||
# Extract images
|
||||
if content_item.get("type") == "image_url":
|
||||
image_url = content_item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
url = image_url.get("url")
|
||||
if url:
|
||||
images_to_check.append(url)
|
||||
|
||||
async def _apply_guardrail_responses_to_input(
|
||||
self,
|
||||
messages: Any, # Can be List[Dict[str, Any]] or ResponseInputParam
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, Optional[int]]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to input messages.
|
||||
|
||||
Override this method to customize how responses are applied.
|
||||
"""
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
msg_idx = cast(int, mapping[0])
|
||||
content_idx_optional = cast(Optional[int], mapping[1])
|
||||
|
||||
content = messages[msg_idx].get("content", None)
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
if isinstance(content, str) and content_idx_optional is None:
|
||||
# Replace string content with guardrail response
|
||||
messages[msg_idx]["content"] = guardrail_response
|
||||
|
||||
elif isinstance(content, list) and content_idx_optional is not None:
|
||||
# Replace specific text item in list content
|
||||
if isinstance(messages[msg_idx]["content"][content_idx_optional], dict):
|
||||
messages[msg_idx]["content"][content_idx_optional][
|
||||
"text"
|
||||
] = guardrail_response
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "ResponsesAPIResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response by applying guardrails to text content and tool calls.
|
||||
|
||||
Args:
|
||||
response: LiteLLM ResponsesAPIResponse object
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrail applied to content
|
||||
|
||||
Response Format Support:
|
||||
- response.output is a list of output items
|
||||
- Each output item can be:
|
||||
* GenericResponseOutputItem with a content list of OutputText objects
|
||||
* ResponseFunctionToolCall with tool call data
|
||||
- Each OutputText object has a text field
|
||||
"""
|
||||
|
||||
texts_to_check: List[str] = []
|
||||
images_to_check: List[str] = []
|
||||
tool_calls_to_check: List[ChatCompletionToolCallChunk] = []
|
||||
task_mappings: List[Tuple[int, int]] = []
|
||||
# Track (output_item_index, content_index) for each text
|
||||
|
||||
# Handle both dict and Pydantic object responses
|
||||
if isinstance(response, dict):
|
||||
response_output = response.get("output", [])
|
||||
elif hasattr(response, "output"):
|
||||
response_output = response.output or []
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Responses API: No output found in response"
|
||||
)
|
||||
return response
|
||||
|
||||
if not response_output:
|
||||
verbose_proxy_logger.debug("OpenAI Responses API: Empty output in response")
|
||||
return response
|
||||
|
||||
# Step 1: Extract all text content and tool calls from response output
|
||||
for output_idx, output_item in enumerate(response_output):
|
||||
self._extract_output_text_and_images(
|
||||
output_item=output_item,
|
||||
output_idx=output_idx,
|
||||
texts_to_check=texts_to_check,
|
||||
images_to_check=images_to_check,
|
||||
task_mappings=task_mappings,
|
||||
tool_calls_to_check=tool_calls_to_check,
|
||||
)
|
||||
|
||||
# Step 2: Apply guardrail to all texts in batch
|
||||
if texts_to_check or tool_calls_to_check:
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=texts_to_check)
|
||||
if images_to_check:
|
||||
inputs["images"] = images_to_check
|
||||
if tool_calls_to_check:
|
||||
inputs["tool_calls"] = tool_calls_to_check
|
||||
# Include model information from the response if available
|
||||
response_model = None
|
||||
if isinstance(response, dict):
|
||||
response_model = response.get("model")
|
||||
elif hasattr(response, "model"):
|
||||
response_model = getattr(response, "model", None)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
|
||||
# Step 3: Map guardrail responses back to original response structure
|
||||
await self._apply_guardrail_responses_to_output(
|
||||
response=response,
|
||||
responses=guardrailed_texts,
|
||||
task_mappings=task_mappings,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Responses API: Processed output response: %s", response
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
async def process_output_streaming_response(
|
||||
self,
|
||||
responses_so_far: List[Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Process output streaming response by applying guardrails to text content.
|
||||
"""
|
||||
|
||||
final_chunk = responses_so_far[-1]
|
||||
|
||||
if final_chunk.get("type") == "response.output_item.done":
|
||||
# convert openai response to model response
|
||||
model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(
|
||||
final_chunk
|
||||
)
|
||||
|
||||
tool_calls = model_response_stream.choices[0].delta.tool_calls
|
||||
if tool_calls:
|
||||
inputs = GenericGuardrailAPIInputs()
|
||||
inputs["tool_calls"] = cast(
|
||||
List[ChatCompletionToolCallChunk], tool_calls
|
||||
)
|
||||
# Include model information if available
|
||||
if (
|
||||
hasattr(model_response_stream, "model")
|
||||
and model_response_stream.model
|
||||
):
|
||||
inputs["model"] = model_response_stream.model
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return responses_so_far
|
||||
elif final_chunk.get("type") == "response.completed":
|
||||
# convert openai response to model response
|
||||
outputs = final_chunk.get("response", {}).get("output", [])
|
||||
|
||||
model_response_choices = LiteLLMResponsesTransformationHandler._convert_response_output_to_choices(
|
||||
output_items=outputs,
|
||||
handle_raw_dict_callback=None,
|
||||
)
|
||||
|
||||
if model_response_choices:
|
||||
tool_calls = model_response_choices[0].message.tool_calls
|
||||
text = model_response_choices[0].message.content
|
||||
guardrail_inputs = GenericGuardrailAPIInputs()
|
||||
if text:
|
||||
guardrail_inputs["texts"] = [text]
|
||||
if tool_calls:
|
||||
guardrail_inputs["tool_calls"] = cast(
|
||||
List[ChatCompletionToolCallChunk], tool_calls
|
||||
)
|
||||
# Include model information from the response if available
|
||||
response_model = final_chunk.get("response", {}).get("model")
|
||||
if response_model:
|
||||
guardrail_inputs["model"] = response_model
|
||||
if tool_calls or text:
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=guardrail_inputs,
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return responses_so_far
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Skipping output guardrail - model response has no choices"
|
||||
)
|
||||
# model_response_stream = OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream(final_chunk)
|
||||
# tool_calls = model_response_stream.choices[0].tool_calls
|
||||
# convert openai response to model response
|
||||
string_so_far = self.get_streaming_string_so_far(responses_so_far)
|
||||
inputs = GenericGuardrailAPIInputs(texts=[string_so_far])
|
||||
# Try to get model from the final chunk if available
|
||||
if isinstance(final_chunk, dict):
|
||||
response_model = (
|
||||
final_chunk.get("response", {}).get("model")
|
||||
if isinstance(final_chunk.get("response"), dict)
|
||||
else None
|
||||
)
|
||||
if response_model:
|
||||
inputs["model"] = response_model
|
||||
_guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data={},
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return responses_so_far
|
||||
|
||||
def _check_streaming_has_ended(self, responses_so_far: List[Any]) -> bool:
|
||||
"""
|
||||
Check if the streaming has ended.
|
||||
"""
|
||||
return all(
|
||||
response.choices[0].finish_reason is not None
|
||||
for response in responses_so_far
|
||||
)
|
||||
|
||||
def get_streaming_string_so_far(self, responses_so_far: List[Any]) -> str:
|
||||
"""
|
||||
Get the string so far from the responses so far.
|
||||
"""
|
||||
return "".join([response.get("text", "") for response in responses_so_far])
|
||||
|
||||
def _has_text_content(self, response: "ResponsesAPIResponse") -> bool:
|
||||
"""
|
||||
Check if response has any text content to process.
|
||||
|
||||
Override this method to customize text content detection.
|
||||
"""
|
||||
if not hasattr(response, "output") or response.output is None:
|
||||
return False
|
||||
|
||||
for output_item in response.output:
|
||||
if isinstance(output_item, BaseModel):
|
||||
try:
|
||||
generic_response_output_item = (
|
||||
GenericResponseOutputItem.model_validate(
|
||||
output_item.model_dump()
|
||||
)
|
||||
)
|
||||
if generic_response_output_item.content:
|
||||
output_item = generic_response_output_item
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(output_item, (GenericResponseOutputItem, dict)):
|
||||
content = (
|
||||
output_item.content
|
||||
if isinstance(output_item, GenericResponseOutputItem)
|
||||
else output_item.get("content", [])
|
||||
)
|
||||
if content:
|
||||
for content_item in content:
|
||||
# Check if it's an OutputText with text
|
||||
if isinstance(content_item, OutputText):
|
||||
if content_item.text:
|
||||
return True
|
||||
elif isinstance(content_item, dict):
|
||||
if content_item.get("text"):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _extract_output_text_and_images(
|
||||
self,
|
||||
output_item: Any,
|
||||
output_idx: int,
|
||||
texts_to_check: List[str],
|
||||
images_to_check: List[str],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
tool_calls_to_check: Optional[List[ChatCompletionToolCallChunk]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Extract text content, images, and tool calls from a response output item.
|
||||
|
||||
Override this method to customize text/image/tool extraction logic.
|
||||
"""
|
||||
|
||||
# Check if this is a tool call (OutputFunctionToolCall)
|
||||
if isinstance(output_item, OutputFunctionToolCall):
|
||||
if tool_calls_to_check is not None:
|
||||
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item=output_item,
|
||||
index=output_idx,
|
||||
)
|
||||
tool_calls_to_check.append(
|
||||
cast(ChatCompletionToolCallChunk, tool_call_dict)
|
||||
)
|
||||
return
|
||||
elif (
|
||||
isinstance(output_item, BaseModel)
|
||||
and hasattr(output_item, "type")
|
||||
and getattr(output_item, "type") == "function_call"
|
||||
):
|
||||
if tool_calls_to_check is not None:
|
||||
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item=output_item,
|
||||
index=output_idx,
|
||||
)
|
||||
tool_calls_to_check.append(
|
||||
cast(ChatCompletionToolCallChunk, tool_call_dict)
|
||||
)
|
||||
return
|
||||
elif (
|
||||
isinstance(output_item, dict) and output_item.get("type") == "function_call"
|
||||
):
|
||||
# Handle dict representation of tool call
|
||||
if tool_calls_to_check is not None:
|
||||
# Convert dict to ResponseFunctionToolCall for processing
|
||||
try:
|
||||
tool_call_obj = ResponseFunctionToolCall(**output_item)
|
||||
tool_call_dict = LiteLLMCompletionResponsesConfig.convert_response_function_tool_call_to_chat_completion_tool_call(
|
||||
tool_call_item=tool_call_obj,
|
||||
index=output_idx,
|
||||
)
|
||||
tool_calls_to_check.append(
|
||||
cast(ChatCompletionToolCallChunk, tool_call_dict)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# Handle both GenericResponseOutputItem and dict
|
||||
content: Optional[Union[List[OutputText], List[dict]]] = None
|
||||
if isinstance(output_item, BaseModel):
|
||||
try:
|
||||
output_item_dump = output_item.model_dump()
|
||||
generic_response_output_item = GenericResponseOutputItem.model_validate(
|
||||
output_item_dump
|
||||
)
|
||||
if generic_response_output_item.content:
|
||||
content = generic_response_output_item.content
|
||||
except Exception:
|
||||
# Try to extract content directly from output_item if validation fails
|
||||
if hasattr(output_item, "content") and output_item.content: # type: ignore
|
||||
content = output_item.content # type: ignore
|
||||
else:
|
||||
return
|
||||
elif isinstance(output_item, dict):
|
||||
content = output_item.get("content", [])
|
||||
else:
|
||||
return
|
||||
|
||||
if not content:
|
||||
return
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Responses API: Processing output item: %s", output_item
|
||||
)
|
||||
|
||||
# Iterate through content items (list of OutputText objects)
|
||||
for content_idx, content_item in enumerate(content):
|
||||
# Handle both OutputText objects and dicts
|
||||
if isinstance(content_item, OutputText):
|
||||
text_content = content_item.text
|
||||
elif isinstance(content_item, dict):
|
||||
text_content = content_item.get("text")
|
||||
else:
|
||||
continue
|
||||
|
||||
if text_content:
|
||||
texts_to_check.append(text_content)
|
||||
task_mappings.append((output_idx, int(content_idx)))
|
||||
|
||||
async def _apply_guardrail_responses_to_output(
|
||||
self,
|
||||
response: "ResponsesAPIResponse",
|
||||
responses: List[str],
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrail responses back to output response.
|
||||
|
||||
Override this method to customize how responses are applied.
|
||||
"""
|
||||
# Handle both dict and Pydantic object responses
|
||||
if isinstance(response, dict):
|
||||
response_output = response.get("output", [])
|
||||
elif hasattr(response, "output"):
|
||||
response_output = response.output or []
|
||||
else:
|
||||
return
|
||||
|
||||
for task_idx, guardrail_response in enumerate(responses):
|
||||
mapping = task_mappings[task_idx]
|
||||
output_idx = cast(int, mapping[0])
|
||||
content_idx = cast(int, mapping[1])
|
||||
|
||||
if output_idx >= len(response_output):
|
||||
continue
|
||||
|
||||
output_item = response_output[output_idx]
|
||||
|
||||
# Handle both GenericResponseOutputItem, BaseModel, and dict
|
||||
if isinstance(output_item, GenericResponseOutputItem):
|
||||
if output_item.content and content_idx < len(output_item.content):
|
||||
content_item = output_item.content[content_idx]
|
||||
if isinstance(content_item, OutputText):
|
||||
content_item.text = guardrail_response
|
||||
elif isinstance(content_item, dict):
|
||||
content_item["text"] = guardrail_response
|
||||
elif isinstance(output_item, BaseModel):
|
||||
# Handle other Pydantic models by converting to GenericResponseOutputItem
|
||||
try:
|
||||
generic_item = GenericResponseOutputItem.model_validate(
|
||||
output_item.model_dump()
|
||||
)
|
||||
if generic_item.content and content_idx < len(generic_item.content):
|
||||
content_item = generic_item.content[content_idx]
|
||||
if isinstance(content_item, OutputText):
|
||||
content_item.text = guardrail_response
|
||||
# Update the original response output
|
||||
if hasattr(output_item, "content") and output_item.content: # type: ignore
|
||||
original_content = output_item.content[content_idx] # type: ignore
|
||||
if hasattr(original_content, "text"):
|
||||
original_content.text = guardrail_response # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(output_item, dict):
|
||||
content = output_item.get("content", [])
|
||||
if content and content_idx < len(content):
|
||||
if isinstance(content[content_idx], dict):
|
||||
content[content_idx]["text"] = guardrail_response
|
||||
elif hasattr(content[content_idx], "text"):
|
||||
content[content_idx].text = guardrail_response
|
||||
@@ -0,0 +1,580 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union, cast, get_type_hints
|
||||
|
||||
import httpx
|
||||
from openai.types.responses import ResponseReasoningItem
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_safe_convert_created_field,
|
||||
)
|
||||
from litellm.llms.base_llm.responses.transformation import BaseResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import *
|
||||
from litellm.types.responses.main import *
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIResponsesAPIConfig(BaseResponsesAPIConfig):
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.OPENAI
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
All OpenAI Responses API params are supported
|
||||
"""
|
||||
supported_params = get_type_hints(ResponsesAPIRequestParams).keys()
|
||||
return list(
|
||||
set(
|
||||
[
|
||||
"input",
|
||||
"model",
|
||||
"extra_headers",
|
||||
"extra_query",
|
||||
"extra_body",
|
||||
"timeout",
|
||||
]
|
||||
+ list(supported_params)
|
||||
)
|
||||
)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(response_api_optional_params)
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""No transform applied since inputs are in OpenAI spec already"""
|
||||
|
||||
input = self._validate_input_param(input)
|
||||
final_request_params = dict(
|
||||
ResponsesAPIRequestParams(
|
||||
model=model, input=input, **response_api_optional_request_params
|
||||
)
|
||||
)
|
||||
|
||||
return final_request_params
|
||||
|
||||
def _validate_input_param(
|
||||
self, input: Union[str, ResponseInputParam]
|
||||
) -> Union[str, ResponseInputParam]:
|
||||
"""
|
||||
Ensure all input fields if pydantic are converted to dict
|
||||
|
||||
OpenAI API Fails when we try to JSON dumps specific input pydantic fields.
|
||||
This function ensures all input fields are converted to dict.
|
||||
"""
|
||||
if isinstance(input, list):
|
||||
validated_input = []
|
||||
for item in input:
|
||||
# if it's pydantic, convert to dict
|
||||
if isinstance(item, BaseModel):
|
||||
validated_input.append(item.model_dump(exclude_none=True))
|
||||
elif isinstance(item, dict):
|
||||
# Handle reasoning items specifically to filter out status=None
|
||||
if item.get("type") == "reasoning":
|
||||
verbose_logger.debug(f"Handling reasoning item: {item}")
|
||||
# Type assertion since we know it's a dict at this point
|
||||
dict_item = cast(Dict[str, Any], item)
|
||||
filtered_item = self._handle_reasoning_item(dict_item)
|
||||
else:
|
||||
# For other dict items, just pass through
|
||||
filtered_item = cast(Dict[str, Any], item)
|
||||
validated_input.append(filtered_item)
|
||||
else:
|
||||
validated_input.append(item)
|
||||
return validated_input # type: ignore
|
||||
# Input is expected to be either str or List, no single BaseModel expected
|
||||
return input
|
||||
|
||||
def _handle_reasoning_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle reasoning items specifically to filter out status=None using OpenAI's model.
|
||||
Issue: https://github.com/BerriAI/litellm/issues/13484
|
||||
OpenAI API does not accept ReasoningItem(status=None), so we need to:
|
||||
1. Check if the item is a reasoning type
|
||||
2. Create a ResponseReasoningItem object with the item data
|
||||
3. Convert it back to dict with exclude_none=True to filter None values
|
||||
"""
|
||||
if item.get("type") == "reasoning":
|
||||
try:
|
||||
# Ensure required fields are present for ResponseReasoningItem
|
||||
item_data = dict(item)
|
||||
if "summary" not in item_data:
|
||||
item_data["summary"] = (
|
||||
item_data.get("reasoning_content", "")[:100] + "..."
|
||||
if len(item_data.get("reasoning_content", "")) > 100
|
||||
else item_data.get("reasoning_content", "")
|
||||
)
|
||||
|
||||
# Create ResponseReasoningItem object from the item data
|
||||
reasoning_item = ResponseReasoningItem(**item_data)
|
||||
|
||||
# Convert back to dict with exclude_none=True to exclude None fields
|
||||
dict_reasoning_item = reasoning_item.model_dump(exclude_none=True)
|
||||
|
||||
return dict_reasoning_item
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Failed to create ResponseReasoningItem, falling back to manual filtering: {e}"
|
||||
)
|
||||
# Fallback: manually filter out known None fields
|
||||
filtered_item = {
|
||||
k: v
|
||||
for k, v in item.items()
|
||||
if v is not None
|
||||
or k not in {"status", "content", "encrypted_content"}
|
||||
}
|
||||
return filtered_item
|
||||
return item
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""No transform applied since outputs are in OpenAI spec already"""
|
||||
try:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
raw_response_json = raw_response.json()
|
||||
raw_response_json["created_at"] = _safe_convert_created_field(
|
||||
raw_response_json["created_at"]
|
||||
)
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
try:
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
|
||||
)
|
||||
response = ResponsesAPIResponse.model_construct(**raw_response_json)
|
||||
|
||||
# Store processed headers in additional_headers so they get returned to the client
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
return response
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the endpoint for OpenAI responses API
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/responses"
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIStreamingResponse:
|
||||
"""
|
||||
Transform a parsed streaming response chunk into a ResponsesAPIStreamingResponse
|
||||
"""
|
||||
# Convert the dictionary to a properly typed ResponsesAPIStreamingResponse
|
||||
verbose_logger.debug("Raw OpenAI Chunk=%s", parsed_chunk)
|
||||
event_type = str(parsed_chunk.get("type"))
|
||||
event_pydantic_model = OpenAIResponsesAPIConfig.get_event_model_class(
|
||||
event_type=event_type
|
||||
)
|
||||
# Some OpenAI-compatible providers send error.code: null; coalesce so validation succeeds.
|
||||
try:
|
||||
error_obj = parsed_chunk.get("error")
|
||||
if isinstance(error_obj, dict) and error_obj.get("code") is None:
|
||||
parsed_chunk = dict(parsed_chunk)
|
||||
parsed_chunk["error"] = dict(error_obj)
|
||||
parsed_chunk["error"]["code"] = "unknown_error"
|
||||
except Exception:
|
||||
verbose_logger.debug("Failed to coalesce error.code in parsed_chunk")
|
||||
|
||||
try:
|
||||
return event_pydantic_model(**parsed_chunk)
|
||||
except ValidationError:
|
||||
verbose_logger.debug(
|
||||
"Pydantic validation failed for %s with chunk %s, "
|
||||
"falling back to model_construct",
|
||||
event_pydantic_model.__name__,
|
||||
parsed_chunk,
|
||||
)
|
||||
return event_pydantic_model.model_construct(**parsed_chunk)
|
||||
|
||||
@staticmethod
|
||||
def get_event_model_class(event_type: str) -> Any:
|
||||
"""
|
||||
Returns the appropriate event model class based on the event type.
|
||||
|
||||
Args:
|
||||
event_type (str): The type of event from the response chunk
|
||||
|
||||
Returns:
|
||||
Any: The corresponding event model class
|
||||
|
||||
Raises:
|
||||
ValueError: If the event type is unknown
|
||||
"""
|
||||
event_models = {
|
||||
ResponsesAPIStreamEvents.RESPONSE_CREATED: ResponseCreatedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_IN_PROGRESS: ResponseInProgressEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_COMPLETED: ResponseCompletedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_FAILED: ResponseFailedEvent,
|
||||
ResponsesAPIStreamEvents.RESPONSE_INCOMPLETE: ResponseIncompleteEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_ITEM_ADDED: OutputItemAddedEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_ITEM_DONE: OutputItemDoneEvent,
|
||||
ResponsesAPIStreamEvents.CONTENT_PART_ADDED: ContentPartAddedEvent,
|
||||
ResponsesAPIStreamEvents.CONTENT_PART_DONE: ContentPartDoneEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_DELTA: OutputTextDeltaEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_ANNOTATION_ADDED: OutputTextAnnotationAddedEvent,
|
||||
ResponsesAPIStreamEvents.OUTPUT_TEXT_DONE: OutputTextDoneEvent,
|
||||
ResponsesAPIStreamEvents.REFUSAL_DELTA: RefusalDeltaEvent,
|
||||
ResponsesAPIStreamEvents.REFUSAL_DONE: RefusalDoneEvent,
|
||||
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DELTA: FunctionCallArgumentsDeltaEvent,
|
||||
ResponsesAPIStreamEvents.FUNCTION_CALL_ARGUMENTS_DONE: FunctionCallArgumentsDoneEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_IN_PROGRESS: FileSearchCallInProgressEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_SEARCHING: FileSearchCallSearchingEvent,
|
||||
ResponsesAPIStreamEvents.FILE_SEARCH_CALL_COMPLETED: FileSearchCallCompletedEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_IN_PROGRESS: WebSearchCallInProgressEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_SEARCHING: WebSearchCallSearchingEvent,
|
||||
ResponsesAPIStreamEvents.WEB_SEARCH_CALL_COMPLETED: WebSearchCallCompletedEvent,
|
||||
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_IN_PROGRESS: MCPListToolsInProgressEvent,
|
||||
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_COMPLETED: MCPListToolsCompletedEvent,
|
||||
ResponsesAPIStreamEvents.MCP_LIST_TOOLS_FAILED: MCPListToolsFailedEvent,
|
||||
ResponsesAPIStreamEvents.MCP_CALL_IN_PROGRESS: MCPCallInProgressEvent,
|
||||
ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DELTA: MCPCallArgumentsDeltaEvent,
|
||||
ResponsesAPIStreamEvents.MCP_CALL_ARGUMENTS_DONE: MCPCallArgumentsDoneEvent,
|
||||
ResponsesAPIStreamEvents.MCP_CALL_COMPLETED: MCPCallCompletedEvent,
|
||||
ResponsesAPIStreamEvents.MCP_CALL_FAILED: MCPCallFailedEvent,
|
||||
ResponsesAPIStreamEvents.IMAGE_GENERATION_PARTIAL_IMAGE: ImageGenerationPartialImageEvent,
|
||||
ResponsesAPIStreamEvents.ERROR: ErrorEvent,
|
||||
# Shell tool events: passthrough as GenericEvent so payload is preserved
|
||||
ResponsesAPIStreamEvents.SHELL_CALL_IN_PROGRESS: GenericEvent,
|
||||
ResponsesAPIStreamEvents.SHELL_CALL_COMPLETED: GenericEvent,
|
||||
ResponsesAPIStreamEvents.SHELL_CALL_OUTPUT: GenericEvent,
|
||||
}
|
||||
|
||||
model_class = event_models.get(cast(ResponsesAPIStreamEvents, event_type))
|
||||
if not model_class:
|
||||
return GenericEvent
|
||||
|
||||
return model_class
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
if stream is not True:
|
||||
return False
|
||||
if model is not None:
|
||||
try:
|
||||
if (
|
||||
litellm.utils.supports_native_streaming(
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
is False
|
||||
):
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_logger.debug(
|
||||
f"Error getting model info in OpenAIResponsesAPIConfig: {e}"
|
||||
)
|
||||
return False
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""OpenAI supports native WebSocket for Responses API"""
|
||||
return True
|
||||
|
||||
#########################################################
|
||||
########## DELETE RESPONSE API TRANSFORMATION ##############
|
||||
#########################################################
|
||||
def transform_delete_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the delete response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- DELETE /v1/responses/{response_id}
|
||||
"""
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_delete_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> DeleteResponseResult:
|
||||
"""
|
||||
Transform the delete response API response into a DeleteResponseResult
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
return DeleteResponseResult(**raw_response_json)
|
||||
|
||||
#########################################################
|
||||
########## GET RESPONSE API TRANSFORMATION ###############
|
||||
#########################################################
|
||||
def transform_get_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the get response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- GET /v1/responses/{response_id}
|
||||
"""
|
||||
url = f"{api_base}/{response_id}"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_get_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform the get response API response into a ResponsesAPIResponse
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## LIST INPUT ITEMS TRANSFORMATION #############
|
||||
#########################################################
|
||||
def transform_list_input_items_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
include: Optional[List[str]] = None,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{response_id}/input_items"
|
||||
params: Dict[str, Any] = {}
|
||||
if after is not None:
|
||||
params["after"] = after
|
||||
if before is not None:
|
||||
params["before"] = before
|
||||
if include:
|
||||
params["include"] = ",".join(include)
|
||||
if limit is not None:
|
||||
params["limit"] = limit
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
return url, params
|
||||
|
||||
def transform_list_input_items_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> Dict:
|
||||
try:
|
||||
return raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
#########################################################
|
||||
########## CANCEL RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
def transform_cancel_response_api_request(
|
||||
self,
|
||||
response_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the cancel response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- POST /v1/responses/{response_id}/cancel
|
||||
"""
|
||||
url = f"{api_base}/{response_id}/cancel"
|
||||
data: Dict = {}
|
||||
return url, data
|
||||
|
||||
def transform_cancel_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform the cancel response API response into a ResponsesAPIResponse
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
|
||||
return response
|
||||
|
||||
#########################################################
|
||||
########## COMPACT RESPONSE API TRANSFORMATION ##########
|
||||
#########################################################
|
||||
def transform_compact_response_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Union[str, ResponseInputParam],
|
||||
response_api_optional_request_params: Dict,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the compact response API request into a URL and data
|
||||
|
||||
OpenAI API expects the following request
|
||||
- POST /v1/responses/compact
|
||||
"""
|
||||
# Preserve query params (e.g., api-version) while appending /compact.
|
||||
parsed_url = httpx.URL(api_base)
|
||||
compact_path = parsed_url.path.rstrip("/") + "/compact"
|
||||
url = str(parsed_url.copy_with(path=compact_path))
|
||||
|
||||
input = self._validate_input_param(input)
|
||||
data = dict(
|
||||
ResponsesAPIRequestParams(
|
||||
model=model, input=input, **response_api_optional_request_params
|
||||
)
|
||||
)
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_compact_response_api_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> ResponsesAPIResponse:
|
||||
"""
|
||||
Transform the compact response API response into a ResponsesAPIResponse
|
||||
"""
|
||||
try:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
raw_response_json = raw_response.json()
|
||||
raw_response_json["created_at"] = _safe_convert_created_field(
|
||||
raw_response_json["created_at"]
|
||||
)
|
||||
except Exception:
|
||||
raise OpenAIError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
raw_response_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_response_headers)
|
||||
|
||||
try:
|
||||
response = ResponsesAPIResponse(**raw_response_json)
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"Error constructing ResponsesAPIResponse: {raw_response_json}, using model_construct"
|
||||
)
|
||||
response = ResponsesAPIResponse.model_construct(**raw_response_json)
|
||||
|
||||
response._hidden_params["additional_headers"] = processed_headers
|
||||
response._hidden_params["headers"] = raw_response_headers
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,178 @@
|
||||
# OpenAI Text-to-Speech Guardrail Translation Handler
|
||||
|
||||
Handler for processing OpenAI's text-to-speech endpoint (`/v1/audio/speech`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes text-to-speech requests by:
|
||||
1. Extracting the input text from the request
|
||||
2. Applying guardrails to the input text
|
||||
3. Updating the request with the guardrailed text
|
||||
4. Returning the output unchanged (audio is binary, not text)
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "tts-1",
|
||||
"input": "The quick brown fox jumped over the lazy dog.",
|
||||
"voice": "alloy",
|
||||
"response_format": "mp3",
|
||||
"speed": 1.0
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
The output is binary audio data (MP3, WAV, etc.), not text, so it cannot be guardrailed.
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the text-to-speech endpoint.
|
||||
|
||||
### Example: Using Guardrails with Text-to-Speech
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/speech' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "The quick brown fox jumped over the lazy dog.",
|
||||
"voice": "alloy",
|
||||
"guardrails": ["content_moderation"]
|
||||
}' \
|
||||
--output speech.mp3
|
||||
```
|
||||
|
||||
The guardrail will be applied to the input text before the text-to-speech conversion.
|
||||
|
||||
### Example: PII Masking in TTS Input
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/speech' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "tts-1",
|
||||
"input": "Please call John Doe at john@example.com",
|
||||
"voice": "nova",
|
||||
"guardrails": ["mask_pii"]
|
||||
}' \
|
||||
--output speech.mp3
|
||||
```
|
||||
|
||||
The audio will say: "Please call [NAME_REDACTED] at [EMAIL_REDACTED]"
|
||||
|
||||
### Example: Content Filtering Before TTS
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/speech' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "tts-1-hd",
|
||||
"input": "This is the text that will be spoken",
|
||||
"voice": "shimmer",
|
||||
"guardrails": ["content_filter"]
|
||||
}' \
|
||||
--output speech.mp3
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Field**: `input` (string)
|
||||
- **Processing**: Applies guardrail to input text
|
||||
- **Result**: Updated input text in request
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Processing**: Not applicable (audio is binary data)
|
||||
- **Result**: Response returned unchanged
|
||||
|
||||
## Use Cases
|
||||
|
||||
1. **PII Protection**: Remove personally identifiable information before converting to speech
|
||||
2. **Content Filtering**: Remove inappropriate content before TTS conversion
|
||||
3. **Compliance**: Ensure text meets requirements before voice synthesis
|
||||
4. **Text Sanitization**: Clean up text before audio generation
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how input text is processed
|
||||
- `process_output_response()`: Currently a no-op, but can be overridden if needed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.speech` - Synchronous text-to-speech
|
||||
- `CallTypes.aspeech` - Asynchronous text-to-speech
|
||||
|
||||
## Notes
|
||||
|
||||
- Only the input text is processed by guardrails
|
||||
- Output processing is a no-op since audio cannot be text-guardrailed
|
||||
- Both sync and async call types use the same handler
|
||||
- Works with all TTS models (tts-1, tts-1-hd, etc.)
|
||||
- Works with all voice options
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Remove PII Before TTS
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from pathlib import Path
|
||||
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
response = litellm.speech(
|
||||
model="tts-1",
|
||||
voice="alloy",
|
||||
input="Hi, this is John Doe calling from john@company.com",
|
||||
guardrails=["mask_pii"],
|
||||
)
|
||||
response.stream_to_file(speech_file_path)
|
||||
# Audio will have PII masked
|
||||
```
|
||||
|
||||
### Content Moderation Before TTS
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from pathlib import Path
|
||||
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
response = litellm.speech(
|
||||
model="tts-1-hd",
|
||||
voice="nova",
|
||||
input="Your text here",
|
||||
guardrails=["content_moderation"],
|
||||
)
|
||||
response.stream_to_file(speech_file_path)
|
||||
```
|
||||
|
||||
### Async TTS with Guardrails
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
async def generate_speech():
|
||||
speech_file_path = Path(__file__).parent / "speech.mp3"
|
||||
response = await litellm.aspeech(
|
||||
model="tts-1",
|
||||
voice="echo",
|
||||
input="Text to convert to speech",
|
||||
guardrails=["pii_mask"],
|
||||
)
|
||||
response.stream_to_file(speech_file_path)
|
||||
|
||||
asyncio.run(generate_speech())
|
||||
```
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Text-to-Speech handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.speech.guardrail_translation.handler import (
|
||||
OpenAITextToSpeechHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.speech: OpenAITextToSpeechHandler,
|
||||
CallTypes.aspeech: OpenAITextToSpeechHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAITextToSpeechHandler"]
|
||||
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
OpenAI Text-to-Speech Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's text-to-speech endpoint.
|
||||
The handler processes the 'input' text parameter (output is audio, so no text to guardrail).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
|
||||
|
||||
class OpenAITextToSpeechHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI text-to-speech requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input text (pre-call hook)
|
||||
|
||||
Note: Output processing is not applicable since the output is audio (binary),
|
||||
not text. Only the input text is processed.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input text by applying guardrails.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'input' parameter
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to input text
|
||||
"""
|
||||
input_text = data.get("input")
|
||||
if input_text is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text-to-Speech: No input text found in request data"
|
||||
)
|
||||
return data
|
||||
|
||||
if isinstance(input_text, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[input_text])
|
||||
# Include model information if available (voice model)
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["input"] = guardrailed_texts[0] if guardrailed_texts else input_text
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text-to-Speech: Applied guardrail to input text. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(input_text),
|
||||
len(data["input"]),
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text-to-Speech: Unexpected input type: %s. Expected string.",
|
||||
type(input_text),
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "HttpxBinaryResponseContent",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output - not applicable for text-to-speech.
|
||||
|
||||
The output is audio (binary data), not text, so there's nothing to apply
|
||||
guardrails to. This method returns the response unchanged.
|
||||
|
||||
Args:
|
||||
response: Binary audio response
|
||||
guardrail_to_apply: The guardrail instance (unused)
|
||||
litellm_logging_obj: Optional logging object (unused)
|
||||
user_api_key_dict: User API key metadata (unused)
|
||||
|
||||
Returns:
|
||||
Unmodified response (audio data doesn't need text guardrails)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Text-to-Speech: Output processing not applicable "
|
||||
"(output is audio data, not text)"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,41 @@
|
||||
from typing import List
|
||||
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
AudioTranscriptionRequestData,
|
||||
)
|
||||
from litellm.types.llms.openai import OpenAIAudioTranscriptionOptionalParams
|
||||
from litellm.types.utils import FileTypes
|
||||
|
||||
from .whisper_transformation import OpenAIWhisperAudioTranscriptionConfig
|
||||
|
||||
|
||||
class OpenAIGPTAudioTranscriptionConfig(OpenAIWhisperAudioTranscriptionConfig):
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
"""
|
||||
Get the supported OpenAI params for the `gpt-4o-transcribe` models
|
||||
"""
|
||||
return [
|
||||
"language",
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"include",
|
||||
]
|
||||
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> AudioTranscriptionRequestData:
|
||||
"""
|
||||
Transform the audio transcription request
|
||||
"""
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
return AudioTranscriptionRequestData(
|
||||
data=data,
|
||||
)
|
||||
@@ -0,0 +1,159 @@
|
||||
# OpenAI Audio Transcription Guardrail Translation Handler
|
||||
|
||||
Handler for processing OpenAI's audio transcription endpoint (`/v1/audio/transcriptions`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes audio transcription responses by:
|
||||
1. Applying guardrails to the transcribed text output
|
||||
2. Returning the input unchanged (since input is an audio file, not text)
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
The input is an audio file, which cannot be guardrailed (it's binary data, not text).
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "whisper-1",
|
||||
"file": "<audio file>",
|
||||
"response_format": "json",
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "This is the transcribed text from the audio file."
|
||||
}
|
||||
```
|
||||
|
||||
Or with additional metadata:
|
||||
|
||||
```json
|
||||
{
|
||||
"text": "This is the transcribed text from the audio file.",
|
||||
"duration": 3.5,
|
||||
"language": "en"
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the audio transcription endpoint.
|
||||
|
||||
### Example: Using Guardrails with Audio Transcription
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-F 'file=@audio.mp3' \
|
||||
-F 'model=whisper-1' \
|
||||
-F 'guardrails=["pii_mask"]'
|
||||
```
|
||||
|
||||
The guardrail will be applied to the **output** transcribed text only.
|
||||
|
||||
### Example: PII Masking in Transcribed Text
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-F 'file=@meeting_recording.mp3' \
|
||||
-F 'model=whisper-1' \
|
||||
-F 'guardrails=["mask_pii"]' \
|
||||
-F 'response_format=json'
|
||||
```
|
||||
|
||||
If the audio contains: "My name is John Doe and my email is john@example.com"
|
||||
|
||||
The transcription output will be: "My name is [NAME_REDACTED] and my email is [EMAIL_REDACTED]"
|
||||
|
||||
### Example: Content Moderation on Transcriptions
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/audio/transcriptions' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-F 'file=@audio.wav' \
|
||||
-F 'model=whisper-1' \
|
||||
-F 'guardrails=["content_moderation"]'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Status**: Not applicable
|
||||
- **Reason**: Input is an audio file (binary data), not text
|
||||
- **Result**: Request data returned unchanged
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Field**: `text` (string)
|
||||
- **Processing**: Applies guardrail to the transcribed text
|
||||
- **Result**: Updated text in response
|
||||
|
||||
## Use Cases
|
||||
|
||||
1. **PII Protection**: Automatically redact personally identifiable information from transcriptions
|
||||
2. **Content Filtering**: Remove or flag inappropriate content in transcribed audio
|
||||
3. **Compliance**: Ensure transcriptions meet regulatory requirements
|
||||
4. **Data Sanitization**: Clean up transcriptions before storage or further processing
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_output_response()`: Customize how transcribed text is processed
|
||||
- `process_input_messages()`: Currently a no-op, but can be overridden if needed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.transcription` - Synchronous audio transcription
|
||||
- `CallTypes.atranscription` - Asynchronous audio transcription
|
||||
|
||||
## Notes
|
||||
|
||||
- Input processing is a no-op since audio files cannot be text-guardrailed
|
||||
- Only the transcribed text output is processed
|
||||
- Guardrails apply after transcription is complete
|
||||
- Both sync and async call types use the same handler
|
||||
- Works with all Whisper models and response formats
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Transcribe and Redact PII
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.transcription(
|
||||
model="whisper-1",
|
||||
file=open("interview.mp3", "rb"),
|
||||
guardrails=["mask_pii"],
|
||||
)
|
||||
|
||||
# response.text will have PII redacted
|
||||
print(response.text)
|
||||
```
|
||||
|
||||
### Async Transcription with Guardrails
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import asyncio
|
||||
|
||||
async def transcribe_with_guardrails():
|
||||
response = await litellm.atranscription(
|
||||
model="whisper-1",
|
||||
file=open("audio.mp3", "rb"),
|
||||
guardrails=["content_filter"],
|
||||
)
|
||||
return response.text
|
||||
|
||||
text = asyncio.run(transcribe_with_guardrails())
|
||||
```
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
"""OpenAI Audio Transcription handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.openai.transcriptions.guardrail_translation.handler import (
|
||||
OpenAIAudioTranscriptionHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.transcription: OpenAIAudioTranscriptionHandler,
|
||||
CallTypes.atranscription: OpenAIAudioTranscriptionHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "OpenAIAudioTranscriptionHandler"]
|
||||
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
OpenAI Audio Transcription Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for OpenAI's audio transcription endpoint.
|
||||
The handler processes the output transcribed text (input is audio, so no text to guardrail).
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.utils import TranscriptionResponse
|
||||
|
||||
|
||||
class OpenAIAudioTranscriptionHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing OpenAI audio transcription responses with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process output transcription text (post-call hook)
|
||||
|
||||
Note: Input processing is not applicable since the input is an audio file,
|
||||
not text. Only the transcribed text output is processed.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input - not applicable for audio transcription.
|
||||
|
||||
The input is an audio file, not text, so there's nothing to apply
|
||||
guardrails to. This method returns the data unchanged.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing audio file
|
||||
guardrail_to_apply: The guardrail instance (unused)
|
||||
|
||||
Returns:
|
||||
Unmodified data (audio files don't need text guardrails)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Audio Transcription: Input processing not applicable "
|
||||
"(input is audio file, not text)"
|
||||
)
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "TranscriptionResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output transcription by applying guardrails to transcribed text.
|
||||
|
||||
Args:
|
||||
response: Transcription response object containing transcribed text
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
litellm_logging_obj: Optional logging object
|
||||
user_api_key_dict: User API key metadata to pass to guardrails
|
||||
|
||||
Returns:
|
||||
Modified response with guardrails applied to transcribed text
|
||||
"""
|
||||
if not hasattr(response, "text") or response.text is None:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Audio Transcription: No text in response to process"
|
||||
)
|
||||
return response
|
||||
|
||||
if isinstance(response.text, str):
|
||||
original_text = response.text
|
||||
# Create a request_data dict with response info and user API key metadata
|
||||
request_data: dict = {"response": response}
|
||||
|
||||
# Add user API key metadata with prefixed keys
|
||||
user_metadata = self.transform_user_api_key_dict_to_metadata(
|
||||
user_api_key_dict
|
||||
)
|
||||
if user_metadata:
|
||||
request_data["litellm_metadata"] = user_metadata
|
||||
|
||||
inputs = GenericGuardrailAPIInputs(texts=[original_text])
|
||||
# Include model information from the response if available
|
||||
if hasattr(response, "model") and response.model:
|
||||
inputs["model"] = response.model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=request_data,
|
||||
input_type="response",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
response.text = guardrailed_texts[0] if guardrailed_texts else original_text
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Audio Transcription: Applied guardrail to transcribed text. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(original_text),
|
||||
len(response.text),
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"OpenAI Audio Transcription: Unexpected text type: %s. Expected string.",
|
||||
type(response.text),
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -0,0 +1,231 @@
|
||||
from typing import TYPE_CHECKING, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from aiohttp import ClientSession
|
||||
from litellm.litellm_core_utils.audio_utils.utils import get_audio_file_name
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
from litellm.types.utils import FileTypes
|
||||
from litellm.utils import (
|
||||
TranscriptionResponse,
|
||||
convert_to_model_response_object,
|
||||
extract_duration_from_srt_or_vtt,
|
||||
)
|
||||
|
||||
from ..openai import OpenAIChatCompletion
|
||||
|
||||
|
||||
class OpenAIAudioTranscription(OpenAIChatCompletion):
|
||||
# Audio Transcriptions
|
||||
async def make_openai_audio_transcriptions_request(
|
||||
self,
|
||||
openai_aclient: AsyncOpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||
- call openai_aclient.audio.transcriptions.create by default
|
||||
"""
|
||||
try:
|
||||
raw_response = (
|
||||
await openai_aclient.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
|
||||
return headers, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def make_sync_openai_audio_transcriptions_request(
|
||||
self,
|
||||
openai_client: OpenAI,
|
||||
data: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
):
|
||||
"""
|
||||
Helper to:
|
||||
- call openai_aclient.audio.transcriptions.with_raw_response when litellm.return_response_headers is True
|
||||
- call openai_aclient.audio.transcriptions.create by default
|
||||
"""
|
||||
try:
|
||||
if litellm.return_response_headers is True:
|
||||
raw_response = (
|
||||
openai_client.audio.transcriptions.with_raw_response.create(
|
||||
**data, timeout=timeout
|
||||
)
|
||||
) # type: ignore
|
||||
headers = dict(raw_response.headers)
|
||||
response = raw_response.parse()
|
||||
return headers, response
|
||||
else:
|
||||
response = openai_client.audio.transcriptions.create(**data, timeout=timeout) # type: ignore
|
||||
return None, response
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def audio_transcriptions(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client=None,
|
||||
atranscription: bool = False,
|
||||
provider_config: Optional[BaseAudioTranscriptionConfig] = None,
|
||||
shared_session: Optional["ClientSession"] = None,
|
||||
) -> TranscriptionResponse:
|
||||
"""
|
||||
Handle audio transcription request
|
||||
"""
|
||||
if provider_config is not None:
|
||||
transformed_data = provider_config.transform_audio_transcription_request(
|
||||
model=model,
|
||||
audio_file=audio_file,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
data = cast(dict, transformed_data.data)
|
||||
else:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
client=client,
|
||||
max_retries=max_retries,
|
||||
logging_obj=logging_obj,
|
||||
shared_session=shared_session,
|
||||
)
|
||||
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=openai_client.api_key,
|
||||
additional_args={
|
||||
"api_base": openai_client._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
_, response = self.make_sync_openai_audio_transcriptions_request(
|
||||
openai_client=openai_client,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": model, "custom_llm_provider": "openai"}
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
self,
|
||||
audio_file: FileTypes,
|
||||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
shared_session: Optional["ClientSession"] = None,
|
||||
):
|
||||
try:
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
client=client,
|
||||
shared_session=shared_session,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=openai_aclient.api_key,
|
||||
additional_args={
|
||||
"api_base": openai_aclient._base_url._uri_reference,
|
||||
"atranscription": True,
|
||||
"complete_input_dict": data,
|
||||
},
|
||||
)
|
||||
headers, response = await self.make_openai_audio_transcriptions_request(
|
||||
openai_aclient=openai_aclient,
|
||||
data=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
logging_obj.model_call_details["response_headers"] = headers
|
||||
if isinstance(response, BaseModel):
|
||||
stringified_response = response.model_dump()
|
||||
else:
|
||||
duration = extract_duration_from_srt_or_vtt(response)
|
||||
stringified_response = TranscriptionResponse(text=response).model_dump()
|
||||
stringified_response["_audio_transcription_duration"] = duration
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=get_audio_file_name(audio_file),
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
# Extract the actual model from data instead of hardcoding "whisper-1"
|
||||
actual_model = data.get("model", "whisper-1")
|
||||
hidden_params = {"model": actual_model, "custom_llm_provider": "openai"}
|
||||
|
||||
return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,150 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.audio_transcription.transformation import (
|
||||
AudioTranscriptionRequestData,
|
||||
BaseAudioTranscriptionConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIAudioTranscriptionOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import FileTypes, TranscriptionResponse
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
||||
class OpenAIWhisperAudioTranscriptionConfig(BaseAudioTranscriptionConfig):
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
## get the api base, attach the endpoint - v1/audio/transcriptions
|
||||
# strip trailing slash if present
|
||||
api_base = api_base.rstrip("/") if api_base else ""
|
||||
|
||||
# if endswith "/v1"
|
||||
if api_base and api_base.endswith("/v1"):
|
||||
api_base = f"{api_base}/audio/transcriptions"
|
||||
else:
|
||||
api_base = f"{api_base}/v1/audio/transcriptions"
|
||||
|
||||
return api_base or ""
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIAudioTranscriptionOptionalParams]:
|
||||
"""
|
||||
Get the supported OpenAI params for the `whisper-1` models
|
||||
"""
|
||||
return [
|
||||
"language",
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"timestamp_granularities",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Whisper params
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
for k, v in non_default_params.items():
|
||||
if k in supported_params:
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
api_key = api_key or get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
auth_header = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
|
||||
headers.update(auth_header)
|
||||
return headers
|
||||
|
||||
def transform_audio_transcription_request(
|
||||
self,
|
||||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> AudioTranscriptionRequestData:
|
||||
"""
|
||||
Transform the audio transcription request
|
||||
"""
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
if "response_format" not in data or (
|
||||
data["response_format"] == "text" or data["response_format"] == "json"
|
||||
):
|
||||
data[
|
||||
"response_format"
|
||||
] = "verbose_json" # ensures 'duration' is received - used for cost calculation
|
||||
|
||||
return AudioTranscriptionRequestData(
|
||||
data=data,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return OpenAIError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_audio_transcription_response(
|
||||
self,
|
||||
raw_response: Response,
|
||||
) -> TranscriptionResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming response to json: {str(e)}\nResponse: {raw_response.text}"
|
||||
)
|
||||
|
||||
if any(
|
||||
key in raw_response_json
|
||||
for key in TranscriptionResponse.model_fields.keys()
|
||||
):
|
||||
return TranscriptionResponse(**raw_response_json)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response format. Received response does not match the expected format. Got: ",
|
||||
raw_response_json,
|
||||
)
|
||||
@@ -0,0 +1,258 @@
|
||||
from typing import Any, Dict, Optional, Tuple, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.vector_store_files.transformation import (
|
||||
BaseVectorStoreFilesConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_store_files import (
|
||||
VectorStoreFileAuthCredentials,
|
||||
VectorStoreFileContentResponse,
|
||||
VectorStoreFileCreateRequest,
|
||||
VectorStoreFileDeleteResponse,
|
||||
VectorStoreFileListQueryParams,
|
||||
VectorStoreFileListResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileUpdateRequest,
|
||||
)
|
||||
from litellm.utils import add_openai_metadata
|
||||
|
||||
|
||||
def _clean_dict(source: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return {k: v for k, v in source.items() if v is not None}
|
||||
|
||||
|
||||
class OpenAIVectorStoreFilesConfig(BaseVectorStoreFilesConfig):
|
||||
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
|
||||
ASSISTANTS_HEADER_VALUE = "assistants=v2"
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: Dict[str, Any]
|
||||
) -> VectorStoreFileAuthCredentials:
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
return {
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
}
|
||||
|
||||
def get_vector_store_file_endpoints_by_type(
|
||||
self,
|
||||
) -> Dict[str, Tuple[Tuple[str, str], ...]]:
|
||||
return {
|
||||
"read": (
|
||||
("GET", "/vector_stores/{vector_store_id}/files"),
|
||||
("GET", "/vector_stores/{vector_store_id}/files/{file_id}"),
|
||||
(
|
||||
"GET",
|
||||
"/vector_stores/{vector_store_id}/files/{file_id}/content",
|
||||
),
|
||||
),
|
||||
"write": (
|
||||
("POST", "/vector_stores/{vector_store_id}/files"),
|
||||
("POST", "/vector_stores/{vector_store_id}/files/{file_id}"),
|
||||
("DELETE", "/vector_stores/{vector_store_id}/files/{file_id}"),
|
||||
),
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
*,
|
||||
headers: Dict[str, str],
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> Dict[str, str]:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
if self.ASSISTANTS_HEADER_KEY not in headers:
|
||||
headers[self.ASSISTANTS_HEADER_KEY] = self.ASSISTANTS_HEADER_VALUE
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
*,
|
||||
api_base: Optional[str],
|
||||
vector_store_id: str,
|
||||
litellm_params: Dict[str, Any],
|
||||
) -> str:
|
||||
base_url = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
return f"{base_url}/vector_stores/{vector_store_id}/files"
|
||||
|
||||
def transform_create_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
create_request: VectorStoreFileCreateRequest,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
payload: Dict[str, Any] = _clean_dict(dict(create_request))
|
||||
attributes = payload.get("attributes")
|
||||
if isinstance(attributes, dict):
|
||||
filtered_attributes = add_openai_metadata(attributes)
|
||||
if filtered_attributes is not None:
|
||||
payload["attributes"] = filtered_attributes
|
||||
else:
|
||||
payload.pop("attributes", None)
|
||||
url = api_base
|
||||
return url, payload
|
||||
|
||||
def transform_create_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
try:
|
||||
return cast(VectorStoreFileObject, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_list_vector_store_files_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
query_params: VectorStoreFileListQueryParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
params = _clean_dict(dict(query_params))
|
||||
return api_base, params
|
||||
|
||||
def transform_list_vector_store_files_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileListResponse:
|
||||
try:
|
||||
return cast(VectorStoreFileListResponse, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_retrieve_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
return f"{api_base}/{file_id}", {}
|
||||
|
||||
def transform_retrieve_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
try:
|
||||
return cast(VectorStoreFileObject, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_retrieve_vector_store_file_content_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
return f"{api_base}/{file_id}/content", {}
|
||||
|
||||
def transform_retrieve_vector_store_file_content_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileContentResponse:
|
||||
try:
|
||||
return cast(VectorStoreFileContentResponse, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_update_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
update_request: VectorStoreFileUpdateRequest,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
payload: Dict[str, Any] = dict(update_request)
|
||||
attributes = payload.get("attributes")
|
||||
if isinstance(attributes, dict):
|
||||
filtered_attributes = add_openai_metadata(attributes)
|
||||
if filtered_attributes is not None:
|
||||
payload["attributes"] = filtered_attributes
|
||||
else:
|
||||
payload.pop("attributes", None)
|
||||
return f"{api_base}/{file_id}", payload
|
||||
|
||||
def transform_update_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileObject:
|
||||
try:
|
||||
return cast(VectorStoreFileObject, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_delete_vector_store_file_request(
|
||||
self,
|
||||
*,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
return f"{api_base}/{file_id}", {}
|
||||
|
||||
def transform_delete_vector_store_file_response(
|
||||
self,
|
||||
*,
|
||||
response: httpx.Response,
|
||||
) -> VectorStoreFileDeleteResponse:
|
||||
try:
|
||||
return cast(VectorStoreFileDeleteResponse, response.json())
|
||||
except Exception as exc: # noqa: BLE001
|
||||
raise self.get_error_class(
|
||||
error_message=str(exc),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
@@ -0,0 +1,176 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateRequest,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchRequest,
|
||||
VectorStoreSearchResponse,
|
||||
)
|
||||
from litellm.utils import add_openai_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIVectorStoreConfig(BaseVectorStoreConfig):
|
||||
ASSISTANTS_HEADER_KEY = "OpenAI-Beta"
|
||||
ASSISTANTS_HEADER_VALUE = "assistants=v2"
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
api_key = litellm_params.get("api_key")
|
||||
if api_key is None:
|
||||
raise ValueError("api_key is required")
|
||||
return {
|
||||
"headers": {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
},
|
||||
}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
return {
|
||||
"read": [("GET", "/vector_stores/{index_name}/search")],
|
||||
"write": [("POST", "/vector_stores")],
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Ensure OpenAI Assistants header is includes
|
||||
#########################################################
|
||||
if self.ASSISTANTS_HEADER_KEY not in headers:
|
||||
headers.update(
|
||||
{
|
||||
self.ASSISTANTS_HEADER_KEY: self.ASSISTANTS_HEADER_VALUE,
|
||||
}
|
||||
)
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the Base endpoint for OpenAI Vector Stores API
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENAI_BASE_URL")
|
||||
or get_secret_str("OPENAI_API_BASE")
|
||||
or "https://api.openai.com/v1"
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/vector_stores"
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = f"{api_base}/{vector_store_id}/search"
|
||||
typed_request_body = VectorStoreSearchRequest(
|
||||
query=query,
|
||||
filters=vector_store_search_optional_params.get("filters", None),
|
||||
max_num_results=vector_store_search_optional_params.get(
|
||||
"max_num_results", None
|
||||
),
|
||||
ranking_options=vector_store_search_optional_params.get(
|
||||
"ranking_options", None
|
||||
),
|
||||
rewrite_query=vector_store_search_optional_params.get(
|
||||
"rewrite_query", None
|
||||
),
|
||||
)
|
||||
|
||||
dict_request_body = cast(dict, typed_request_body)
|
||||
return url, dict_request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
try:
|
||||
response_json = response.json()
|
||||
return VectorStoreSearchResponse(**response_json)
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
url = api_base # Base URL for creating vector stores
|
||||
metadata = vector_store_create_optional_params.get("metadata", None)
|
||||
metadata_payload = add_openai_metadata(metadata)
|
||||
|
||||
typed_request_body = VectorStoreCreateRequest(
|
||||
name=vector_store_create_optional_params.get("name", None),
|
||||
file_ids=vector_store_create_optional_params.get("file_ids", None),
|
||||
expires_after=vector_store_create_optional_params.get(
|
||||
"expires_after", None
|
||||
),
|
||||
chunking_strategy=vector_store_create_optional_params.get(
|
||||
"chunking_strategy", None
|
||||
),
|
||||
metadata=metadata_payload,
|
||||
)
|
||||
|
||||
dict_request_body = cast(dict, typed_request_body)
|
||||
return url, dict_request_body
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
try:
|
||||
response_json = response.json()
|
||||
return VectorStoreCreateResponse(**response_json)
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
@@ -0,0 +1,447 @@
|
||||
from io import BufferedReader
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.videos.transformation import BaseVideoConfig
|
||||
from litellm.llms.openai.image_edit.transformation import ImageEditRequestUtils
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import CreateVideoRequest
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.videos.main import VideoCreateOptionalRequestParams, VideoObject
|
||||
from litellm.types.videos.utils import (
|
||||
encode_video_id_with_provider,
|
||||
extract_original_video_id,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
class OpenAIVideoConfig(BaseVideoConfig):
|
||||
"""
|
||||
Configuration class for OpenAI video generation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the list of supported OpenAI parameters for video generation.
|
||||
"""
|
||||
return [
|
||||
"model",
|
||||
"prompt",
|
||||
"input_reference",
|
||||
"seconds",
|
||||
"size",
|
||||
"user",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
video_create_optional_params: VideoCreateOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""No mapping applied since inputs are in OpenAI spec already"""
|
||||
return dict(video_create_optional_params)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
) -> dict:
|
||||
# Use api_key from litellm_params if available, otherwise fall back to other sources
|
||||
if litellm_params and litellm_params.api_key:
|
||||
api_key = api_key or litellm_params.api_key
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for OpenAI video generation.
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com/v1"
|
||||
|
||||
return f"{api_base.rstrip('/')}/videos"
|
||||
|
||||
def transform_video_create_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
video_create_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles, str]:
|
||||
"""
|
||||
Transform the video creation request for OpenAI API.
|
||||
"""
|
||||
# Remove model and extra_headers from optional params as they're handled separately
|
||||
video_create_optional_request_params = {
|
||||
k: v
|
||||
for k, v in video_create_optional_request_params.items()
|
||||
if k not in ["model", "extra_headers", "prompt"]
|
||||
}
|
||||
|
||||
# Create the request data
|
||||
video_create_request = CreateVideoRequest(
|
||||
model=model, prompt=prompt, **video_create_optional_request_params
|
||||
)
|
||||
request_dict = cast(Dict, video_create_request)
|
||||
|
||||
# Handle input_reference parameter if provided
|
||||
_input_reference = video_create_optional_request_params.get("input_reference")
|
||||
data_without_files = {
|
||||
k: v for k, v in request_dict.items() if k not in ["input_reference"]
|
||||
}
|
||||
files_list: List[Tuple[str, Any]] = []
|
||||
|
||||
# Handle input_reference parameter
|
||||
if _input_reference is not None:
|
||||
self._add_image_to_files(
|
||||
files_list=files_list,
|
||||
image=_input_reference,
|
||||
field_name="input_reference",
|
||||
)
|
||||
return data_without_files, files_list, api_base
|
||||
|
||||
def transform_video_create_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_data: Optional[Dict] = None,
|
||||
) -> VideoObject:
|
||||
"""Transform the OpenAI video creation response."""
|
||||
response_data = raw_response.json()
|
||||
|
||||
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
|
||||
|
||||
if custom_llm_provider and video_obj.id:
|
||||
video_obj.id = encode_video_id_with_provider(
|
||||
video_obj.id, custom_llm_provider, model
|
||||
)
|
||||
|
||||
usage_data = {}
|
||||
if video_obj:
|
||||
if hasattr(video_obj, "seconds") and video_obj.seconds:
|
||||
try:
|
||||
usage_data["duration_seconds"] = float(video_obj.seconds)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
video_obj.usage = usage_data
|
||||
|
||||
return video_obj
|
||||
|
||||
def transform_video_content_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video content request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- GET /v1/videos/{video_id}/content
|
||||
- GET /v1/videos/{video_id}/content?variant=thumbnail
|
||||
"""
|
||||
original_video_id = extract_original_video_id(video_id)
|
||||
|
||||
# Construct the URL for video content download
|
||||
url = f"{api_base.rstrip('/')}/{original_video_id}/content"
|
||||
if variant is not None:
|
||||
url = f"{url}?variant={variant}"
|
||||
|
||||
# No additional data needed for GET content request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_video_remix_request(
|
||||
self,
|
||||
video_id: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video remix request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- POST /v1/videos/{video_id}/remix
|
||||
"""
|
||||
original_video_id = extract_original_video_id(video_id)
|
||||
|
||||
# Construct the URL for video remix
|
||||
url = f"{api_base.rstrip('/')}/{original_video_id}/remix"
|
||||
|
||||
# Prepare the request data
|
||||
data = {"prompt": prompt}
|
||||
|
||||
# Add any extra body parameters
|
||||
if extra_body:
|
||||
data.update(extra_body)
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_video_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""Transform the OpenAI video content download response."""
|
||||
return raw_response.content
|
||||
|
||||
def transform_video_remix_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the OpenAI video remix response.
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
|
||||
|
||||
if custom_llm_provider and video_obj.id:
|
||||
video_obj.id = encode_video_id_with_provider(
|
||||
video_obj.id, custom_llm_provider, None
|
||||
)
|
||||
|
||||
# Create usage object with duration information for cost calculation
|
||||
# Video remix API doesn't provide usage, so we create one with duration
|
||||
usage_data = {}
|
||||
if video_obj:
|
||||
if hasattr(video_obj, "seconds") and video_obj.seconds:
|
||||
try:
|
||||
usage_data["duration_seconds"] = float(video_obj.seconds)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
# Create the response
|
||||
video_obj.usage = usage_data
|
||||
|
||||
return video_obj
|
||||
|
||||
def transform_video_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video list request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- GET /v1/videos
|
||||
"""
|
||||
# Use the api_base directly for video list
|
||||
url = api_base
|
||||
|
||||
# Prepare query parameters
|
||||
params = {}
|
||||
if after is not None:
|
||||
# Decode the wrapped video ID back to the original provider ID
|
||||
params["after"] = extract_original_video_id(after)
|
||||
if limit is not None:
|
||||
params["limit"] = str(limit)
|
||||
if order is not None:
|
||||
params["order"] = order
|
||||
|
||||
# Add any extra query parameters
|
||||
if extra_query:
|
||||
params.update(extra_query)
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_video_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
response_data = raw_response.json()
|
||||
|
||||
if custom_llm_provider and "data" in response_data:
|
||||
for video_obj in response_data.get("data", []):
|
||||
if isinstance(video_obj, dict) and "id" in video_obj:
|
||||
video_obj["id"] = encode_video_id_with_provider(
|
||||
video_obj["id"],
|
||||
custom_llm_provider,
|
||||
video_obj.get("model"),
|
||||
)
|
||||
|
||||
# Encode pagination cursor IDs so they remain consistent
|
||||
# with the wrapped data[].id format
|
||||
data_list = response_data.get("data", [])
|
||||
if response_data.get("first_id"):
|
||||
first_model = None
|
||||
if data_list and isinstance(data_list[0], dict):
|
||||
first_model = data_list[0].get("model")
|
||||
response_data["first_id"] = encode_video_id_with_provider(
|
||||
response_data["first_id"],
|
||||
custom_llm_provider,
|
||||
first_model,
|
||||
)
|
||||
if response_data.get("last_id"):
|
||||
last_model = None
|
||||
if data_list and isinstance(data_list[-1], dict):
|
||||
last_model = data_list[-1].get("model")
|
||||
response_data["last_id"] = encode_video_id_with_provider(
|
||||
response_data["last_id"],
|
||||
custom_llm_provider,
|
||||
last_model,
|
||||
)
|
||||
|
||||
return response_data
|
||||
|
||||
def transform_video_delete_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video delete request for OpenAI API.
|
||||
|
||||
OpenAI API expects the following request:
|
||||
- DELETE /v1/videos/{video_id}
|
||||
"""
|
||||
original_video_id = extract_original_video_id(video_id)
|
||||
|
||||
# Construct the URL for video delete
|
||||
url = f"{api_base.rstrip('/')}/{original_video_id}"
|
||||
|
||||
# No data needed for DELETE request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_video_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the OpenAI video delete response.
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Transform the response data
|
||||
video_obj = VideoObject(**response_data) # type: ignore[arg-type] # type: ignore[arg-type]
|
||||
|
||||
return video_obj
|
||||
|
||||
def transform_video_status_retrieve_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the OpenAI video retrieve request.
|
||||
"""
|
||||
# Extract the original video_id (remove provider encoding if present)
|
||||
original_video_id = extract_original_video_id(video_id)
|
||||
|
||||
# For video retrieve, we just need to construct the URL
|
||||
url = f"{api_base.rstrip('/')}/{original_video_id}"
|
||||
|
||||
# No additional data needed for GET request
|
||||
data: Dict[str, Any] = {}
|
||||
|
||||
return url, data
|
||||
|
||||
def transform_video_status_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the OpenAI video retrieve response.
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
# Transform the response data
|
||||
video_obj = VideoObject(**response_data) # type: ignore[arg-type]
|
||||
|
||||
if custom_llm_provider and video_obj.id:
|
||||
video_obj.id = encode_video_id_with_provider(
|
||||
video_obj.id, custom_llm_provider, None
|
||||
)
|
||||
|
||||
return video_obj
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ...base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
raise BaseLLMException(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def _add_image_to_files(
|
||||
self,
|
||||
files_list: List[Tuple[str, Any]],
|
||||
image: Any,
|
||||
field_name: str,
|
||||
) -> None:
|
||||
"""Add an image to the files list with appropriate content type"""
|
||||
image_content_type = ImageEditRequestUtils.get_image_content_type(image)
|
||||
|
||||
if isinstance(image, BufferedReader):
|
||||
files_list.append((field_name, (image.name, image, image_content_type)))
|
||||
else:
|
||||
files_list.append(
|
||||
(field_name, ("input_reference.png", image, image_content_type))
|
||||
)
|
||||
Reference in New Issue
Block a user