chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Transformation logic for context caching.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Tuple, Literal
|
||||
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.llms.vertex_ai import CachedContentRequestBody
|
||||
from litellm.utils import is_cached_message
|
||||
|
||||
from ..common_utils import get_supports_system_message
|
||||
from ..gemini.transformation import (
|
||||
_gemini_convert_messages_with_history,
|
||||
_transform_system_message,
|
||||
)
|
||||
|
||||
|
||||
def get_first_continuous_block_idx(
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] # (idx, message)
|
||||
) -> int:
|
||||
"""
|
||||
Find the array index that ends the first continuous sequence of message blocks.
|
||||
|
||||
Args:
|
||||
filtered_messages: List of tuples containing (index, message) pairs
|
||||
|
||||
Returns:
|
||||
int: The array index where the first continuous sequence ends
|
||||
"""
|
||||
if not filtered_messages:
|
||||
return -1
|
||||
|
||||
if len(filtered_messages) == 1:
|
||||
return 0
|
||||
|
||||
current_value = filtered_messages[0][0]
|
||||
|
||||
# Search forward through the array indices
|
||||
for i in range(1, len(filtered_messages)):
|
||||
if filtered_messages[i][0] != current_value + 1:
|
||||
return i - 1
|
||||
current_value = filtered_messages[i][0]
|
||||
|
||||
# If we made it through the whole list, return the last index
|
||||
return len(filtered_messages) - 1
|
||||
|
||||
|
||||
def extract_ttl_from_cached_messages(messages: List[AllMessageValues]) -> Optional[str]:
|
||||
"""
|
||||
Extract TTL from cached messages. Returns the first valid TTL found.
|
||||
|
||||
Args:
|
||||
messages: List of messages to extract TTL from
|
||||
|
||||
Returns:
|
||||
Optional[str]: TTL string in format "3600s" or None if not found/invalid
|
||||
"""
|
||||
for message in messages:
|
||||
if not is_cached_message(message):
|
||||
continue
|
||||
|
||||
content = message.get("content")
|
||||
if not content or isinstance(content, str):
|
||||
continue
|
||||
|
||||
for content_item in content:
|
||||
# Type check to ensure content_item is a dictionary before calling .get()
|
||||
if not isinstance(content_item, dict):
|
||||
continue
|
||||
|
||||
cache_control = content_item.get("cache_control")
|
||||
if not cache_control or not isinstance(cache_control, dict):
|
||||
continue
|
||||
|
||||
if cache_control.get("type") != "ephemeral":
|
||||
continue
|
||||
|
||||
ttl = cache_control.get("ttl")
|
||||
if ttl and _is_valid_ttl_format(ttl):
|
||||
return str(ttl)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _is_valid_ttl_format(ttl: str) -> bool:
|
||||
"""
|
||||
Validate TTL format. Should be a string ending with 's' for seconds.
|
||||
Examples: "3600s", "7200s", "1.5s"
|
||||
|
||||
Args:
|
||||
ttl: TTL string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if valid format, False otherwise
|
||||
"""
|
||||
if not isinstance(ttl, str):
|
||||
return False
|
||||
|
||||
# TTL should end with 's' and contain a valid number before it
|
||||
pattern = r"^([0-9]*\.?[0-9]+)s$"
|
||||
match = re.match(pattern, ttl)
|
||||
|
||||
if not match:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Ensure the numeric part is valid and positive
|
||||
numeric_part = float(match.group(1))
|
||||
return numeric_part > 0
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def separate_cached_messages(
|
||||
messages: List[AllMessageValues],
|
||||
) -> Tuple[List[AllMessageValues], List[AllMessageValues]]:
|
||||
"""
|
||||
Returns separated cached and non-cached messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages to be separated.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- cached_messages: List of cached messages.
|
||||
- non_cached_messages: List of non-cached messages.
|
||||
"""
|
||||
cached_messages: List[AllMessageValues] = []
|
||||
non_cached_messages: List[AllMessageValues] = []
|
||||
|
||||
# Extract cached messages and their indices
|
||||
filtered_messages: List[Tuple[int, AllMessageValues]] = []
|
||||
for idx, message in enumerate(messages):
|
||||
if is_cached_message(message=message):
|
||||
filtered_messages.append((idx, message))
|
||||
|
||||
# Validate only one block of continuous cached messages
|
||||
last_continuous_block_idx = get_first_continuous_block_idx(filtered_messages)
|
||||
# Separate messages based on the block of cached messages
|
||||
if filtered_messages and last_continuous_block_idx is not None:
|
||||
first_cached_idx = filtered_messages[0][0]
|
||||
last_cached_idx = filtered_messages[last_continuous_block_idx][0]
|
||||
|
||||
cached_messages = messages[first_cached_idx : last_cached_idx + 1]
|
||||
non_cached_messages = (
|
||||
messages[:first_cached_idx] + messages[last_cached_idx + 1 :]
|
||||
)
|
||||
else:
|
||||
non_cached_messages = messages
|
||||
|
||||
return cached_messages, non_cached_messages
|
||||
|
||||
|
||||
def transform_openai_messages_to_gemini_context_caching(
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
cache_key: str,
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
) -> CachedContentRequestBody:
|
||||
# Extract TTL from cached messages BEFORE system message transformation
|
||||
ttl = extract_ttl_from_cached_messages(messages)
|
||||
|
||||
supports_system_message = get_supports_system_message(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
transformed_system_messages, new_messages = _transform_system_message(
|
||||
supports_system_message=supports_system_message, messages=messages
|
||||
)
|
||||
|
||||
transformed_messages = _gemini_convert_messages_with_history(
|
||||
messages=new_messages, model=model
|
||||
)
|
||||
|
||||
model_name = "models/{}".format(model)
|
||||
|
||||
if custom_llm_provider == "vertex_ai" or custom_llm_provider == "vertex_ai_beta":
|
||||
model_name = f"projects/{vertex_project}/locations/{vertex_location}/publishers/google/{model_name}"
|
||||
|
||||
data = CachedContentRequestBody(
|
||||
contents=transformed_messages,
|
||||
model=model_name,
|
||||
displayName=cache_key,
|
||||
)
|
||||
|
||||
# Add TTL if present and valid
|
||||
if ttl:
|
||||
data["ttl"] = ttl
|
||||
|
||||
if transformed_system_messages is not None:
|
||||
data["system_instruction"] = transformed_system_messages
|
||||
|
||||
return data
|
||||
@@ -0,0 +1,578 @@
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import Cache, LiteLLMCacheType
|
||||
from litellm.constants import MINIMUM_PROMPT_CACHE_TOKEN_COUNT
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.openai.openai import AllMessageValues
|
||||
from litellm.utils import is_prompt_caching_valid_prompt
|
||||
from litellm.types.llms.vertex_ai import (
|
||||
CachedContentListAllResponseBody,
|
||||
VertexAICachedContentResponseObject,
|
||||
)
|
||||
|
||||
from ..common_utils import VertexAIError
|
||||
from ..vertex_llm_base import VertexBase
|
||||
from .transformation import (
|
||||
separate_cached_messages,
|
||||
transform_openai_messages_to_gemini_context_caching,
|
||||
)
|
||||
|
||||
local_cache_obj = Cache(
|
||||
type=LiteLLMCacheType.LOCAL
|
||||
) # only used for calling 'get_cache_key' function
|
||||
|
||||
MAX_PAGINATION_PAGES = 100 # Reasonable upper bound for pagination
|
||||
|
||||
|
||||
class ContextCachingEndpoints(VertexBase):
|
||||
"""
|
||||
Covers context caching endpoints for Vertex AI + Google AI Studio
|
||||
|
||||
v0: covers Google AI Studio
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_token_and_url_context_caching(
|
||||
self,
|
||||
gemini_api_key: Optional[str],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Internal function. Returns the token and url for the call.
|
||||
|
||||
Handles logic if it's google ai studio vs. vertex ai.
|
||||
|
||||
Returns
|
||||
token, url
|
||||
"""
|
||||
if custom_llm_provider == "gemini":
|
||||
auth_header = None
|
||||
endpoint = "cachedContents"
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/{}?key={}".format(
|
||||
endpoint, gemini_api_key
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
auth_header = vertex_auth_header
|
||||
endpoint = "cachedContents"
|
||||
if vertex_location == "global":
|
||||
url = f"https://aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
auth_header = vertex_auth_header
|
||||
endpoint = "cachedContents"
|
||||
if vertex_location == "global":
|
||||
url = f"https://aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
else:
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/{endpoint}"
|
||||
|
||||
return self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
gemini_api_key=gemini_api_key,
|
||||
endpoint=endpoint,
|
||||
stream=None,
|
||||
auth_header=auth_header,
|
||||
url=url,
|
||||
model=None,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_api_version="v1beta1"
|
||||
if custom_llm_provider == "vertex_ai_beta"
|
||||
else "v1",
|
||||
)
|
||||
|
||||
def check_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
client: HTTPHandler,
|
||||
headers: dict,
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Checks if content already cached.
|
||||
|
||||
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
|
||||
|
||||
Returns
|
||||
- cached_content_name - str - cached content name stored on google. (if found.)
|
||||
OR
|
||||
- None
|
||||
"""
|
||||
|
||||
_, base_url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
page_token: Optional[str] = None
|
||||
|
||||
# Iterate through all pages
|
||||
for _ in range(MAX_PAGINATION_PAGES):
|
||||
# Build URL with pagination token if present
|
||||
if page_token:
|
||||
separator = "&" if "?" in base_url else "?"
|
||||
url = f"{base_url}{separator}pageToken={page_token}"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
try:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
resp = client.get(url=url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 403:
|
||||
return None
|
||||
raise VertexAIError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
raw_response = resp.json()
|
||||
logging_obj.post_call(original_response=raw_response)
|
||||
|
||||
if "cachedContents" not in raw_response:
|
||||
return None
|
||||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
# Check current page for matching cache_key
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
# Check if there are more pages
|
||||
page_token = all_cached_items.get("nextPageToken")
|
||||
if not page_token:
|
||||
# No more pages, cache not found
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
async def async_check_cache(
|
||||
self,
|
||||
cache_key: str,
|
||||
client: AsyncHTTPHandler,
|
||||
headers: dict,
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Checks if content already cached.
|
||||
|
||||
Currently, checks cache list, for cache key == displayName, since Google doesn't let us set the name of the cache (their API docs are out of sync with actual implementation).
|
||||
|
||||
Returns
|
||||
- cached_content_name - str - cached content name stored on google. (if found.)
|
||||
OR
|
||||
- None
|
||||
"""
|
||||
|
||||
_, base_url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
page_token: Optional[str] = None
|
||||
|
||||
# Iterate through all pages
|
||||
for _ in range(MAX_PAGINATION_PAGES):
|
||||
# Build URL with pagination token if present
|
||||
if page_token:
|
||||
separator = "&" if "?" in base_url else "?"
|
||||
url = f"{base_url}{separator}pageToken={page_token}"
|
||||
else:
|
||||
url = base_url
|
||||
|
||||
try:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input="",
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {},
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
resp = await client.get(url=url, headers=headers)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 403:
|
||||
return None
|
||||
raise VertexAIError(
|
||||
status_code=e.response.status_code, message=e.response.text
|
||||
)
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
raw_response = resp.json()
|
||||
logging_obj.post_call(original_response=raw_response)
|
||||
|
||||
if "cachedContents" not in raw_response:
|
||||
return None
|
||||
|
||||
all_cached_items = CachedContentListAllResponseBody(**raw_response)
|
||||
|
||||
if "cachedContents" not in all_cached_items:
|
||||
return None
|
||||
|
||||
# Check current page for matching cache_key
|
||||
for cached_item in all_cached_items["cachedContents"]:
|
||||
display_name = cached_item.get("displayName")
|
||||
if display_name is not None and display_name == cache_key:
|
||||
return cached_item.get("name")
|
||||
|
||||
# Check if there are more pages
|
||||
page_token = all_cached_items.get("nextPageToken")
|
||||
if not page_token:
|
||||
# No more pages, cache not found
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
def check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: dict, # cache the tools if present, in case cache content exists in messages
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
||||
Returns
|
||||
- messages - List[dict] - filtered list of messages in the openai format.
|
||||
- cached_content - str - the cache content id, to be passed in the gemini request body
|
||||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, optional_params, cached_content
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if token is not None:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = HTTPHandler(**_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(
|
||||
messages=cached_messages, tools=tools, model=model
|
||||
)
|
||||
google_cache_name = self.check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
if google_cache_name:
|
||||
return non_cached_messages, optional_params, google_cache_name
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
cache_key=generated_cache_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
)
|
||||
)
|
||||
|
||||
cached_content_request_body["tools"] = tools
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
raw_response_cached = response.json()
|
||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (
|
||||
non_cached_messages,
|
||||
optional_params,
|
||||
cached_content_response_obj["name"],
|
||||
)
|
||||
|
||||
async def async_check_and_create_cache(
|
||||
self,
|
||||
messages: List[AllMessageValues], # receives openai format messages
|
||||
optional_params: dict, # cache the tools if present, in case cache content exists in messages
|
||||
api_key: str,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: Logging,
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_auth_header: Optional[str],
|
||||
extra_headers: Optional[dict] = None,
|
||||
cached_content: Optional[str] = None,
|
||||
) -> Tuple[List[AllMessageValues], dict, Optional[str]]:
|
||||
"""
|
||||
Receives
|
||||
- messages: List of dict - messages in the openai format
|
||||
|
||||
Returns
|
||||
- messages - List[dict] - filtered list of messages in the openai format.
|
||||
- cached_content - str - the cache content id, to be passed in the gemini request body
|
||||
|
||||
Follows - https://ai.google.dev/api/caching#request-body
|
||||
"""
|
||||
if cached_content is not None:
|
||||
return messages, optional_params, cached_content
|
||||
|
||||
cached_messages, non_cached_messages = separate_cached_messages(
|
||||
messages=messages
|
||||
)
|
||||
|
||||
if len(cached_messages) == 0:
|
||||
return messages, optional_params, None
|
||||
|
||||
# Gemini requires a minimum of 1024 tokens for context caching.
|
||||
# Skip caching if the cached content is too small to avoid API errors.
|
||||
if not is_prompt_caching_valid_prompt(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"Vertex AI context caching: cached content is below minimum token "
|
||||
"count (%d). Skipping context caching.",
|
||||
MINIMUM_PROMPT_CACHE_TOKEN_COUNT,
|
||||
)
|
||||
return messages, optional_params, None
|
||||
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
||||
## AUTHORIZATION ##
|
||||
token, url = self._get_token_and_url_context_caching(
|
||||
gemini_api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if token is not None:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
if extra_headers is not None:
|
||||
headers.update(extra_headers)
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
params={"timeout": timeout}, llm_provider=litellm.LlmProviders.VERTEX_AI
|
||||
)
|
||||
else:
|
||||
client = client
|
||||
|
||||
## CHECK IF CACHED ALREADY
|
||||
generated_cache_key = local_cache_obj.get_cache_key(
|
||||
messages=cached_messages, tools=tools, model=model
|
||||
)
|
||||
google_cache_name = await self.async_check_cache(
|
||||
cache_key=generated_cache_key,
|
||||
client=client,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
vertex_auth_header=vertex_auth_header,
|
||||
)
|
||||
|
||||
if google_cache_name:
|
||||
return non_cached_messages, optional_params, google_cache_name
|
||||
|
||||
## TRANSFORM REQUEST
|
||||
cached_content_request_body = (
|
||||
transform_openai_messages_to_gemini_context_caching(
|
||||
model=model,
|
||||
messages=cached_messages,
|
||||
cache_key=generated_cache_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
)
|
||||
)
|
||||
|
||||
cached_content_request_body["tools"] = tools
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": cached_content_request_body,
|
||||
"api_base": url,
|
||||
"headers": headers,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=url, headers=headers, json=cached_content_request_body # type: ignore
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise VertexAIError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise VertexAIError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
raw_response_cached = response.json()
|
||||
cached_content_response_obj = VertexAICachedContentResponseObject(
|
||||
name=raw_response_cached.get("name"), model=raw_response_cached.get("model")
|
||||
)
|
||||
return (
|
||||
non_cached_messages,
|
||||
optional_params,
|
||||
cached_content_response_obj["name"],
|
||||
)
|
||||
|
||||
def get_cache(self):
|
||||
pass
|
||||
|
||||
async def async_get_cache(self):
|
||||
pass
|
||||
Reference in New Issue
Block a user