chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Transformation for Calling Google models in their native format.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.google_genai.transformation import (
|
||||
BaseGoogleGenAIGenerateContentConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentContentListUnionDict,
|
||||
GenerateContentResponse,
|
||||
ToolConfigDict,
|
||||
)
|
||||
else:
|
||||
GenerateContentConfigDict = Any
|
||||
GenerateContentContentListUnionDict = Any
|
||||
GenerateContentResponse = Any
|
||||
ToolConfigDict = Any
|
||||
|
||||
from ..common_utils import get_api_key_from_env
|
||||
|
||||
|
||||
class GoogleGenAIConfig(BaseGoogleGenAIGenerateContentConfig, VertexLLM):
|
||||
"""
|
||||
Configuration for calling Google models in their native format.
|
||||
"""
|
||||
|
||||
##############################
|
||||
# Constants
|
||||
##############################
|
||||
XGOOGLE_API_KEY = "x-goog-api-key"
|
||||
##############################
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]:
|
||||
return "gemini"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_generate_content_optional_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the list of supported Google GenAI parameters for the model.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
List of supported parameter names
|
||||
"""
|
||||
return [
|
||||
"http_options",
|
||||
"system_instruction",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"candidate_count",
|
||||
"max_output_tokens",
|
||||
"stop_sequences",
|
||||
"response_logprobs",
|
||||
"logprobs",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"seed",
|
||||
"response_mime_type",
|
||||
"response_schema",
|
||||
"response_json_schema",
|
||||
"routing_config",
|
||||
"model_selection_config",
|
||||
"safety_settings",
|
||||
"tools",
|
||||
"tool_config",
|
||||
"labels",
|
||||
"cached_content",
|
||||
"response_modalities",
|
||||
"media_resolution",
|
||||
"speech_config",
|
||||
"audio_timestamp",
|
||||
"automatic_function_calling",
|
||||
"thinking_config",
|
||||
"image_config",
|
||||
]
|
||||
|
||||
def map_generate_content_optional_params(
|
||||
self,
|
||||
generate_content_config_dict: GenerateContentConfigDict,
|
||||
model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map Google GenAI parameters to provider-specific format.
|
||||
|
||||
Args:
|
||||
generate_content_optional_params: Optional parameters for generate content
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
Mapped parameters for the provider
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.transformation import (
|
||||
_camel_to_snake,
|
||||
_snake_to_camel,
|
||||
)
|
||||
|
||||
_generate_content_config_dict: Dict[str, Any] = {}
|
||||
supported_google_genai_params = (
|
||||
self.get_supported_generate_content_optional_params(model)
|
||||
)
|
||||
# Create a set with both camelCase and snake_case versions for faster lookup
|
||||
supported_params_set = set(supported_google_genai_params)
|
||||
supported_params_set.update(
|
||||
_snake_to_camel(p) for p in supported_google_genai_params
|
||||
)
|
||||
supported_params_set.update(
|
||||
_camel_to_snake(p) for p in supported_google_genai_params if "_" not in p
|
||||
)
|
||||
|
||||
for param, value in generate_content_config_dict.items():
|
||||
# Google GenAI API expects camelCase, so we'll always output in camelCase
|
||||
# Check if param (or its variants) is supported
|
||||
param_snake = _camel_to_snake(param)
|
||||
param_camel = _snake_to_camel(param)
|
||||
|
||||
# Check if param is supported in any format
|
||||
is_supported = (
|
||||
param in supported_google_genai_params
|
||||
or param_snake in supported_google_genai_params
|
||||
or param_camel in supported_google_genai_params
|
||||
)
|
||||
|
||||
if is_supported:
|
||||
# Always output in camelCase for Google GenAI API
|
||||
output_key = param_camel if param != param_camel else param
|
||||
_generate_content_config_dict[output_key] = value
|
||||
return _generate_content_config_dict
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
headers: Optional[dict],
|
||||
model: str,
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# Use the passed api_key first, then fall back to litellm_params and environment
|
||||
gemini_api_key = api_key or self._get_google_ai_studio_api_key(
|
||||
dict(litellm_params or {})
|
||||
)
|
||||
if isinstance(gemini_api_key, dict):
|
||||
default_headers.update(gemini_api_key)
|
||||
elif gemini_api_key is not None:
|
||||
default_headers[self.XGOOGLE_API_KEY] = gemini_api_key
|
||||
if headers is not None:
|
||||
default_headers.update(headers)
|
||||
|
||||
return default_headers
|
||||
|
||||
def _get_google_ai_studio_api_key(self, litellm_params: dict) -> Optional[str]:
|
||||
return (
|
||||
litellm_params.pop("api_key", None)
|
||||
or litellm_params.pop("gemini_api_key", None)
|
||||
or get_api_key_from_env()
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
def _get_common_auth_components(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get common authentication components used by both sync and async methods.
|
||||
|
||||
Returns:
|
||||
Tuple of (vertex_credentials, vertex_project, vertex_location)
|
||||
"""
|
||||
vertex_credentials = self.get_vertex_ai_credentials(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
return vertex_credentials, vertex_project, vertex_location
|
||||
|
||||
def _build_final_headers_and_url(
|
||||
self,
|
||||
model: str,
|
||||
auth_header: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Any,
|
||||
stream: bool,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Build final headers and API URL from auth components.
|
||||
"""
|
||||
gemini_api_key = self._get_google_ai_studio_api_key(litellm_params)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=True,
|
||||
)
|
||||
|
||||
headers = self.validate_environment(
|
||||
api_key=auth_header,
|
||||
headers=None,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def sync_get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Sync version of get_auth_token_and_url.
|
||||
"""
|
||||
(
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
) = self._get_common_auth_components(litellm_params)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
return self._build_final_headers_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Get the complete URL for the request.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
model: The model name
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Tuple of headers and API base
|
||||
"""
|
||||
(
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
) = self._get_common_auth_components(litellm_params)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
return self._build_final_headers_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
def transform_generate_content_request(
|
||||
self,
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
tools: Optional[ToolConfigDict],
|
||||
generate_content_config_dict: Dict,
|
||||
system_instruction: Optional[Any] = None,
|
||||
) -> dict:
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentRequestDict,
|
||||
)
|
||||
|
||||
typed_generate_content_request = GenerateContentRequestDict(
|
||||
model=model,
|
||||
contents=contents,
|
||||
tools=tools,
|
||||
generationConfig=GenerateContentConfigDict(**generate_content_config_dict),
|
||||
)
|
||||
|
||||
request_dict = cast(dict, typed_generate_content_request)
|
||||
|
||||
if system_instruction is not None:
|
||||
request_dict["systemInstruction"] = system_instruction
|
||||
return request_dict
|
||||
|
||||
def transform_generate_content_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> GenerateContentResponse:
|
||||
"""
|
||||
Transform the raw response from the generate content API.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
raw_response: Raw HTTP response
|
||||
|
||||
Returns:
|
||||
Transformed response data
|
||||
"""
|
||||
from litellm.types.google_genai.main import GenerateContentResponse
|
||||
|
||||
try:
|
||||
response = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming generate content response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
logging_obj.model_call_details["httpx_response"] = raw_response
|
||||
response = self.convert_citation_sources_to_citations(response)
|
||||
|
||||
return GenerateContentResponse(**response)
|
||||
|
||||
def convert_citation_sources_to_citations(self, response: Dict) -> Dict:
|
||||
"""
|
||||
Convert citation sources to citations.
|
||||
API's camelCase citationSources becomes the SDK's snake_case citations
|
||||
"""
|
||||
if "candidates" in response:
|
||||
for candidate in response["candidates"]:
|
||||
if "citationMetadata" in candidate and isinstance(
|
||||
candidate["citationMetadata"], dict
|
||||
):
|
||||
citation_metadata = candidate["citationMetadata"]
|
||||
# Transform citationSources to citations to match expected schema
|
||||
if "citationSources" in citation_metadata:
|
||||
citation_metadata["citations"] = citation_metadata.pop(
|
||||
"citationSources"
|
||||
)
|
||||
return response
|
||||
Reference in New Issue
Block a user