""" Common helpers / utils across al OpenAI endpoints """ import hashlib import inspect import json import os import ssl from typing import ( TYPE_CHECKING, Any, Dict, List, Literal, NamedTuple, Optional, Tuple, Union, ) import httpx import openai from openai import AsyncAzureOpenAI, AsyncOpenAI, AzureOpenAI, OpenAI if TYPE_CHECKING: from aiohttp import ClientSession import litellm from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.custom_httpx.http_handler import ( _DEFAULT_TTL_FOR_HTTPX_CLIENTS, AsyncHTTPHandler, get_ssl_configuration, ) def _get_client_init_params(cls: type) -> Tuple[str, ...]: """Extract __init__ parameter names (excluding 'self') from a class.""" return tuple(p for p in inspect.signature(cls.__init__).parameters if p != "self") # type: ignore[misc] _OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(OpenAI) _AZURE_OPENAI_INIT_PARAMS: Tuple[str, ...] = _get_client_init_params(AzureOpenAI) class OpenAIError(BaseLLMException): def __init__( self, status_code: int, message: str, request: Optional[httpx.Request] = None, response: Optional[httpx.Response] = None, headers: Optional[Union[dict, httpx.Headers]] = None, body: Optional[dict] = None, ): self.status_code = status_code self.message = message self.headers = headers if request: self.request = request else: self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") if response: self.response = response else: self.response = httpx.Response( status_code=status_code, request=self.request ) super().__init__( status_code=status_code, message=self.message, headers=self.headers, request=self.request, response=self.response, body=body, ) ####### Error Handling Utils for OpenAI API ####################### ################################################################### def drop_params_from_unprocessable_entity_error( e: Union[openai.UnprocessableEntityError, httpx.HTTPStatusError], data: Dict[str, Any], ) -> Dict[str, Any]: """ Helper function to read OpenAI UnprocessableEntityError and drop the params that raised an error from the error message. Args: e (UnprocessableEntityError): The UnprocessableEntityError exception data (Dict[str, Any]): The original data dictionary containing all parameters Returns: Dict[str, Any]: A new dictionary with invalid parameters removed """ invalid_params: List[str] = [] if isinstance(e, httpx.HTTPStatusError): error_json = e.response.json() error_message = error_json.get("error", {}) error_body = error_message else: error_body = e.body if ( error_body is not None and isinstance(error_body, dict) and error_body.get("message") ): message = error_body.get("message", {}) if isinstance(message, str): try: message = json.loads(message) except json.JSONDecodeError: message = {"detail": message} detail = message.get("detail") if isinstance(detail, List) and len(detail) > 0 and isinstance(detail[0], dict): for error_dict in detail: if ( error_dict.get("loc") and isinstance(error_dict.get("loc"), list) and len(error_dict.get("loc")) == 2 ): invalid_params.append(error_dict["loc"][1]) new_data = {k: v for k, v in data.items() if k not in invalid_params} return new_data class BaseOpenAILLM: """ Base class for OpenAI LLMs for getting their httpx clients and SSL verification settings """ @staticmethod def get_cached_openai_client( client_initialization_params: dict, client_type: Literal["openai", "azure"] ) -> Optional[Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI]]: """Retrieves the OpenAI client from the in-memory cache based on the client initialization parameters""" _cache_key = BaseOpenAILLM.get_openai_client_cache_key( client_initialization_params=client_initialization_params, client_type=client_type, ) _cached_client = litellm.in_memory_llm_clients_cache.get_cache(_cache_key) return _cached_client @staticmethod def set_cached_openai_client( openai_client: Union[OpenAI, AsyncOpenAI, AzureOpenAI, AsyncAzureOpenAI], client_type: Literal["openai", "azure"], client_initialization_params: dict, ): """Stores the OpenAI client in the in-memory cache for _DEFAULT_TTL_FOR_HTTPX_CLIENTS SECONDS""" _cache_key = BaseOpenAILLM.get_openai_client_cache_key( client_initialization_params=client_initialization_params, client_type=client_type, ) litellm.in_memory_llm_clients_cache.set_cache( key=_cache_key, value=openai_client, ttl=_DEFAULT_TTL_FOR_HTTPX_CLIENTS, ) @staticmethod def get_openai_client_cache_key( client_initialization_params: dict, client_type: Literal["openai", "azure"] ) -> str: """Creates a cache key for the OpenAI client based on the client initialization parameters""" hashed_api_key = None if client_initialization_params.get("api_key") is not None: hash_object = hashlib.sha256( client_initialization_params.get("api_key", "").encode() ) # Hexadecimal representation of the hash hashed_api_key = hash_object.hexdigest() # Create a more readable cache key using a list of key-value pairs key_parts = [ f"hashed_api_key={hashed_api_key}", f"is_async={client_initialization_params.get('is_async')}", ] LITELLM_CLIENT_SPECIFIC_PARAMS = ( "timeout", "max_retries", "organization", "api_base", ) openai_client_fields = ( BaseOpenAILLM.get_openai_client_initialization_param_fields( client_type=client_type ) + LITELLM_CLIENT_SPECIFIC_PARAMS ) for param in openai_client_fields: key_parts.append(f"{param}={client_initialization_params.get(param)}") _cache_key = ",".join(key_parts) return _cache_key @staticmethod def get_openai_client_initialization_param_fields( client_type: Literal["openai", "azure"] ) -> Tuple[str, ...]: """Returns a tuple of fields that are used to initialize the OpenAI client""" if client_type == "openai": return _OPENAI_INIT_PARAMS else: return _AZURE_OPENAI_INIT_PARAMS @staticmethod def _get_async_http_client( shared_session: Optional["ClientSession"] = None, ) -> Optional[httpx.AsyncClient]: if litellm.aclient_session is not None: return litellm.aclient_session if getattr(litellm, "network_mock", False): from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport return httpx.AsyncClient(transport=MockOpenAITransport()) # Get unified SSL configuration ssl_config = get_ssl_configuration() return httpx.AsyncClient( verify=ssl_config, transport=AsyncHTTPHandler._create_async_transport( ssl_context=ssl_config if isinstance(ssl_config, ssl.SSLContext) else None, ssl_verify=ssl_config if isinstance(ssl_config, bool) else None, shared_session=shared_session, ), follow_redirects=True, ) @staticmethod def _get_sync_http_client() -> Optional[httpx.Client]: if litellm.client_session is not None: return litellm.client_session if getattr(litellm, "network_mock", False): from litellm.llms.custom_httpx.mock_transport import MockOpenAITransport return httpx.Client(transport=MockOpenAITransport()) # Get unified SSL configuration ssl_config = get_ssl_configuration() return httpx.Client( verify=ssl_config, follow_redirects=True, ) class OpenAICredentials(NamedTuple): api_base: str api_key: Optional[str] organization: Optional[str] def get_openai_credentials( api_base: Optional[str] = None, api_key: Optional[str] = None, organization: Optional[str] = None, ) -> OpenAICredentials: """Resolve OpenAI credentials from params, litellm globals, and env vars.""" resolved_api_base = ( api_base or litellm.api_base or os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_API_BASE") or "https://api.openai.com/v1" ) resolved_organization = ( organization or litellm.organization or os.getenv("OPENAI_ORGANIZATION", None) or None ) resolved_api_key = ( api_key or litellm.api_key or litellm.openai_key or os.getenv("OPENAI_API_KEY") ) return OpenAICredentials( api_base=resolved_api_base, api_key=resolved_api_key, organization=resolved_organization, )