""" GigaChat Embedding Transformation Transforms OpenAI /v1/embeddings format to GigaChat format. API Documentation: https://developers.sber.ru/docs/ru/gigachat/api/reference/rest/post-embeddings """ import types from typing import List, Optional, Tuple, Union import httpx from litellm import LlmProviders from litellm.llms.base_llm.chat.transformation import BaseLLMException from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues from litellm.types.utils import EmbeddingResponse from ..authenticator import get_access_token # GigaChat API endpoint GIGACHAT_BASE_URL = "https://gigachat.devices.sberbank.ru/api/v1" class GigaChatEmbeddingError(BaseLLMException): """GigaChat Embedding API error.""" pass class GigaChatEmbeddingConfig(BaseEmbeddingConfig): """ Configuration class for GigaChat Embeddings API. GigaChat embeddings endpoint: POST /api/v1/embeddings """ def __init__(self) -> None: pass @classmethod def get_config(cls): return { k: v for k, v in cls.__dict__.items() if not k.startswith("__") and not isinstance( v, ( types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod, ), ) and v is not None } def get_supported_openai_params(self, model: str) -> List[str]: """GigaChat embeddings don't support additional parameters.""" return [] def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool, ) -> dict: """Map OpenAI params to GigaChat format (no special mapping needed).""" return optional_params def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str], ) -> Tuple[str, Optional[str], Optional[str]]: """ Returns provider info for GigaChat. Returns: Tuple of (custom_llm_provider, api_base, dynamic_api_key) """ api_base = api_base or GIGACHAT_BASE_URL return LlmProviders.GIGACHAT.value, api_base, api_key def get_complete_url( self, api_base: Optional[str], api_key: Optional[str], model: str, optional_params: dict, litellm_params: dict, stream: Optional[bool] = None, ) -> str: """Get the complete URL for embeddings endpoint.""" base = api_base or GIGACHAT_BASE_URL return f"{base}/embeddings" def transform_embedding_request( self, model: str, input: AllEmbeddingInputValues, optional_params: dict, headers: dict, ) -> dict: """ Transform OpenAI embedding request to GigaChat format. GigaChat format: { "model": "Embeddings", "input": ["text1", "text2", ...] } """ # Normalize input to list if isinstance(input, str): input_list: list = [input] elif isinstance(input, list): input_list = input else: input_list = [input] # Remove gigachat/ prefix from model if present if model.startswith("gigachat/"): model = model[9:] return { "model": model, "input": input_list, } def transform_embedding_response( self, model: str, raw_response: httpx.Response, model_response: EmbeddingResponse, logging_obj: LiteLLMLoggingObj, api_key: Optional[str], request_data: dict, optional_params: dict, litellm_params: dict, ) -> EmbeddingResponse: """ Transform GigaChat embedding response to OpenAI format. GigaChat returns: { "object": "list", "data": [{"object": "embedding", "embedding": [...], "index": 0, "usage": {...}}], "model": "Embeddings" } """ response_json = raw_response.json() # Log response logging_obj.post_call( input=request_data.get("input"), api_key=api_key, additional_args={"complete_input_dict": request_data}, original_response=response_json, ) # Calculate total tokens from individual embeddings total_tokens = 0 if "data" in response_json: for emb in response_json["data"]: if "usage" in emb and "prompt_tokens" in emb["usage"]: total_tokens += emb["usage"]["prompt_tokens"] # Remove usage from individual embeddings (not part of OpenAI format) if "usage" in emb: del emb["usage"] # Set overall usage response_json["usage"] = { "prompt_tokens": total_tokens, "total_tokens": total_tokens, } return EmbeddingResponse(**response_json) def validate_environment( self, headers: dict, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: """ Set up headers with OAuth token for GigaChat. """ # Get access token via OAuth access_token = get_access_token(api_key) default_headers = { "Content-Type": "application/json", "Authorization": f"Bearer {access_token}", } return {**default_headers, **headers} def get_error_class( self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers] ) -> BaseLLMException: """Return GigaChat-specific error class.""" return GigaChatEmbeddingError( status_code=status_code, message=error_message, )