chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,153 @@
|
||||
from typing import List, Optional, cast
|
||||
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
convert_generic_image_chunk_to_openai_image_obj,
|
||||
convert_to_anthropic_image_obj,
|
||||
)
|
||||
from litellm.litellm_core_utils.prompt_templates.image_handling import (
|
||||
convert_url_to_base64,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionFileObject
|
||||
from litellm.types.llms.vertex_ai import ContentType, PartType
|
||||
from litellm.utils import supports_reasoning
|
||||
|
||||
from ...vertex_ai.gemini.transformation import _gemini_convert_messages_with_history
|
||||
from ...vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexGeminiConfig
|
||||
|
||||
|
||||
class GoogleAIStudioGeminiConfig(VertexGeminiConfig):
|
||||
"""
|
||||
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
||||
|
||||
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||
|
||||
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||
|
||||
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||
|
||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||
|
||||
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
|
||||
|
||||
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
|
||||
|
||||
- `candidate_count` (int): Number of generated responses to return.
|
||||
|
||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
response_mime_type: Optional[str] = None
|
||||
response_schema: Optional[dict] = None
|
||||
candidate_count: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
response_schema: Optional[dict] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def is_model_gemini_audio_model(self, model: str) -> bool:
|
||||
return "tts" in model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"functions",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
"logprobs",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"modalities",
|
||||
"parallel_tool_calls",
|
||||
"web_search_options",
|
||||
]
|
||||
if supports_reasoning(model, custom_llm_provider="gemini"):
|
||||
supported_params.append("reasoning_effort")
|
||||
supported_params.append("thinking")
|
||||
if self.is_model_gemini_audio_model(model):
|
||||
supported_params.append("audio")
|
||||
return supported_params
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: Optional[str] = None
|
||||
) -> List[ContentType]:
|
||||
"""
|
||||
Google AI Studio Gemini does not support HTTP/HTTPS URLs for files.
|
||||
Convert them to base64 data instead.
|
||||
"""
|
||||
for message in messages:
|
||||
_message_content = message.get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element in _message_content:
|
||||
if element.get("type") == "image_url":
|
||||
img_element = element
|
||||
_image_url: Optional[str] = None
|
||||
format: Optional[str] = None
|
||||
detail: Optional[str] = None
|
||||
if isinstance(img_element.get("image_url"), dict):
|
||||
_image_url = img_element["image_url"].get("url") # type: ignore
|
||||
format = img_element["image_url"].get("format") # type: ignore
|
||||
detail = img_element["image_url"].get("detail") # type: ignore
|
||||
else:
|
||||
_image_url = img_element.get("image_url") # type: ignore
|
||||
if _image_url and "https://" in _image_url:
|
||||
image_obj = convert_to_anthropic_image_obj(
|
||||
_image_url, format=format
|
||||
)
|
||||
converted_image_url = (
|
||||
convert_generic_image_chunk_to_openai_image_obj(
|
||||
image_obj
|
||||
)
|
||||
)
|
||||
if detail is not None:
|
||||
img_element["image_url"] = { # type: ignore
|
||||
"url": converted_image_url,
|
||||
"detail": detail,
|
||||
}
|
||||
else:
|
||||
img_element["image_url"] = converted_image_url # type: ignore
|
||||
elif element.get("type") == "file":
|
||||
file_element = cast(ChatCompletionFileObject, element)
|
||||
file_id = file_element["file"].get("file_id")
|
||||
if file_id and ("http://" in file_id or "https://" in file_id):
|
||||
# Convert HTTP/HTTPS file URL to base64 data
|
||||
try:
|
||||
base64_data = convert_url_to_base64(file_id)
|
||||
file_element["file"]["file_data"] = base64_data # type: ignore
|
||||
file_element["file"].pop("file_id", None) # type: ignore
|
||||
except Exception:
|
||||
# If conversion fails, leave as is and let the API handle it
|
||||
pass
|
||||
return _gemini_convert_messages_with_history(messages=messages, model=model)
|
||||
@@ -0,0 +1,204 @@
|
||||
import base64
|
||||
import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo, BaseTokenCounter
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
|
||||
|
||||
class GeminiError(BaseLLMException):
|
||||
pass
|
||||
|
||||
|
||||
class GeminiModelInfo(BaseLLMModelInfo):
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""Google AI Studio sends api key in query params"""
|
||||
return headers
|
||||
|
||||
@property
|
||||
def api_version(self) -> str:
|
||||
return "v1beta"
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_base
|
||||
or get_secret_str("GEMINI_API_BASE")
|
||||
or "https://generativelanguage.googleapis.com"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or (get_secret_str("GOOGLE_API_KEY"))
|
||||
or (get_secret_str("GEMINI_API_KEY"))
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model.replace("gemini/", "")
|
||||
|
||||
def process_model_name(self, models: List[Dict[str, str]]) -> List[str]:
|
||||
litellm_model_names = []
|
||||
for model in models:
|
||||
stripped_model_name = model["name"].replace("models/", "")
|
||||
litellm_model_name = "gemini/" + stripped_model_name
|
||||
litellm_model_names.append(litellm_model_name)
|
||||
return litellm_model_names
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = GeminiModelInfo.get_api_base(api_base)
|
||||
api_key = GeminiModelInfo.get_api_key(api_key)
|
||||
endpoint = f"/{self.api_version}/models"
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"GEMINI_API_BASE or GEMINI_API_KEY/GOOGLE_API_KEY is not set. Please set the environment variable, to query Gemini's `/models` endpoint."
|
||||
)
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}{endpoint}?key={api_key}",
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to fetch models from Gemini. Status code: {response.status_code}, Response: {response.json()}"
|
||||
)
|
||||
|
||||
models = response.json()["models"]
|
||||
|
||||
litellm_model_names = self.process_model_name(models)
|
||||
return litellm_model_names
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return GeminiError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def get_token_counter(self) -> Optional[BaseTokenCounter]:
|
||||
"""
|
||||
Factory method to create a token counter for this provider.
|
||||
|
||||
Returns:
|
||||
Optional TokenCounterInterface implementation for this provider,
|
||||
or None if token counting is not supported.
|
||||
"""
|
||||
return GoogleAIStudioTokenCounter()
|
||||
|
||||
|
||||
def encode_unserializable_types(
|
||||
data: Dict[str, object], depth: int = 0
|
||||
) -> Dict[str, object]:
|
||||
"""Converts unserializable types in dict to json.dumps() compatible types.
|
||||
|
||||
This function is called in models.py after calling convert_to_dict(). The
|
||||
convert_to_dict() can convert pydantic object to dict. However, the input to
|
||||
convert_to_dict() is dict mixed of pydantic object and nested dict(the output
|
||||
of converters). So they may be bytes in the dict and they are out of
|
||||
`ser_json_bytes` control in model_dump(mode='json') called in
|
||||
`convert_to_dict`, as well as datetime deserialization in Pydantic json mode.
|
||||
|
||||
Returns:
|
||||
A dictionary with json.dumps() incompatible type (e.g. bytes datetime)
|
||||
to compatible type (e.g. base64 encoded string, isoformat date string).
|
||||
"""
|
||||
if depth > DEFAULT_MAX_RECURSE_DEPTH:
|
||||
return data
|
||||
processed_data: dict[str, object] = {}
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
for key, value in data.items():
|
||||
if isinstance(value, bytes):
|
||||
processed_data[key] = base64.urlsafe_b64encode(value).decode("ascii")
|
||||
elif isinstance(value, datetime.datetime):
|
||||
processed_data[key] = value.isoformat()
|
||||
elif isinstance(value, dict):
|
||||
processed_data[key] = encode_unserializable_types(value, depth + 1)
|
||||
elif isinstance(value, list):
|
||||
if all(isinstance(v, bytes) for v in value):
|
||||
processed_data[key] = [
|
||||
base64.urlsafe_b64encode(v).decode("ascii") for v in value
|
||||
]
|
||||
if all(isinstance(v, datetime.datetime) for v in value):
|
||||
processed_data[key] = [v.isoformat() for v in value]
|
||||
else:
|
||||
processed_data[key] = [
|
||||
encode_unserializable_types(v, depth + 1) for v in value
|
||||
]
|
||||
else:
|
||||
processed_data[key] = value
|
||||
return processed_data
|
||||
|
||||
|
||||
def get_api_key_from_env() -> Optional[str]:
|
||||
return get_secret_str("GOOGLE_API_KEY") or get_secret_str("GEMINI_API_KEY")
|
||||
|
||||
|
||||
class GoogleAIStudioTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for Google AI Studio provider."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
return custom_llm_provider == LlmProviders.GEMINI.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
import copy
|
||||
|
||||
from litellm.llms.gemini.count_tokens.handler import GoogleAIStudioTokenCounter
|
||||
|
||||
deployment = deployment or {}
|
||||
count_tokens_params_request = copy.deepcopy(
|
||||
deployment.get("litellm_params", {})
|
||||
)
|
||||
count_tokens_params = {
|
||||
"model": model_to_use,
|
||||
"contents": contents,
|
||||
}
|
||||
count_tokens_params_request.update(count_tokens_params)
|
||||
result = await GoogleAIStudioTokenCounter().acount_tokens(
|
||||
**count_tokens_params_request,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("totalTokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type=result.get("tokenizer_used", ""),
|
||||
original_response=result,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1 @@
|
||||
[Go here for the Gemini Context Caching code](../../vertex_ai/context_caching/)
|
||||
@@ -0,0 +1,56 @@
|
||||
"""
|
||||
This file is used to calculate the cost of the Gemini API.
|
||||
|
||||
Handles the context caching for Gemini API.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelInfo, Usage
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str, usage: "Usage", service_tier: Optional[str] = None
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Follows the same logic as Anthropic's cost per token calculation.
|
||||
"""
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
|
||||
return generic_cost_per_token(
|
||||
model=model,
|
||||
usage=usage,
|
||||
custom_llm_provider="gemini",
|
||||
service_tier=service_tier,
|
||||
)
|
||||
|
||||
|
||||
def cost_per_web_search_request(usage: "Usage", model_info: "ModelInfo") -> float:
|
||||
"""
|
||||
Calculates the cost per web search request for a given model, prompt tokens, and completion tokens.
|
||||
"""
|
||||
from litellm.types.utils import PromptTokensDetailsWrapper
|
||||
|
||||
# cost per web search request
|
||||
cost_per_web_search_request = 35e-3
|
||||
|
||||
number_of_web_search_requests = 0
|
||||
# Get number of web search requests
|
||||
if (
|
||||
usage is not None
|
||||
and usage.prompt_tokens_details is not None
|
||||
and isinstance(usage.prompt_tokens_details, PromptTokensDetailsWrapper)
|
||||
and hasattr(usage.prompt_tokens_details, "web_search_requests")
|
||||
and usage.prompt_tokens_details.web_search_requests is not None
|
||||
):
|
||||
number_of_web_search_requests = usage.prompt_tokens_details.web_search_requests
|
||||
else:
|
||||
number_of_web_search_requests = 0
|
||||
|
||||
# Calculate total cost
|
||||
total_cost = cost_per_web_search_request * number_of_web_search_requests
|
||||
|
||||
return total_cost
|
||||
@@ -0,0 +1,168 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.google_genai.main import GenerateContentContentListUnionDict
|
||||
else:
|
||||
GenerateContentContentListUnionDict = Any
|
||||
|
||||
|
||||
class GoogleAIStudioTokenCounter:
|
||||
def _clean_contents_for_gemini_api(self, contents: Any) -> Any:
|
||||
"""
|
||||
Clean up contents to remove unsupported fields for the Gemini API.
|
||||
|
||||
The Google Gemini API doesn't recognize the 'id' field in function responses,
|
||||
so we need to remove it to prevent 400 Bad Request errors.
|
||||
|
||||
Args:
|
||||
contents: The contents to clean up
|
||||
|
||||
Returns:
|
||||
Cleaned contents with unsupported fields removed
|
||||
"""
|
||||
import copy
|
||||
|
||||
from google.genai.types import FunctionResponse
|
||||
|
||||
# Handle None or empty contents
|
||||
if not contents:
|
||||
return contents
|
||||
|
||||
cleaned_contents = copy.deepcopy(contents)
|
||||
|
||||
for content in cleaned_contents:
|
||||
parts = content["parts"]
|
||||
for part in parts:
|
||||
if "functionResponse" in part:
|
||||
function_response_data = part["functionResponse"]
|
||||
function_response_part = FunctionResponse(**function_response_data)
|
||||
function_response_part.id = None
|
||||
part["functionResponse"] = function_response_part.model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
|
||||
return cleaned_contents
|
||||
|
||||
def _construct_url(self, model: str, api_base: Optional[str] = None) -> str:
|
||||
"""
|
||||
Construct the URL for the Google Gen AI Studio countTokens endpoint.
|
||||
"""
|
||||
base_url = api_base or "https://generativelanguage.googleapis.com"
|
||||
return f"{base_url}/v1beta/models/{model}:countTokens"
|
||||
|
||||
async def validate_environment(
|
||||
self,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
model: str = "",
|
||||
litellm_params: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[Dict[str, Any], str]:
|
||||
"""
|
||||
Returns a Tuple of headers and url for the Google Gen AI Studio countTokens endpoint.
|
||||
"""
|
||||
from litellm.llms.gemini.google_genai.transformation import GoogleGenAIConfig
|
||||
|
||||
headers = GoogleGenAIConfig().validate_environment(
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
url = self._construct_url(model=model, api_base=api_base)
|
||||
return headers, url
|
||||
|
||||
async def acount_tokens(
|
||||
self,
|
||||
contents: Any,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Count tokens using Google Gen AI Studio countTokens endpoint.
|
||||
|
||||
Args:
|
||||
contents: The content to count tokens for (Google Gen AI format)
|
||||
Example: [{"parts": [{"text": "Hello world"}]}]
|
||||
model: The model name (e.g. "gemini-1.5-flash")
|
||||
api_key: Optional Google API key (will fall back to environment)
|
||||
api_base: Optional API base URL (defaults to Google Gen AI Studio)
|
||||
timeout: Optional timeout for the request
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
Dict containing token count information from Google Gen AI Studio API.
|
||||
Example response:
|
||||
{
|
||||
"totalTokens": 31,
|
||||
"totalBillableCharacters": 96,
|
||||
"promptTokensDetails": [
|
||||
{
|
||||
"modality": "TEXT",
|
||||
"tokenCount": 31
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Raises:
|
||||
ValueError: If API key is missing
|
||||
litellm.APIError: If the API call fails
|
||||
litellm.APIConnectionError: If the connection fails
|
||||
Exception: For any other unexpected errors
|
||||
"""
|
||||
|
||||
# Prepare headers
|
||||
headers, url = await self.validate_environment(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
headers={},
|
||||
model=model,
|
||||
litellm_params=kwargs,
|
||||
)
|
||||
|
||||
# Prepare request body - clean up contents to remove unsupported fields
|
||||
cleaned_contents = self._clean_contents_for_gemini_api(contents)
|
||||
request_body = {"contents": cleaned_contents}
|
||||
|
||||
async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.GEMINI,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.post(
|
||||
url=url, headers=headers, json=request_body
|
||||
)
|
||||
|
||||
# Check for HTTP errors
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse response
|
||||
result = response.json()
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_msg = f"Google Gen AI Studio API error: {e.response.status_code} - {e.response.text}"
|
||||
raise litellm.APIError(
|
||||
message=error_msg,
|
||||
llm_provider="gemini",
|
||||
model=model,
|
||||
status_code=e.response.status_code,
|
||||
) from e
|
||||
except httpx.RequestError as e:
|
||||
error_msg = f"Request to Google Gen AI Studio failed: {str(e)}"
|
||||
raise litellm.APIConnectionError(
|
||||
message=error_msg, llm_provider="gemini", model=model
|
||||
) from e
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error during token counting: {str(e)}"
|
||||
raise Exception(error_msg) from e
|
||||
@@ -0,0 +1,389 @@
|
||||
"""
|
||||
Supports writing files to Google AI Studio Files API.
|
||||
|
||||
For vertex ai, check out the vertex_ai/files/handler.py file.
|
||||
"""
|
||||
import time
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import httpx
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.gemini import GeminiCreateFilesResponseObject
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
)
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
from ..common_utils import GeminiModelInfo
|
||||
|
||||
|
||||
class GoogleAIStudioFilesHandler(GeminiModelInfo, BaseFilesConfig):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.GEMINI
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict[Any, Any],
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict[Any, Any],
|
||||
litellm_params: dict[Any, Any],
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict[Any, Any]:
|
||||
"""
|
||||
Validate environment and add Gemini API key to headers.
|
||||
Google AI Studio uses x-goog-api-key header for authentication.
|
||||
"""
|
||||
resolved_api_key = self.get_api_key(api_key)
|
||||
if not resolved_api_key:
|
||||
raise ValueError(
|
||||
"GEMINI_API_KEY is required for Google AI Studio file operations"
|
||||
)
|
||||
|
||||
headers["x-goog-api-key"] = resolved_api_key
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
OPTIONAL
|
||||
|
||||
Get the complete url for the request
|
||||
|
||||
Some providers need `model` in `api_base`
|
||||
"""
|
||||
endpoint = "upload/v1beta/files"
|
||||
api_base = self.get_api_base(api_base)
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
# Get API key from multiple sources
|
||||
final_api_key = api_key or litellm_params.get("api_key") or self.get_api_key()
|
||||
if not final_api_key:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
url = "{}/{}?key={}".format(api_base, endpoint, final_api_key)
|
||||
return url
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the OpenAI-style file creation request into Gemini's format
|
||||
|
||||
Returns:
|
||||
dict: Contains both request data and headers for the two-step upload
|
||||
"""
|
||||
# Extract the file information
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("File data is required")
|
||||
|
||||
# Use the common utility function to extract file data
|
||||
extracted_data = extract_file_data(file_data)
|
||||
|
||||
# Get file size
|
||||
file_size = len(extracted_data["content"])
|
||||
|
||||
# Step 1: Initial resumable upload request
|
||||
headers = {
|
||||
"X-Goog-Upload-Protocol": "resumable",
|
||||
"X-Goog-Upload-Command": "start",
|
||||
"X-Goog-Upload-Header-Content-Length": str(file_size),
|
||||
"X-Goog-Upload-Header-Content-Type": extracted_data["content_type"],
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
headers.update(extracted_data["headers"]) # Add any custom headers
|
||||
|
||||
# Initial metadata request body
|
||||
initial_data = {
|
||||
"file": {
|
||||
"display_name": extracted_data["filename"] or str(int(time.time()))
|
||||
}
|
||||
}
|
||||
|
||||
# Step 2: Actual file upload data
|
||||
upload_headers = {
|
||||
"Content-Length": str(file_size),
|
||||
"X-Goog-Upload-Offset": "0",
|
||||
"X-Goog-Upload-Command": "upload, finalize",
|
||||
}
|
||||
|
||||
return {
|
||||
"initial_request": {"headers": headers, "data": initial_data},
|
||||
"upload_request": {
|
||||
"headers": upload_headers,
|
||||
"data": extracted_data["content"],
|
||||
},
|
||||
}
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform Gemini's file upload response into OpenAI-style FileObject
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
|
||||
response_object = GeminiCreateFilesResponseObject(
|
||||
**response_json.get("file", {}) # type: ignore
|
||||
)
|
||||
|
||||
# Extract file information from Gemini response
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=response_object["uri"], # Gemini uses URI as identifier
|
||||
bytes=int(
|
||||
response_object["sizeBytes"]
|
||||
), # Gemini doesn't return file size
|
||||
created_at=int(
|
||||
time.mktime(
|
||||
time.strptime(
|
||||
response_object["createTime"].replace("Z", "+00:00"),
|
||||
"%Y-%m-%dT%H:%M:%S.%f%z",
|
||||
)
|
||||
)
|
||||
),
|
||||
filename=response_object["displayName"],
|
||||
object="file",
|
||||
purpose="user_data", # Default to assistants as that's the main use case
|
||||
status="uploaded",
|
||||
status_details=None,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error parsing file upload response: {str(e)}")
|
||||
raise ValueError(f"Error parsing file upload response: {str(e)}")
|
||||
|
||||
def transform_retrieve_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Get the URL to retrieve a file from Google AI Studio.
|
||||
|
||||
We expect file_id to be the URI (e.g. https://generativelanguage.googleapis.com/v1beta/files/...)
|
||||
as returned by the upload response.
|
||||
"""
|
||||
api_key = litellm_params.get("api_key") or self.get_api_key()
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
if file_id.startswith("http"):
|
||||
url = "{}?key={}".format(file_id, api_key)
|
||||
else:
|
||||
# Fallback for just file name (files/...)
|
||||
api_base = (
|
||||
self.get_api_base(litellm_params.get("api_base"))
|
||||
or "https://generativelanguage.googleapis.com"
|
||||
)
|
||||
api_base = api_base.rstrip("/")
|
||||
url = "{}/v1beta/{}?key={}".format(api_base, file_id, api_key)
|
||||
|
||||
# Return empty params dict - API key is already in URL, no query params needed
|
||||
return url, {}
|
||||
|
||||
def transform_retrieve_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform Gemini's file retrieval response into OpenAI-style FileObject
|
||||
"""
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
|
||||
# Map Gemini state to OpenAI status
|
||||
gemini_state = response_json.get("state", "STATE_UNSPECIFIED")
|
||||
# Explicitly type status as the Literal union
|
||||
if gemini_state == "ACTIVE":
|
||||
status: Literal["uploaded", "processed", "error"] = "processed"
|
||||
elif gemini_state == "FAILED":
|
||||
status = "error"
|
||||
else:
|
||||
status = "uploaded"
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=response_json.get("uri", ""),
|
||||
bytes=int(response_json.get("sizeBytes", 0)),
|
||||
created_at=int(
|
||||
time.mktime(
|
||||
time.strptime(
|
||||
response_json["createTime"].replace("Z", "+00:00"),
|
||||
"%Y-%m-%dT%H:%M:%S.%f%z",
|
||||
)
|
||||
)
|
||||
),
|
||||
filename=response_json.get("displayName", ""),
|
||||
object="file",
|
||||
purpose="user_data",
|
||||
status=status,
|
||||
status_details=str(response_json.get("error", ""))
|
||||
if gemini_state == "FAILED"
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error parsing file retrieve response: {str(e)}")
|
||||
raise ValueError(f"Error parsing file retrieve response: {str(e)}")
|
||||
|
||||
def transform_delete_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
"""
|
||||
Transform delete file request for Google AI Studio.
|
||||
|
||||
Args:
|
||||
file_id: The file URI (e.g., "files/abc123" or full URI)
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters containing api_key
|
||||
|
||||
Returns:
|
||||
tuple[str, dict]: (url, params) for the DELETE request
|
||||
"""
|
||||
api_base = self.get_api_base(litellm_params.get("api_base"))
|
||||
if not api_base:
|
||||
raise ValueError("api_base is required")
|
||||
|
||||
# Get API key from multiple sources (same pattern as get_complete_url)
|
||||
api_key = litellm_params.get("api_key") or self.get_api_key()
|
||||
if not api_key:
|
||||
raise ValueError("api_key is required")
|
||||
|
||||
# Extract file name from URI if full URI is provided
|
||||
# file_id could be "files/abc123" or "https://generativelanguage.googleapis.com/v1beta/files/abc123"
|
||||
if file_id.startswith("http"):
|
||||
# Extract the file path from full URI
|
||||
file_name = file_id.split("/v1beta/")[-1]
|
||||
else:
|
||||
file_name = file_id if file_id.startswith("files/") else f"files/{file_id}"
|
||||
|
||||
# Construct the delete URL
|
||||
url = f"{api_base}/v1beta/{file_name}"
|
||||
|
||||
# Add API key as header (Google AI Studio uses x-goog-api-key header)
|
||||
params: dict = {}
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_delete_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> FileDeleted:
|
||||
"""
|
||||
Transform Gemini's file delete response into OpenAI-style FileDeleted.
|
||||
|
||||
Google AI Studio returns an empty JSON object {} on successful deletion.
|
||||
"""
|
||||
try:
|
||||
# Google AI Studio returns {} on successful deletion
|
||||
if raw_response.status_code == 200:
|
||||
# Extract file ID from the request URL if possible
|
||||
file_id = "deleted"
|
||||
if hasattr(raw_response, "request") and raw_response.request:
|
||||
url = str(raw_response.request.url)
|
||||
if "/files/" in url:
|
||||
file_id = url.split("/files/")[-1].split("?")[0]
|
||||
# Add the files/ prefix if not present
|
||||
if not file_id.startswith("files/"):
|
||||
file_id = f"files/{file_id}"
|
||||
|
||||
return FileDeleted(id=file_id, deleted=True, object="file")
|
||||
else:
|
||||
raise ValueError(f"Failed to delete file: {raw_response.text}")
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"Error parsing file delete response: {str(e)}")
|
||||
raise ValueError(f"Error parsing file delete response: {str(e)}")
|
||||
|
||||
def transform_list_files_request(
|
||||
self,
|
||||
purpose: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError(
|
||||
"GoogleAIStudioFilesHandler does not support file listing"
|
||||
)
|
||||
|
||||
def transform_list_files_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
raise NotImplementedError(
|
||||
"GoogleAIStudioFilesHandler does not support file listing"
|
||||
)
|
||||
|
||||
def transform_file_content_request(
|
||||
self,
|
||||
file_content_request,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError(
|
||||
"GoogleAIStudioFilesHandler does not support file content retrieval"
|
||||
)
|
||||
|
||||
def transform_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
raise NotImplementedError(
|
||||
"GoogleAIStudioFilesHandler does not support file content retrieval"
|
||||
)
|
||||
@@ -0,0 +1,379 @@
|
||||
"""
|
||||
Transformation for Calling Google models in their native format.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.google_genai.transformation import (
|
||||
BaseGoogleGenAIGenerateContentConfig,
|
||||
)
|
||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentContentListUnionDict,
|
||||
GenerateContentResponse,
|
||||
ToolConfigDict,
|
||||
)
|
||||
else:
|
||||
GenerateContentConfigDict = Any
|
||||
GenerateContentContentListUnionDict = Any
|
||||
GenerateContentResponse = Any
|
||||
ToolConfigDict = Any
|
||||
|
||||
from ..common_utils import get_api_key_from_env
|
||||
|
||||
|
||||
class GoogleGenAIConfig(BaseGoogleGenAIGenerateContentConfig, VertexLLM):
|
||||
"""
|
||||
Configuration for calling Google models in their native format.
|
||||
"""
|
||||
|
||||
##############################
|
||||
# Constants
|
||||
##############################
|
||||
XGOOGLE_API_KEY = "x-goog-api-key"
|
||||
##############################
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Literal["gemini", "vertex_ai"]:
|
||||
return "gemini"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
VertexLLM.__init__(self)
|
||||
|
||||
def get_supported_generate_content_optional_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the list of supported Google GenAI parameters for the model.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
List of supported parameter names
|
||||
"""
|
||||
return [
|
||||
"http_options",
|
||||
"system_instruction",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"candidate_count",
|
||||
"max_output_tokens",
|
||||
"stop_sequences",
|
||||
"response_logprobs",
|
||||
"logprobs",
|
||||
"presence_penalty",
|
||||
"frequency_penalty",
|
||||
"seed",
|
||||
"response_mime_type",
|
||||
"response_schema",
|
||||
"response_json_schema",
|
||||
"routing_config",
|
||||
"model_selection_config",
|
||||
"safety_settings",
|
||||
"tools",
|
||||
"tool_config",
|
||||
"labels",
|
||||
"cached_content",
|
||||
"response_modalities",
|
||||
"media_resolution",
|
||||
"speech_config",
|
||||
"audio_timestamp",
|
||||
"automatic_function_calling",
|
||||
"thinking_config",
|
||||
"image_config",
|
||||
]
|
||||
|
||||
def map_generate_content_optional_params(
|
||||
self,
|
||||
generate_content_config_dict: GenerateContentConfigDict,
|
||||
model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map Google GenAI parameters to provider-specific format.
|
||||
|
||||
Args:
|
||||
generate_content_optional_params: Optional parameters for generate content
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
Mapped parameters for the provider
|
||||
"""
|
||||
from litellm.llms.vertex_ai.gemini.transformation import (
|
||||
_camel_to_snake,
|
||||
_snake_to_camel,
|
||||
)
|
||||
|
||||
_generate_content_config_dict: Dict[str, Any] = {}
|
||||
supported_google_genai_params = (
|
||||
self.get_supported_generate_content_optional_params(model)
|
||||
)
|
||||
# Create a set with both camelCase and snake_case versions for faster lookup
|
||||
supported_params_set = set(supported_google_genai_params)
|
||||
supported_params_set.update(
|
||||
_snake_to_camel(p) for p in supported_google_genai_params
|
||||
)
|
||||
supported_params_set.update(
|
||||
_camel_to_snake(p) for p in supported_google_genai_params if "_" not in p
|
||||
)
|
||||
|
||||
for param, value in generate_content_config_dict.items():
|
||||
# Google GenAI API expects camelCase, so we'll always output in camelCase
|
||||
# Check if param (or its variants) is supported
|
||||
param_snake = _camel_to_snake(param)
|
||||
param_camel = _snake_to_camel(param)
|
||||
|
||||
# Check if param is supported in any format
|
||||
is_supported = (
|
||||
param in supported_google_genai_params
|
||||
or param_snake in supported_google_genai_params
|
||||
or param_camel in supported_google_genai_params
|
||||
)
|
||||
|
||||
if is_supported:
|
||||
# Always output in camelCase for Google GenAI API
|
||||
output_key = param_camel if param != param_camel else param
|
||||
_generate_content_config_dict[output_key] = value
|
||||
return _generate_content_config_dict
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
headers: Optional[dict],
|
||||
model: str,
|
||||
litellm_params: Optional[Union[GenericLiteLLMParams, dict]],
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# Use the passed api_key first, then fall back to litellm_params and environment
|
||||
gemini_api_key = api_key or self._get_google_ai_studio_api_key(
|
||||
dict(litellm_params or {})
|
||||
)
|
||||
if isinstance(gemini_api_key, dict):
|
||||
default_headers.update(gemini_api_key)
|
||||
elif gemini_api_key is not None:
|
||||
default_headers[self.XGOOGLE_API_KEY] = gemini_api_key
|
||||
if headers is not None:
|
||||
default_headers.update(headers)
|
||||
|
||||
return default_headers
|
||||
|
||||
def _get_google_ai_studio_api_key(self, litellm_params: dict) -> Optional[str]:
|
||||
return (
|
||||
litellm_params.pop("api_key", None)
|
||||
or litellm_params.pop("gemini_api_key", None)
|
||||
or get_api_key_from_env()
|
||||
or litellm.api_key
|
||||
)
|
||||
|
||||
def _get_common_auth_components(
|
||||
self,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[Any, Optional[str], Optional[str]]:
|
||||
"""
|
||||
Get common authentication components used by both sync and async methods.
|
||||
|
||||
Returns:
|
||||
Tuple of (vertex_credentials, vertex_project, vertex_location)
|
||||
"""
|
||||
vertex_credentials = self.get_vertex_ai_credentials(litellm_params)
|
||||
vertex_project = self.get_vertex_ai_project(litellm_params)
|
||||
vertex_location = self.get_vertex_ai_location(litellm_params)
|
||||
return vertex_credentials, vertex_project, vertex_location
|
||||
|
||||
def _build_final_headers_and_url(
|
||||
self,
|
||||
model: str,
|
||||
auth_header: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Any,
|
||||
stream: bool,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Build final headers and API URL from auth components.
|
||||
"""
|
||||
gemini_api_key = self._get_google_ai_studio_api_key(litellm_params)
|
||||
|
||||
auth_header, api_base = self._get_token_and_url(
|
||||
model=model,
|
||||
gemini_api_key=gemini_api_key,
|
||||
auth_header=auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
api_base=api_base,
|
||||
should_use_v1beta1_features=True,
|
||||
)
|
||||
|
||||
headers = self.validate_environment(
|
||||
api_key=auth_header,
|
||||
headers=None,
|
||||
model=model,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
return headers, api_base
|
||||
|
||||
def sync_get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Sync version of get_auth_token_and_url.
|
||||
"""
|
||||
(
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
) = self._get_common_auth_components(litellm_params)
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
return self._build_final_headers_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
async def get_auth_token_and_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
litellm_params: dict,
|
||||
stream: bool,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Get the complete URL for the request.
|
||||
|
||||
Args:
|
||||
api_base: Base API URL
|
||||
model: The model name
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Tuple of headers and API base
|
||||
"""
|
||||
(
|
||||
vertex_credentials,
|
||||
vertex_project,
|
||||
vertex_location,
|
||||
) = self._get_common_auth_components(litellm_params)
|
||||
|
||||
_auth_header, vertex_project = await self._ensure_access_token_async(
|
||||
credentials=vertex_credentials,
|
||||
project_id=vertex_project,
|
||||
custom_llm_provider=self.custom_llm_provider,
|
||||
)
|
||||
|
||||
return self._build_final_headers_and_url(
|
||||
model=model,
|
||||
auth_header=_auth_header,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
stream=stream,
|
||||
api_base=api_base,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
def transform_generate_content_request(
|
||||
self,
|
||||
model: str,
|
||||
contents: GenerateContentContentListUnionDict,
|
||||
tools: Optional[ToolConfigDict],
|
||||
generate_content_config_dict: Dict,
|
||||
system_instruction: Optional[Any] = None,
|
||||
) -> dict:
|
||||
from litellm.types.google_genai.main import (
|
||||
GenerateContentConfigDict,
|
||||
GenerateContentRequestDict,
|
||||
)
|
||||
|
||||
typed_generate_content_request = GenerateContentRequestDict(
|
||||
model=model,
|
||||
contents=contents,
|
||||
tools=tools,
|
||||
generationConfig=GenerateContentConfigDict(**generate_content_config_dict),
|
||||
)
|
||||
|
||||
request_dict = cast(dict, typed_generate_content_request)
|
||||
|
||||
if system_instruction is not None:
|
||||
request_dict["systemInstruction"] = system_instruction
|
||||
return request_dict
|
||||
|
||||
def transform_generate_content_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> GenerateContentResponse:
|
||||
"""
|
||||
Transform the raw response from the generate content API.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
raw_response: Raw HTTP response
|
||||
|
||||
Returns:
|
||||
Transformed response data
|
||||
"""
|
||||
from litellm.types.google_genai.main import GenerateContentResponse
|
||||
|
||||
try:
|
||||
response = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming generate content response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
logging_obj.model_call_details["httpx_response"] = raw_response
|
||||
response = self.convert_citation_sources_to_citations(response)
|
||||
|
||||
return GenerateContentResponse(**response)
|
||||
|
||||
def convert_citation_sources_to_citations(self, response: Dict) -> Dict:
|
||||
"""
|
||||
Convert citation sources to citations.
|
||||
API's camelCase citationSources becomes the SDK's snake_case citations
|
||||
"""
|
||||
if "candidates" in response:
|
||||
for candidate in response["candidates"]:
|
||||
if "citationMetadata" in candidate and isinstance(
|
||||
candidate["citationMetadata"], dict
|
||||
):
|
||||
citation_metadata = candidate["citationMetadata"]
|
||||
# Transform citationSources to citations to match expected schema
|
||||
if "citationSources" in citation_metadata:
|
||||
citation_metadata["citations"] = citation_metadata.pop(
|
||||
"citationSources"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,10 @@
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
|
||||
from .transformation import GeminiImageEditConfig
|
||||
from .cost_calculator import cost_calculator
|
||||
|
||||
__all__ = ["GeminiImageEditConfig", "get_gemini_image_edit_config", "cost_calculator"]
|
||||
|
||||
|
||||
def get_gemini_image_edit_config(model: str) -> BaseImageEditConfig:
|
||||
return GeminiImageEditConfig()
|
||||
@@ -0,0 +1,34 @@
|
||||
"""
|
||||
Gemini Image Edit Cost Calculator
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Gemini image edit cost calculator.
|
||||
|
||||
Mirrors image generation pricing: charge per returned image based on
|
||||
model metadata (`output_cost_per_image`).
|
||||
"""
|
||||
model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
|
||||
if not isinstance(image_response, ImageResponse):
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
|
||||
num_images = len(image_response.data or [])
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,211 @@
|
||||
import base64
|
||||
from io import BufferedReader, BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse, OpenAIImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class GeminiImageEditConfig(BaseImageEditConfig):
|
||||
DEFAULT_BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
||||
SUPPORTED_PARAMS: List[str] = ["size"]
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return list(self.SUPPORTED_PARAMS)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
filtered_params = {
|
||||
key: value
|
||||
for key, value in image_edit_optional_params.items()
|
||||
if key in supported_params
|
||||
}
|
||||
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
if "size" in filtered_params:
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(
|
||||
filtered_params["size"] # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
return mapped_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
final_api_key: Optional[str] = api_key or get_secret_str("GEMINI_API_KEY")
|
||||
if not final_api_key:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
|
||||
headers["x-goog-api-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""Gemini uses JSON requests, not multipart/form-data."""
|
||||
return False
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
base_url = (
|
||||
api_base or get_secret_str("GEMINI_API_BASE") or self.DEFAULT_BASE_URL
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
return f"{base_url}/models/{model}:generateContent"
|
||||
|
||||
def transform_image_edit_request( # type: ignore[override]
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict[str, Any],
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict[str, Any], Optional[RequestFiles]]:
|
||||
inline_parts = self._prepare_inline_image_parts(image) if image else []
|
||||
if not inline_parts:
|
||||
raise ValueError("Gemini image edit requires at least one image.")
|
||||
|
||||
# Build parts list with image and prompt (if provided)
|
||||
parts = inline_parts.copy()
|
||||
if prompt is not None and prompt != "":
|
||||
parts.append({"text": prompt})
|
||||
|
||||
contents = [
|
||||
{
|
||||
"parts": parts,
|
||||
}
|
||||
]
|
||||
|
||||
request_body: Dict[str, Any] = {"contents": contents}
|
||||
|
||||
generation_config: Dict[str, Any] = {}
|
||||
|
||||
if "aspectRatio" in image_edit_optional_request_params:
|
||||
# Move aspectRatio into imageConfig inside generationConfig
|
||||
if "imageConfig" not in generation_config:
|
||||
generation_config["imageConfig"] = {}
|
||||
generation_config["imageConfig"][
|
||||
"aspectRatio"
|
||||
] = image_edit_optional_request_params["aspectRatio"]
|
||||
|
||||
if generation_config:
|
||||
request_body["generationConfig"] = generation_config
|
||||
|
||||
empty_files = cast(RequestFiles, [])
|
||||
return request_body, empty_files
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: Any,
|
||||
) -> ImageResponse:
|
||||
model_response = ImageResponse()
|
||||
try:
|
||||
response_json = raw_response.json()
|
||||
except Exception as exc:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image edit response: {exc}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
candidates = response_json.get("candidates", [])
|
||||
data_list: List[ImageObject] = []
|
||||
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
inline_data = part.get("inlineData")
|
||||
if inline_data and inline_data.get("data"):
|
||||
data_list.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.data = cast(List[OpenAIImage], data_list)
|
||||
return model_response
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _prepare_inline_image_parts(
|
||||
self, image: Union[FileTypes, List[FileTypes]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
images: List[FileTypes]
|
||||
if isinstance(image, list):
|
||||
images = image
|
||||
else:
|
||||
images = [image]
|
||||
|
||||
inline_parts: List[Dict[str, Any]] = []
|
||||
for img in images:
|
||||
if img is None:
|
||||
continue
|
||||
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(img)
|
||||
image_bytes = self._read_all_bytes(img)
|
||||
inline_parts.append(
|
||||
{
|
||||
"inlineData": {
|
||||
"mimeType": mime_type,
|
||||
"data": base64.b64encode(image_bytes).decode("utf-8"),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return inline_parts
|
||||
|
||||
def _read_all_bytes(self, image: FileTypes) -> bytes:
|
||||
if isinstance(image, bytes):
|
||||
return image
|
||||
if isinstance(image, BytesIO):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
if isinstance(image, BufferedReader):
|
||||
current_pos = image.tell()
|
||||
image.seek(0)
|
||||
data = image.read()
|
||||
image.seek(current_pos)
|
||||
return data
|
||||
raise ValueError("Unsupported image type for Gemini image edit.")
|
||||
@@ -0,0 +1,13 @@
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
|
||||
from .transformation import GoogleImageGenConfig
|
||||
|
||||
__all__ = [
|
||||
"GoogleImageGenConfig",
|
||||
]
|
||||
|
||||
|
||||
def get_gemini_image_generation_config(model: str) -> BaseImageGenerationConfig:
|
||||
return GoogleImageGenConfig()
|
||||
@@ -0,0 +1,44 @@
|
||||
"""
|
||||
Google AI Image Generation Cost Calculator
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import (
|
||||
calculate_image_response_cost_from_usage,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: Any,
|
||||
) -> float:
|
||||
"""
|
||||
Google AI Image Generation Cost Calculator
|
||||
"""
|
||||
_model_info = litellm.get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
|
||||
if isinstance(image_response, ImageResponse):
|
||||
token_based_cost = calculate_image_response_cost_from_usage(
|
||||
model=model,
|
||||
image_response=image_response,
|
||||
custom_llm_provider="gemini",
|
||||
)
|
||||
if token_based_cost is not None:
|
||||
return token_based_cost
|
||||
|
||||
output_cost_per_image: float = _model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if isinstance(image_response, ImageResponse):
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
else:
|
||||
raise ValueError(
|
||||
f"image_response must be of type ImageResponse got type={type(image_response)}"
|
||||
)
|
||||
@@ -0,0 +1,269 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_generation.transformation import (
|
||||
BaseImageGenerationConfig,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.gemini import GeminiImageGenerationRequest
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIImageGenerationOptionalParams,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
ImageObject,
|
||||
ImageResponse,
|
||||
ImageUsage,
|
||||
ImageUsageInputTokensDetails,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class GoogleImageGenConfig(BaseImageGenerationConfig):
|
||||
DEFAULT_BASE_URL: str = "https://generativelanguage.googleapis.com/v1beta"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAIImageGenerationOptionalParams]:
|
||||
"""
|
||||
Google AI Imagen API supported parameters
|
||||
https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
return ["n", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
mapped_params = {}
|
||||
|
||||
for k, v in non_default_params.items():
|
||||
if k not in optional_params.keys():
|
||||
if k in supported_params:
|
||||
# Map OpenAI parameters to Google format
|
||||
if k == "n":
|
||||
mapped_params["sampleCount"] = v
|
||||
elif k == "size":
|
||||
# Map OpenAI size format to Google aspectRatio
|
||||
mapped_params["aspectRatio"] = self._map_size_to_aspect_ratio(v)
|
||||
else:
|
||||
mapped_params[k] = v
|
||||
return mapped_params
|
||||
|
||||
def _map_size_to_aspect_ratio(self, size: str) -> str:
|
||||
"""
|
||||
https://ai.google.dev/gemini-api/docs/image-generation
|
||||
|
||||
"""
|
||||
aspect_ratio_map = {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1280x896": "4:3",
|
||||
"896x1280": "3:4",
|
||||
}
|
||||
return aspect_ratio_map.get(size, "1:1")
|
||||
|
||||
def _transform_image_usage(self, usage_metadata: dict) -> ImageUsage:
|
||||
"""
|
||||
Transform Gemini usageMetadata to ImageUsage format
|
||||
"""
|
||||
input_tokens_details = ImageUsageInputTokensDetails(
|
||||
image_tokens=0,
|
||||
text_tokens=0,
|
||||
)
|
||||
|
||||
# Extract detailed token counts from promptTokensDetails
|
||||
tokens_details = usage_metadata.get("promptTokensDetails", [])
|
||||
for details in tokens_details:
|
||||
if isinstance(details, dict):
|
||||
modality = details.get("modality")
|
||||
token_count = details.get("tokenCount", 0)
|
||||
if modality == "TEXT":
|
||||
input_tokens_details.text_tokens = token_count
|
||||
elif modality == "IMAGE":
|
||||
input_tokens_details.image_tokens = token_count
|
||||
|
||||
return ImageUsage(
|
||||
input_tokens=usage_metadata.get("promptTokenCount", 0),
|
||||
input_tokens_details=input_tokens_details,
|
||||
output_tokens=usage_metadata.get("candidatesTokenCount", 0),
|
||||
total_tokens=usage_metadata.get("totalTokenCount", 0),
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
|
||||
Gemini 2.5 Flash Image Preview: :generateContent
|
||||
Other Imagen models: :predict
|
||||
"""
|
||||
complete_url: str = (
|
||||
api_base or get_secret_str("GEMINI_API_BASE") or self.DEFAULT_BASE_URL
|
||||
)
|
||||
|
||||
complete_url = complete_url.rstrip("/")
|
||||
|
||||
# Gemini Flash Image Preview models use generateContent endpoint
|
||||
if "gemini" in model:
|
||||
complete_url = f"{complete_url}/models/{model}:generateContent"
|
||||
else:
|
||||
# All other Imagen models use predict endpoint
|
||||
complete_url = f"{complete_url}/models/{model}:predict"
|
||||
|
||||
return complete_url
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
final_api_key: Optional[str] = api_key or get_secret_str("GEMINI_API_KEY")
|
||||
if not final_api_key:
|
||||
raise ValueError("GEMINI_API_KEY is not set")
|
||||
|
||||
headers["x-goog-api-key"] = final_api_key
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def transform_image_generation_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the image generation request to Gemini format
|
||||
|
||||
For Gemini 2.5 Flash Image Preview, use the standard Gemini format with response_modalities:
|
||||
{
|
||||
"contents": [
|
||||
{
|
||||
"parts": [
|
||||
{"text": "Generate an image of..."}
|
||||
]
|
||||
}
|
||||
],
|
||||
"generationConfig": {
|
||||
"response_modalities": ["IMAGE", "TEXT"]
|
||||
}
|
||||
}
|
||||
"""
|
||||
# For Gemini Flash Image Preview models, use standard Gemini format
|
||||
if "gemini" in model:
|
||||
request_body: dict = {
|
||||
"contents": [{"parts": [{"text": prompt}]}],
|
||||
"generationConfig": {"response_modalities": ["IMAGE", "TEXT"]},
|
||||
}
|
||||
return request_body
|
||||
else:
|
||||
# For other Imagen models, use the original Imagen format
|
||||
from litellm.types.llms.gemini import (
|
||||
GeminiImageGenerationInstance,
|
||||
GeminiImageGenerationParameters,
|
||||
)
|
||||
|
||||
request_body_obj: GeminiImageGenerationRequest = (
|
||||
GeminiImageGenerationRequest(
|
||||
instances=[GeminiImageGenerationInstance(prompt=prompt)],
|
||||
parameters=GeminiImageGenerationParameters(**optional_params),
|
||||
)
|
||||
)
|
||||
return request_body_obj.model_dump(exclude_none=True)
|
||||
|
||||
def transform_image_generation_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ImageResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Google AI Imagen response to litellm ImageResponse format
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error transforming image generation response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Handle different response formats based on model
|
||||
if "gemini" in model:
|
||||
# Gemini Flash Image Preview models return in candidates format
|
||||
candidates = response_data.get("candidates", [])
|
||||
for candidate in candidates:
|
||||
content = candidate.get("content", {})
|
||||
parts = content.get("parts", [])
|
||||
for part in parts:
|
||||
# Look for inlineData with image
|
||||
if "inlineData" in part:
|
||||
inline_data = part["inlineData"]
|
||||
if "data" in inline_data:
|
||||
thought_sig = part.get("thoughtSignature")
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=inline_data["data"],
|
||||
url=None,
|
||||
provider_specific_fields={
|
||||
"thought_signature": thought_sig
|
||||
}
|
||||
if thought_sig
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract usage metadata for Gemini models
|
||||
if "usageMetadata" in response_data:
|
||||
model_response.usage = self._transform_image_usage(
|
||||
response_data["usageMetadata"]
|
||||
)
|
||||
else:
|
||||
# Original Imagen format - predictions with generated images
|
||||
predictions = response_data.get("predictions", [])
|
||||
for prediction in predictions:
|
||||
# Google AI returns base64 encoded images in the prediction
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=prediction.get("bytesBase64Encoded", None),
|
||||
url=None, # Google AI returns base64, not URLs
|
||||
)
|
||||
)
|
||||
return model_response
|
||||
@@ -0,0 +1,7 @@
|
||||
"""Google AI Studio Interactions API implementation."""
|
||||
|
||||
from litellm.llms.gemini.interactions.transformation import (
|
||||
GoogleAIStudioInteractionsConfig,
|
||||
)
|
||||
|
||||
__all__ = ["GoogleAIStudioInteractionsConfig"]
|
||||
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Google AI Studio Interactions API configuration.
|
||||
|
||||
Per OpenAPI spec (https://ai.google.dev/static/api/interactions.openapi.json):
|
||||
- Create: POST https://generativelanguage.googleapis.com/{api_version}/interactions
|
||||
- Get: GET https://generativelanguage.googleapis.com/{api_version}/interactions/{interaction_id}
|
||||
- Delete: DELETE https://generativelanguage.googleapis.com/{api_version}/interactions/{interaction_id}
|
||||
|
||||
This is a thin wrapper - no transformation needed since we follow the spec directly.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.llms.base_llm.interactions.transformation import BaseInteractionsAPIConfig
|
||||
from litellm.llms.gemini.common_utils import GeminiError, GeminiModelInfo
|
||||
from litellm.types.interactions import (
|
||||
CancelInteractionResult,
|
||||
DeleteInteractionResult,
|
||||
InteractionInput,
|
||||
InteractionsAPIOptionalRequestParams,
|
||||
InteractionsAPIResponse,
|
||||
InteractionsAPIStreamingResponse,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class GoogleAIStudioInteractionsConfig(BaseInteractionsAPIConfig):
|
||||
"""
|
||||
Configuration for Google AI Studio Interactions API.
|
||||
|
||||
Minimal config - we follow the OpenAPI spec directly with no transformation.
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.GEMINI
|
||||
|
||||
@property
|
||||
def api_version(self) -> str:
|
||||
return "v1beta"
|
||||
|
||||
def get_supported_params(self, model: str) -> List[str]:
|
||||
"""Per OpenAPI spec CreateModelInteractionParams."""
|
||||
return [
|
||||
"model",
|
||||
"agent",
|
||||
"input",
|
||||
"tools",
|
||||
"system_instruction",
|
||||
"generation_config",
|
||||
"stream",
|
||||
"store",
|
||||
"background",
|
||||
"response_modalities",
|
||||
"response_format",
|
||||
"response_mime_type",
|
||||
"previous_interaction_id",
|
||||
]
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> dict:
|
||||
"""Google AI Studio uses API key in query params, not headers."""
|
||||
headers = headers or {}
|
||||
headers["Content-Type"] = "application/json"
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: Optional[str],
|
||||
agent: Optional[str] = None,
|
||||
litellm_params: Optional[dict] = None,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""POST /{api_version}/interactions"""
|
||||
litellm_params = litellm_params or {}
|
||||
api_base = GeminiModelInfo.get_api_base(api_base)
|
||||
api_key = GeminiModelInfo.get_api_key(litellm_params.get("api_key"))
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"Google API key is required. Set GOOGLE_API_KEY or GEMINI_API_KEY environment variable."
|
||||
)
|
||||
|
||||
query_params = f"key={api_key}"
|
||||
if stream:
|
||||
query_params += "&alt=sse"
|
||||
|
||||
return f"{api_base}/{self.api_version}/interactions?{query_params}"
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: Optional[str],
|
||||
agent: Optional[str],
|
||||
input: Optional[InteractionInput],
|
||||
optional_params: InteractionsAPIOptionalRequestParams,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
"""
|
||||
Build request body per OpenAPI spec - minimal transformation.
|
||||
"""
|
||||
request_body: Dict[str, Any] = {}
|
||||
|
||||
# Model or Agent (one required)
|
||||
if model:
|
||||
request_body["model"] = GeminiModelInfo.get_base_model(model) or model
|
||||
elif agent:
|
||||
request_body["agent"] = agent
|
||||
else:
|
||||
raise ValueError("Either 'model' or 'agent' must be provided")
|
||||
|
||||
# Input
|
||||
if input is not None:
|
||||
request_body["input"] = input
|
||||
|
||||
# Pass through optional params directly (they match the spec)
|
||||
optional_keys = [
|
||||
"tools",
|
||||
"system_instruction",
|
||||
"generation_config",
|
||||
"stream",
|
||||
"store",
|
||||
"background",
|
||||
"response_modalities",
|
||||
"response_format",
|
||||
"response_mime_type",
|
||||
"previous_interaction_id",
|
||||
]
|
||||
for key in optional_keys:
|
||||
if optional_params.get(key) is not None:
|
||||
request_body[key] = optional_params[key]
|
||||
|
||||
return request_body
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIResponse:
|
||||
"""Parse response - it already matches our response type."""
|
||||
try:
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
raw_json = raw_response.json()
|
||||
except Exception:
|
||||
raise GeminiError(
|
||||
message=raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
verbose_logger.debug("Google AI Interactions response: %s", raw_json)
|
||||
|
||||
response = InteractionsAPIResponse(**raw_json)
|
||||
response._hidden_params["headers"] = dict(raw_response.headers)
|
||||
response._hidden_params["additional_headers"] = process_response_headers(
|
||||
dict(raw_response.headers)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def transform_streaming_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
parsed_chunk: dict,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIStreamingResponse:
|
||||
"""Parse streaming chunk."""
|
||||
verbose_logger.debug("Google AI Interactions streaming chunk: %s", parsed_chunk)
|
||||
return InteractionsAPIStreamingResponse(**parsed_chunk)
|
||||
|
||||
# GET / DELETE / CANCEL - just build URLs, responses match spec directly
|
||||
|
||||
def transform_get_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""GET /{api_version}/interactions/{interaction_id}"""
|
||||
resolved_api_base = GeminiModelInfo.get_api_base(api_base)
|
||||
api_key = GeminiModelInfo.get_api_key(litellm_params.api_key)
|
||||
if not api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
return (
|
||||
f"{resolved_api_base}/{self.api_version}/interactions/{interaction_id}?key={api_key}",
|
||||
{},
|
||||
)
|
||||
|
||||
def transform_get_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> InteractionsAPIResponse:
|
||||
try:
|
||||
raw_json = raw_response.json()
|
||||
except Exception:
|
||||
raise GeminiError(
|
||||
message=raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
response = InteractionsAPIResponse(**raw_json)
|
||||
response._hidden_params["headers"] = dict(raw_response.headers)
|
||||
return response
|
||||
|
||||
def transform_delete_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""DELETE /{api_version}/interactions/{interaction_id}"""
|
||||
resolved_api_base = GeminiModelInfo.get_api_base(api_base)
|
||||
api_key = GeminiModelInfo.get_api_key(litellm_params.api_key)
|
||||
if not api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
return (
|
||||
f"{resolved_api_base}/{self.api_version}/interactions/{interaction_id}?key={api_key}",
|
||||
{},
|
||||
)
|
||||
|
||||
def transform_delete_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
interaction_id: str,
|
||||
) -> DeleteInteractionResult:
|
||||
if 200 <= raw_response.status_code < 300:
|
||||
return DeleteInteractionResult(success=True, id=interaction_id)
|
||||
raise GeminiError(
|
||||
message=raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
|
||||
def transform_cancel_interaction_request(
|
||||
self,
|
||||
interaction_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""POST /{api_version}/interactions/{interaction_id}:cancel (if supported)"""
|
||||
resolved_api_base = GeminiModelInfo.get_api_base(api_base)
|
||||
api_key = GeminiModelInfo.get_api_key(litellm_params.api_key)
|
||||
if not api_key:
|
||||
raise ValueError("Google API key is required")
|
||||
return (
|
||||
f"{resolved_api_base}/{self.api_version}/interactions/{interaction_id}:cancel?key={api_key}",
|
||||
{},
|
||||
)
|
||||
|
||||
def transform_cancel_interaction_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> CancelInteractionResult:
|
||||
try:
|
||||
raw_json = raw_response.json()
|
||||
except Exception:
|
||||
raise GeminiError(
|
||||
message=raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
headers=dict(raw_response.headers),
|
||||
)
|
||||
return CancelInteractionResult(**raw_json)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,5 @@
|
||||
"""Gemini File Search Vector Store module."""
|
||||
|
||||
from .transformation import GeminiVectorStoreConfig
|
||||
|
||||
__all__ = ["GeminiVectorStoreConfig"]
|
||||
@@ -0,0 +1,355 @@
|
||||
"""
|
||||
Gemini File Search Vector Store Transformation Layer.
|
||||
|
||||
Implements the transformation between LiteLLM's unified vector store API
|
||||
and Google Gemini's File Search API.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.gemini.common_utils import (
|
||||
GeminiError,
|
||||
GeminiModelInfo,
|
||||
get_api_key_from_env,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
VECTOR_STORE_OPENAI_PARAMS,
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreCreateOptionalRequestParams,
|
||||
VectorStoreCreateResponse,
|
||||
VectorStoreFileCounts,
|
||||
VectorStoreIndexEndpoints,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class GeminiVectorStoreConfig(BaseVectorStoreConfig):
|
||||
"""
|
||||
Vector store configuration for Google Gemini File Search.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.model_info = GeminiModelInfo()
|
||||
self._cached_api_key: Optional[str] = None
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
"""Gemini uses API key in query params, not headers."""
|
||||
return {}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
"""
|
||||
Gemini File Search endpoints.
|
||||
|
||||
Note: Search is done via generateContent with file_search tool,
|
||||
not a dedicated search endpoint.
|
||||
"""
|
||||
return {
|
||||
"read": [("POST", "/models/{model}:generateContent")],
|
||||
"write": [("POST", "/fileSearchStores")],
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[VECTOR_STORE_OPENAI_PARAMS]:
|
||||
"""Supported parameters for Gemini File Search."""
|
||||
return ["max_num_results", "filters"]
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""Validate and set up headers for Gemini API."""
|
||||
headers = headers or {}
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
if litellm_params:
|
||||
api_key = litellm_params.get("api_key") or get_api_key_from_env()
|
||||
if api_key:
|
||||
self._cached_api_key = api_key
|
||||
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
|
||||
"""
|
||||
Get the complete base URL for Gemini API.
|
||||
|
||||
Note: This returns the base URL WITHOUT the API key.
|
||||
The API key will be appended to specific endpoint URLs in the transform methods.
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = GeminiModelInfo.get_api_base()
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError("GEMINI_API_BASE is not set")
|
||||
|
||||
# Ensure we're using the v1beta version for File Search
|
||||
api_version = "v1beta"
|
||||
return f"{api_base}/{api_version}"
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> GeminiError:
|
||||
"""Return Gemini-specific error class."""
|
||||
return GeminiError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform search request to Gemini's generateContent format.
|
||||
|
||||
Gemini File Search works by calling generateContent with a file_search tool.
|
||||
"""
|
||||
# Convert query list to single string if needed
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
# Get model from litellm_params or use default
|
||||
# Note: File Search requires gemini-2.5-flash or later
|
||||
model = litellm_params.get("model") or "gemini-2.5-flash"
|
||||
if model and model.startswith("gemini/"):
|
||||
model = model.replace("gemini/", "")
|
||||
|
||||
# Get API key - Gemini requires it as a query parameter
|
||||
api_key = litellm_params.get("api_key") or GeminiModelInfo.get_api_key()
|
||||
if not api_key:
|
||||
raise ValueError("GEMINI_API_KEY or GOOGLE_API_KEY is required")
|
||||
|
||||
# Build the URL for generateContent with API key
|
||||
url = f"{api_base}/models/{model}:generateContent?key={api_key}"
|
||||
|
||||
# Build file_search tool configuration (using snake_case as per Gemini docs)
|
||||
file_search_config: Dict[str, Any] = {
|
||||
"file_search_store_names": [vector_store_id]
|
||||
}
|
||||
|
||||
# Add metadata filter if provided
|
||||
metadata_filter = vector_store_search_optional_params.get("filters")
|
||||
if metadata_filter:
|
||||
# Convert to Gemini filter syntax if it's a dict
|
||||
if isinstance(metadata_filter, dict):
|
||||
# Simple conversion - may need more sophisticated mapping
|
||||
filter_parts = []
|
||||
for key, value in metadata_filter.items():
|
||||
if isinstance(value, str):
|
||||
filter_parts.append(f'{key} = "{value}"')
|
||||
else:
|
||||
filter_parts.append(f"{key} = {value}")
|
||||
file_search_config["metadata_filter"] = " AND ".join(filter_parts)
|
||||
else:
|
||||
file_search_config["metadata_filter"] = metadata_filter
|
||||
|
||||
# Build request body
|
||||
request_body: Dict[str, Any] = {
|
||||
"contents": [{"parts": [{"text": query}]}],
|
||||
"tools": [{"file_search": file_search_config}],
|
||||
}
|
||||
|
||||
# Add max_num_results if specified
|
||||
max_results = vector_store_search_optional_params.get("max_num_results")
|
||||
if max_results:
|
||||
# This might need to be added to generationConfig or tool config
|
||||
# depending on Gemini's API requirements
|
||||
request_body.setdefault("generationConfig", {})["candidateCount"] = 1
|
||||
|
||||
litellm_logging_obj.model_call_details["query"] = query
|
||||
litellm_logging_obj.model_call_details["vector_store_id"] = vector_store_id
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
"""
|
||||
Transform Gemini's generateContent response to standard format.
|
||||
|
||||
Extracts grounding metadata and citations from the response.
|
||||
"""
|
||||
try:
|
||||
response_data = response.json()
|
||||
results: List[VectorStoreSearchResult] = []
|
||||
|
||||
# Extract candidates and grounding metadata
|
||||
candidates = response_data.get("candidates", [])
|
||||
|
||||
for candidate in candidates:
|
||||
grounding_metadata = candidate.get("groundingMetadata", {})
|
||||
grounding_chunks = grounding_metadata.get("groundingChunks", [])
|
||||
|
||||
# Process each grounding chunk
|
||||
for chunk in grounding_chunks:
|
||||
retrieved_context = chunk.get("retrievedContext")
|
||||
|
||||
if retrieved_context:
|
||||
# This is from file search
|
||||
text = retrieved_context.get("text", "")
|
||||
uri = retrieved_context.get("uri", "")
|
||||
title = retrieved_context.get("title", "")
|
||||
|
||||
# Extract file_id from URI if available
|
||||
file_id = uri if uri else None
|
||||
|
||||
results.append(
|
||||
VectorStoreSearchResult(
|
||||
score=None, # Gemini doesn't provide explicit scores
|
||||
content=[
|
||||
VectorStoreResultContent(text=text, type="text")
|
||||
],
|
||||
file_id=file_id,
|
||||
filename=title if title else None,
|
||||
attributes={
|
||||
"uri": uri,
|
||||
"title": title,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Also extract from grounding supports for more detailed citations
|
||||
grounding_supports = grounding_metadata.get("groundingSupports", [])
|
||||
for support in grounding_supports:
|
||||
segment = support.get("segment", {})
|
||||
text = segment.get("text", "")
|
||||
|
||||
grounding_chunk_indices = support.get("groundingChunkIndices", [])
|
||||
confidence_scores = support.get("confidenceScores", [])
|
||||
|
||||
# Use first confidence score as relevance score
|
||||
score = confidence_scores[0] if confidence_scores else None
|
||||
|
||||
# Only add if we have meaningful text and it's not a duplicate
|
||||
if text:
|
||||
already_exists = False
|
||||
for record in results:
|
||||
contents = record.get("content") or []
|
||||
if contents and contents[0].get("text") == text:
|
||||
already_exists = True
|
||||
break
|
||||
if already_exists:
|
||||
continue
|
||||
results.append(
|
||||
VectorStoreSearchResult(
|
||||
score=score,
|
||||
content=[
|
||||
VectorStoreResultContent(text=text, type="text")
|
||||
],
|
||||
attributes={
|
||||
"grounding_chunk_indices": grounding_chunk_indices,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
query = litellm_logging_obj.model_call_details.get("query", "")
|
||||
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=query,
|
||||
data=results,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Failed to parse Gemini response: {str(e)}",
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params: VectorStoreCreateOptionalRequestParams,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform create request to Gemini's fileSearchStores format.
|
||||
"""
|
||||
url = f"{api_base}/fileSearchStores"
|
||||
|
||||
# Append API key as query parameter (required by Gemini)
|
||||
api_key = self._cached_api_key or get_api_key_from_env()
|
||||
if api_key:
|
||||
url = f"{url}?key={api_key}"
|
||||
|
||||
request_body: Dict[str, Any] = {}
|
||||
|
||||
# Add display name if provided
|
||||
name = vector_store_create_optional_params.get("name")
|
||||
if name:
|
||||
request_body["displayName"] = name
|
||||
|
||||
return url, request_body
|
||||
|
||||
def transform_create_vector_store_response(
|
||||
self, response: httpx.Response
|
||||
) -> VectorStoreCreateResponse:
|
||||
"""
|
||||
Transform Gemini's fileSearchStore response to standard format.
|
||||
"""
|
||||
try:
|
||||
response_data = response.json()
|
||||
|
||||
# Extract store name (format: fileSearchStores/xxxxxxx)
|
||||
store_name = response_data.get("name", "")
|
||||
display_name = response_data.get("displayName", "")
|
||||
create_time = response_data.get("createTime", "")
|
||||
|
||||
# Convert ISO timestamp to Unix timestamp
|
||||
import datetime
|
||||
|
||||
created_at = None
|
||||
if create_time:
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(
|
||||
create_time.replace("Z", "+00:00")
|
||||
)
|
||||
created_at = int(dt.timestamp())
|
||||
except Exception:
|
||||
created_at = None
|
||||
|
||||
return VectorStoreCreateResponse(
|
||||
id=store_name,
|
||||
object="vector_store",
|
||||
created_at=created_at or 0,
|
||||
name=display_name,
|
||||
bytes=0, # Gemini doesn't provide size info on creation
|
||||
file_counts=VectorStoreFileCounts(
|
||||
in_progress=0,
|
||||
completed=0,
|
||||
failed=0,
|
||||
cancelled=0,
|
||||
total=0,
|
||||
),
|
||||
status="completed",
|
||||
expires_after=None,
|
||||
expires_at=None,
|
||||
last_active_at=None,
|
||||
metadata=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Failed to parse Gemini create response: {str(e)}",
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
@@ -0,0 +1,4 @@
|
||||
# Gemini Video Generation Support
|
||||
from .transformation import GeminiVideoConfig
|
||||
|
||||
__all__ = ["GeminiVideoConfig"]
|
||||
@@ -0,0 +1,536 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
|
||||
import base64
|
||||
|
||||
import httpx
|
||||
from httpx._types import RequestFiles
|
||||
|
||||
from litellm.types.videos.main import VideoCreateOptionalRequestParams, VideoObject
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.videos.utils import (
|
||||
encode_video_id_with_provider,
|
||||
extract_original_video_id,
|
||||
)
|
||||
from litellm.images.utils import ImageEditRequestUtils
|
||||
import litellm
|
||||
from litellm.types.llms.gemini import (
|
||||
GeminiLongRunningOperationResponse,
|
||||
GeminiVideoGenerationInstance,
|
||||
GeminiVideoGenerationParameters,
|
||||
GeminiVideoGenerationRequest,
|
||||
)
|
||||
from litellm.constants import DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
|
||||
from litellm.llms.base_llm.videos.transformation import BaseVideoConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from ...base_llm.chat.transformation import BaseLLMException as _BaseLLMException
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
BaseLLMException = _BaseLLMException
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
BaseLLMException = Any
|
||||
|
||||
|
||||
def _convert_image_to_gemini_format(image_file) -> Dict[str, str]:
|
||||
"""
|
||||
Convert image file to Gemini format with base64 encoding and MIME type.
|
||||
|
||||
Args:
|
||||
image_file: File-like object opened in binary mode (e.g., open("path", "rb"))
|
||||
|
||||
Returns:
|
||||
Dict with bytesBase64Encoded and mimeType
|
||||
"""
|
||||
mime_type = ImageEditRequestUtils.get_image_content_type(image_file)
|
||||
|
||||
if hasattr(image_file, "seek"):
|
||||
image_file.seek(0)
|
||||
image_bytes = image_file.read()
|
||||
base64_encoded = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
return {"bytesBase64Encoded": base64_encoded, "mimeType": mime_type}
|
||||
|
||||
|
||||
class GeminiVideoConfig(BaseVideoConfig):
|
||||
"""
|
||||
Configuration class for Gemini (Veo) video generation.
|
||||
|
||||
Veo uses a long-running operation model:
|
||||
1. POST to :predictLongRunning returns operation name
|
||||
2. Poll operation until done=true
|
||||
3. Extract video URI from response
|
||||
4. Download video using file API
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get the list of supported OpenAI parameters for Veo video generation.
|
||||
Veo supports minimal parameters compared to OpenAI.
|
||||
"""
|
||||
return ["model", "prompt", "input_reference", "seconds", "size"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
video_create_optional_params: VideoCreateOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Map OpenAI-style parameters to Veo format.
|
||||
|
||||
Mappings:
|
||||
- prompt → prompt
|
||||
- input_reference → image
|
||||
- size → aspectRatio (e.g., "1280x720" → "16:9")
|
||||
- seconds → durationSeconds (defaults to 4 seconds if not provided)
|
||||
|
||||
All other params are passed through as-is to support Gemini-specific parameters.
|
||||
"""
|
||||
mapped_params: Dict[str, Any] = {}
|
||||
|
||||
# Get supported OpenAI params (exclude "model" and "prompt" which are handled separately)
|
||||
supported_openai_params = self.get_supported_openai_params(model)
|
||||
openai_params_to_map = {
|
||||
param
|
||||
for param in supported_openai_params
|
||||
if param not in {"model", "prompt"}
|
||||
}
|
||||
|
||||
# Map input_reference to image
|
||||
if "input_reference" in video_create_optional_params:
|
||||
mapped_params["image"] = video_create_optional_params["input_reference"]
|
||||
|
||||
# Map size to aspectRatio
|
||||
if "size" in video_create_optional_params:
|
||||
size = video_create_optional_params["size"]
|
||||
if size is not None:
|
||||
aspect_ratio = self._convert_size_to_aspect_ratio(size)
|
||||
if aspect_ratio:
|
||||
mapped_params["aspectRatio"] = aspect_ratio
|
||||
|
||||
# Map seconds to durationSeconds, default to 4 seconds (matching OpenAI)
|
||||
if "seconds" in video_create_optional_params:
|
||||
seconds = video_create_optional_params["seconds"]
|
||||
try:
|
||||
duration = int(seconds) if isinstance(seconds, str) else seconds
|
||||
if duration is not None:
|
||||
mapped_params["durationSeconds"] = duration
|
||||
except (ValueError, TypeError):
|
||||
# If conversion fails, use default
|
||||
pass
|
||||
|
||||
# Pass through any other params that weren't mapped (Gemini-specific params)
|
||||
for key, value in video_create_optional_params.items():
|
||||
if key not in openai_params_to_map and key not in mapped_params:
|
||||
mapped_params[key] = value
|
||||
|
||||
return mapped_params
|
||||
|
||||
def _convert_size_to_aspect_ratio(self, size: str) -> Optional[str]:
|
||||
"""
|
||||
Convert OpenAI size format to Veo aspectRatio format.
|
||||
|
||||
https://cloud.google.com/vertex-ai/generative-ai/docs/image/generate-videos
|
||||
|
||||
Supported aspect ratios: 9:16 (portrait), 16:9 (landscape)
|
||||
"""
|
||||
if not size:
|
||||
return None
|
||||
|
||||
aspect_ratio_map = {
|
||||
"1280x720": "16:9",
|
||||
"1920x1080": "16:9",
|
||||
"720x1280": "9:16",
|
||||
"1080x1920": "9:16",
|
||||
}
|
||||
|
||||
return aspect_ratio_map.get(size, "16:9")
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
litellm_params: Optional[GenericLiteLLMParams] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and add Gemini API key to headers.
|
||||
Gemini uses x-goog-api-key header for authentication.
|
||||
"""
|
||||
# Use api_key from litellm_params if available, otherwise fall back to other sources
|
||||
if litellm_params and litellm_params.api_key:
|
||||
api_key = api_key or litellm_params.api_key
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or get_secret_str("GOOGLE_API_KEY")
|
||||
or get_secret_str("GEMINI_API_KEY")
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"GEMINI_API_KEY or GOOGLE_API_KEY is required for Veo video generation. "
|
||||
"Set it via environment variable or pass it as api_key parameter."
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"x-goog-api-key": api_key,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Veo video generation.
|
||||
For video creation: returns full URL with :predictLongRunning
|
||||
For status/delete: returns base URL only
|
||||
"""
|
||||
if api_base is None:
|
||||
api_base = (
|
||||
get_secret_str("GEMINI_API_BASE")
|
||||
or "https://generativelanguage.googleapis.com"
|
||||
)
|
||||
|
||||
if not model or model == "":
|
||||
return api_base.rstrip("/")
|
||||
|
||||
model_name = model.replace("gemini/", "")
|
||||
url = f"{api_base.rstrip('/')}/v1beta/models/{model_name}:predictLongRunning"
|
||||
|
||||
return url
|
||||
|
||||
def transform_video_create_request(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
video_create_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, RequestFiles, str]:
|
||||
"""
|
||||
Transform the video creation request for Veo API.
|
||||
|
||||
Veo expects:
|
||||
{
|
||||
"instances": [
|
||||
{
|
||||
"prompt": "A cat playing with a ball of yarn"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"aspectRatio": "16:9",
|
||||
"durationSeconds": 8,
|
||||
"resolution": "720p"
|
||||
}
|
||||
}
|
||||
"""
|
||||
instance = GeminiVideoGenerationInstance(prompt=prompt)
|
||||
|
||||
params_copy = video_create_optional_request_params.copy()
|
||||
|
||||
if "image" in params_copy and params_copy["image"] is not None:
|
||||
image_data = _convert_image_to_gemini_format(params_copy["image"])
|
||||
params_copy["image"] = image_data
|
||||
|
||||
parameters = GeminiVideoGenerationParameters(**params_copy)
|
||||
|
||||
request_body_obj = GeminiVideoGenerationRequest(
|
||||
instances=[instance], parameters=parameters
|
||||
)
|
||||
|
||||
request_data = request_body_obj.model_dump(exclude_none=True)
|
||||
|
||||
return request_data, [], api_base
|
||||
|
||||
def transform_video_create_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
request_data: Optional[Dict] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the Veo video creation response.
|
||||
|
||||
Veo returns:
|
||||
{
|
||||
"name": "operations/generate_1234567890",
|
||||
"metadata": {...},
|
||||
"done": false,
|
||||
"error": {...}
|
||||
}
|
||||
|
||||
We return this as a VideoObject with:
|
||||
- id: operation name (used for polling)
|
||||
- status: "processing"
|
||||
- usage: includes duration_seconds for cost calculation
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
|
||||
# Parse response using Pydantic model for type safety
|
||||
try:
|
||||
operation_response = GeminiLongRunningOperationResponse(**response_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse operation response: {e}")
|
||||
|
||||
operation_name = operation_response.name
|
||||
if not operation_name:
|
||||
raise ValueError(f"No operation name in Veo response: {response_data}")
|
||||
|
||||
if custom_llm_provider:
|
||||
video_id = encode_video_id_with_provider(
|
||||
operation_name, custom_llm_provider, model
|
||||
)
|
||||
else:
|
||||
video_id = operation_name
|
||||
|
||||
video_obj = VideoObject(
|
||||
id=video_id,
|
||||
object="video",
|
||||
status="processing",
|
||||
model=model,
|
||||
)
|
||||
|
||||
usage_data = {}
|
||||
if request_data:
|
||||
parameters = request_data.get("parameters", {})
|
||||
duration = (
|
||||
parameters.get("durationSeconds")
|
||||
or DEFAULT_GOOGLE_VIDEO_DURATION_SECONDS
|
||||
)
|
||||
if duration is not None:
|
||||
try:
|
||||
usage_data["duration_seconds"] = float(duration)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
video_obj.usage = usage_data
|
||||
return video_obj
|
||||
|
||||
def transform_video_status_retrieve_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video status retrieve request for Veo API.
|
||||
|
||||
Veo polls operations at:
|
||||
GET https://generativelanguage.googleapis.com/v1beta/{operation_name}
|
||||
"""
|
||||
operation_name = extract_original_video_id(video_id)
|
||||
url = f"{api_base.rstrip('/')}/v1beta/{operation_name}"
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
return url, params
|
||||
|
||||
def transform_video_status_retrieve_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""
|
||||
Transform the Veo operation status response.
|
||||
|
||||
Veo returns:
|
||||
{
|
||||
"name": "operations/generate_1234567890",
|
||||
"done": false # or true when complete
|
||||
}
|
||||
|
||||
When done=true:
|
||||
{
|
||||
"name": "operations/generate_1234567890",
|
||||
"done": true,
|
||||
"response": {
|
||||
"generateVideoResponse": {
|
||||
"generatedSamples": [
|
||||
{
|
||||
"video": {
|
||||
"uri": "files/abc123..."
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
response_data = raw_response.json()
|
||||
# Parse response using Pydantic model for type safety
|
||||
operation_response = GeminiLongRunningOperationResponse(**response_data)
|
||||
|
||||
operation_name = operation_response.name
|
||||
is_done = operation_response.done
|
||||
|
||||
if custom_llm_provider:
|
||||
video_id = encode_video_id_with_provider(
|
||||
operation_name, custom_llm_provider, None
|
||||
)
|
||||
else:
|
||||
video_id = operation_name
|
||||
|
||||
video_obj = VideoObject(
|
||||
id=video_id,
|
||||
object="video",
|
||||
status="processing" if not is_done else "completed",
|
||||
)
|
||||
return video_obj
|
||||
|
||||
def transform_video_content_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Transform the video content request for Veo API.
|
||||
|
||||
For Veo, we need to:
|
||||
1. Get operation status to extract video URI
|
||||
2. Return download URL for the video
|
||||
"""
|
||||
operation_name = extract_original_video_id(video_id)
|
||||
|
||||
status_url = f"{api_base.rstrip('/')}/v1beta/{operation_name}"
|
||||
client = litellm.module_level_client
|
||||
status_response = client.get(url=status_url, headers=headers)
|
||||
status_response.raise_for_status()
|
||||
response_data = status_response.json()
|
||||
|
||||
operation_response = GeminiLongRunningOperationResponse(**response_data)
|
||||
|
||||
if not operation_response.done:
|
||||
raise ValueError(
|
||||
"Video generation is not complete yet. "
|
||||
"Please check status with video_status() before downloading."
|
||||
)
|
||||
|
||||
if not operation_response.response:
|
||||
raise ValueError("No response data in completed operation")
|
||||
|
||||
generated_samples = (
|
||||
operation_response.response.generateVideoResponse.generatedSamples
|
||||
)
|
||||
download_url = generated_samples[0].video.uri
|
||||
|
||||
params: Dict[str, Any] = {}
|
||||
|
||||
return download_url, params
|
||||
|
||||
def transform_video_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> bytes:
|
||||
"""
|
||||
Transform the Veo video content download response.
|
||||
Returns the video bytes directly.
|
||||
"""
|
||||
return raw_response.content
|
||||
|
||||
def transform_video_remix_request(
|
||||
self,
|
||||
video_id: str,
|
||||
prompt: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
extra_body: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video remix is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video remix is not supported by Google Veo. "
|
||||
"Please use video_generation() to create new videos."
|
||||
)
|
||||
|
||||
def transform_video_remix_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
"""Video remix is not supported."""
|
||||
raise NotImplementedError("Video remix is not supported by Google Veo.")
|
||||
|
||||
def transform_video_list_request(
|
||||
self,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
order: Optional[str] = None,
|
||||
extra_query: Optional[Dict[str, Any]] = None,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video list is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video list is not supported by Google Veo. "
|
||||
"Use the operations endpoint directly if you need to list operations."
|
||||
)
|
||||
|
||||
def transform_video_list_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""Video list is not supported."""
|
||||
raise NotImplementedError("Video list is not supported by Google Veo.")
|
||||
|
||||
def transform_video_delete_request(
|
||||
self,
|
||||
video_id: str,
|
||||
api_base: str,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Video delete is not supported by Veo API.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"Video delete is not supported by Google Veo. "
|
||||
"Videos are automatically cleaned up by Google."
|
||||
)
|
||||
|
||||
def transform_video_delete_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
) -> VideoObject:
|
||||
"""Video delete is not supported."""
|
||||
raise NotImplementedError("Video delete is not supported by Google Veo.")
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
from ..common_utils import GeminiError
|
||||
|
||||
return GeminiError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=headers,
|
||||
)
|
||||
Reference in New Issue
Block a user