import json import time import traceback from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import httpx from litellm.litellm_core_utils.exception_mapping_utils import exception_type from litellm.litellm_core_utils.logging_utils import track_llm_api_timing from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, _get_httpx_client, get_async_httpx_client, version, ) from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import LlmProviders from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from ..common_utils import API_BASE, BytezError if TYPE_CHECKING: from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj LiteLLMLoggingObj = _LiteLLMLoggingObj else: LiteLLMLoggingObj = Any # 5 minute timeout (models may need to load) STREAMING_TIMEOUT = 60 * 5 class BytezChatConfig(BaseConfig): """ Configuration class for Bytez's API interface. """ def __init__( self, ) -> None: locals_ = locals().copy() for key, value in locals_.items(): if key != "self" and value is not None: setattr(self.__class__, key, value) # mark the class as using a custom stream wrapper because the default only iterates on lines setattr(self.__class__, "has_custom_stream_wrapper", True) self.openai_to_bytez_param_map = { "stream": "stream", "max_tokens": "max_new_tokens", "max_completion_tokens": "max_new_tokens", "temperature": "temperature", "top_p": "top_p", "n": "num_return_sequences", "max_retries": "max_retries", "seed": False, # TODO requires backend changes "stop": False, # TODO requires backend changes "logit_bias": False, # TODO requires backend changes "logprobs": False, # TODO requires backend changes "frequency_penalty": False, "presence_penalty": False, "top_logprobs": False, "modalities": False, "prediction": False, "stream_options": False, "tools": False, "tool_choice": False, "function_call": False, "functions": False, "extra_headers": False, "parallel_tool_calls": False, "audio": False, "web_search_options": False, } def get_supported_openai_params(self, model: str) -> List[str]: supported_params = [] for key, value in self.openai_to_bytez_param_map.items(): if value: supported_params.append(key) return supported_params def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: adapted_params = {} all_params = {**non_default_params, **optional_params} for key, value in all_params.items(): alias = self.openai_to_bytez_param_map.get(key) if alias is False: if drop_params: continue raise Exception(f"param `{key}` is not supported on Bytez") if alias is None: adapted_params[key] = value continue adapted_params[alias] = value return adapted_params def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: headers.update( { "content-type": "application/json", "Authorization": f"Key {api_key}", "user-agent": f"litellm/{version}", } ) if not messages: raise Exception( "kwarg `messages` must be an array of messages that follow the openai chat standard" ) if not api_key: raise Exception("Missing api_key, make sure you pass in your api key") return headers def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: return f"{API_BASE}/{model}" def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: stream = optional_params.get("stream", False) # we add stream not as an additional param, but as a primary prop on the request body, this is always defined if stream == True if optional_params.get("stream"): del optional_params["stream"] messages = adapt_messages_to_bytez_standard(messages=messages) # type: ignore data = { "messages": messages, "stream": stream, "params": optional_params, } return data def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj: LiteLLMLoggingObj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: json = raw_response.json() # noqa: F811 error = json.get("error") if error is not None: raise BytezError( message=str(json["error"]), status_code=raw_response.status_code, ) # set meta data here model_response.created = int(time.time()) model_response.model = model # Add the output output = json.get("output") message = model_response.choices[0].message # type: ignore message.content = output["content"][0]["text"] messages = adapt_messages_to_bytez_standard(messages=messages) # type: ignore # NOTE We are approximating tokens, to get the true values we will need to update our BE prompt_tokens = get_tokens_from_messages(messages) # type: ignore output_messages = adapt_messages_to_bytez_standard(messages=[output]) completion_tokens = get_tokens_from_messages(output_messages) total_tokens = prompt_tokens + completion_tokens usage = Usage( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, ) model_response.usage = usage # type: ignore model_response._hidden_params["additional_headers"] = raw_response.headers message.provider_specific_fields = { "ratelimit-limit": raw_response.headers.get("ratelimit-limit"), "ratelimit-remaining": raw_response.headers.get("ratelimit-remaining"), "ratelimit-reset": raw_response.headers.get("ratelimit-reset"), "inference-meter": raw_response.headers.get("inference-meter"), "inference-time": raw_response.headers.get("inference-time"), } # TODO additional data when supported # message.tool_calls # message.function_call return model_response @track_llm_api_timing() def get_sync_custom_stream_wrapper( self, model: str, custom_llm_provider: str, logging_obj: LiteLLMLoggingObj, api_base: str, headers: dict, data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, json_mode: Optional[bool] = None, signed_json_body: Optional[bytes] = None, ) -> "BytezCustomStreamWrapper": if client is None or isinstance(client, AsyncHTTPHandler): client = _get_httpx_client(params={}) try: response = client.post( api_base, headers=headers, data=json.dumps(data), stream=True, logging_obj=logging_obj, timeout=STREAMING_TIMEOUT, ) except httpx.HTTPStatusError as e: raise BytezError( status_code=e.response.status_code, message=e.response.text ) if response.status_code != 200: raise BytezError(status_code=response.status_code, message=response.text) completion_stream = response.iter_text() streaming_response = BytezCustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streaming_response @track_llm_api_timing() async def get_async_custom_stream_wrapper( self, model: str, custom_llm_provider: str, logging_obj: LiteLLMLoggingObj, api_base: str, headers: dict, data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, json_mode: Optional[bool] = None, signed_json_body: Optional[bytes] = None, ) -> "BytezCustomStreamWrapper": if client is None or isinstance(client, HTTPHandler): client = get_async_httpx_client(llm_provider=LlmProviders.BYTEZ, params={}) try: response = await client.post( api_base, headers=headers, data=json.dumps(data), stream=True, logging_obj=logging_obj, timeout=STREAMING_TIMEOUT, ) except httpx.HTTPStatusError as e: raise BytezError( status_code=e.response.status_code, message=e.response.text ) if response.status_code != 200: raise BytezError(status_code=response.status_code, message=response.text) completion_stream = response.aiter_text() streaming_response = BytezCustomStreamWrapper( completion_stream=completion_stream, model=model, custom_llm_provider=custom_llm_provider, logging_obj=logging_obj, ) return streaming_response def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: return BytezError(status_code=status_code, message=error_message) class BytezCustomStreamWrapper(CustomStreamWrapper): def chunk_creator(self, chunk: Any): try: model_response = self.model_response_creator() response_obj: Dict[str, Any] = {} response_obj = { "text": chunk, "is_finished": False, "finish_reason": "", } completion_obj: Dict[str, Any] = {"content": chunk} return self.return_processed_chunk_logic( completion_obj=completion_obj, model_response=model_response, # type: ignore response_obj=response_obj, ) except StopIteration: raise StopIteration except Exception as e: traceback.format_exc() setattr(e, "message", str(e)) raise exception_type( model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e, ) # litellm/types/llms/openai.py is a good reference for what is supported open_ai_to_bytez_content_item_map = { "text": {"type": "text", "value_name": "text"}, "image_url": {"type": "image", "value_name": "url"}, "input_audio": {"type": "audio", "value_name": "url"}, "video_url": {"type": "video", "value_name": "url"}, "document": None, "file": None, } def adapt_messages_to_bytez_standard(messages: List[Dict]): messages = _adapt_string_only_content_to_lists(messages) new_messages = [] for message in messages: role = message["role"] content: list = message["content"] new_content = [] for content_item in content: type: Union[str, None] = content_item.get("type") if not type: raise Exception("Prop `type` is not a string") content_item_map = open_ai_to_bytez_content_item_map[type] if not content_item_map: raise Exception(f"Prop `{type}` is not supported") new_type = content_item_map["type"] value_name = content_item_map["value_name"] value: Union[str, None] = content_item.get(value_name) if not value: raise Exception(f"Prop `{value_name}` is not a string") new_content.append({"type": new_type, value_name: value}) new_messages.append({"role": role, "content": new_content}) return new_messages # "content": "The cat ran so fast" # becomes # "content": [{"type": "text", "text": "The cat ran so fast"}] def _adapt_string_only_content_to_lists(messages: List[Dict]): new_messages = [] for message in messages: role = message.get("role") content = message.get("content") new_content = [] if isinstance(content, str): new_content.append({"type": "text", "text": content}) elif isinstance(content, dict): new_content.append(content) elif isinstance(content, list): new_content_items = [] for content_item in content: if isinstance(content_item, str): new_content_items.append({"type": "text", "text": content_item}) elif isinstance(content_item, dict): new_content_items.append(content_item) else: raise Exception( "`content` can only contain strings or openai content dicts" ) new_content += new_content_items else: raise Exception("Content must be a string") new_messages.append({"role": role, "content": new_content}) return new_messages # TODO get this from the api instead of doing it here, will require backend work def get_tokens_from_messages(messages: List[dict]): total = 0 for message in messages: content: List[dict] = message["content"] for content_item in content: type = content_item["type"] if type == "text": value: str = content_item["text"] words = value.split(" ") total += len(words) continue # we'll count media as single tokens for now total += 1 return total