260 lines
9.7 KiB
Python
260 lines
9.7 KiB
Python
"""
|
|
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
|
"""
|
|
|
|
import hashlib
|
|
import json
|
|
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
|
|
|
from typing_extensions import TypedDict
|
|
|
|
from litellm.caching.caching import DualCache
|
|
from litellm.caching.in_memory_cache import InMemoryCache
|
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
|
|
|
if TYPE_CHECKING:
|
|
from opentelemetry.trace import Span as _Span
|
|
|
|
from litellm.router import Router
|
|
|
|
litellm_router = Router
|
|
Span = Union[_Span, Any]
|
|
else:
|
|
Span = Any
|
|
litellm_router = Any
|
|
|
|
|
|
class PromptCachingCacheValue(TypedDict):
|
|
model_id: str
|
|
|
|
|
|
class PromptCachingCache:
|
|
def __init__(self, cache: DualCache):
|
|
self.cache = cache
|
|
self.in_memory_cache = InMemoryCache()
|
|
|
|
@staticmethod
|
|
def serialize_object(obj: Any) -> Any:
|
|
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
|
if hasattr(obj, "dict"):
|
|
# If the object is a Pydantic model, use its `dict()` method
|
|
return obj.dict()
|
|
elif isinstance(obj, dict):
|
|
# If the object is a dictionary, serialize it with sorted keys
|
|
return json.dumps(
|
|
obj, sort_keys=True, separators=(",", ":")
|
|
) # Standardize serialization
|
|
|
|
elif isinstance(obj, list):
|
|
# Serialize lists by ensuring each element is handled properly
|
|
return [PromptCachingCache.serialize_object(item) for item in obj]
|
|
elif isinstance(obj, (int, float, bool)):
|
|
return obj # Keep primitive types as-is
|
|
return str(obj)
|
|
|
|
@staticmethod
|
|
def extract_cacheable_prefix(
|
|
messages: List[AllMessageValues],
|
|
) -> List[AllMessageValues]:
|
|
"""
|
|
Extract the cacheable prefix from messages.
|
|
|
|
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
|
|
(across all messages) that has cache_control. This includes ALL blocks before
|
|
the last cacheable block (even if they don't have cache_control).
|
|
|
|
Args:
|
|
messages: List of messages to extract cacheable prefix from
|
|
|
|
Returns:
|
|
List of messages containing only the cacheable prefix
|
|
"""
|
|
if not messages:
|
|
return messages
|
|
|
|
# Find the last content block (across all messages) that has cache_control
|
|
last_cacheable_message_idx = None
|
|
last_cacheable_content_idx = None
|
|
|
|
for msg_idx, message in enumerate(messages):
|
|
content = message.get("content")
|
|
|
|
# Check for cache_control at message level (when content is a string)
|
|
# This handles the case where cache_control is a sibling of string content:
|
|
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
|
|
message_level_cache_control = message.get("cache_control")
|
|
if (
|
|
message_level_cache_control is not None
|
|
and isinstance(message_level_cache_control, dict)
|
|
and message_level_cache_control.get("type") == "ephemeral"
|
|
):
|
|
last_cacheable_message_idx = msg_idx
|
|
# Set to None to indicate the entire message content is cacheable
|
|
# (not a specific content block index within a list)
|
|
last_cacheable_content_idx = None
|
|
|
|
# Also check for cache_control within content blocks (when content is a list)
|
|
if not isinstance(content, list):
|
|
continue
|
|
|
|
for content_idx, content_block in enumerate(content):
|
|
if isinstance(content_block, dict):
|
|
cache_control = content_block.get("cache_control")
|
|
if (
|
|
cache_control is not None
|
|
and isinstance(cache_control, dict)
|
|
and cache_control.get("type") == "ephemeral"
|
|
):
|
|
last_cacheable_message_idx = msg_idx
|
|
last_cacheable_content_idx = content_idx
|
|
|
|
# If no cacheable block found, return empty list (no cacheable prefix)
|
|
if last_cacheable_message_idx is None:
|
|
return []
|
|
|
|
# Build the cacheable prefix: all messages up to and including the last cacheable message
|
|
cacheable_prefix = []
|
|
|
|
for msg_idx, message in enumerate(messages):
|
|
if msg_idx < last_cacheable_message_idx:
|
|
# Include entire message (comes before last cacheable block)
|
|
cacheable_prefix.append(message)
|
|
elif msg_idx == last_cacheable_message_idx:
|
|
# Include message but only up to and including the last cacheable content block
|
|
content = message.get("content")
|
|
if isinstance(content, list) and last_cacheable_content_idx is not None:
|
|
# Create a copy of the message with only cacheable content blocks
|
|
message_copy = cast(
|
|
AllMessageValues,
|
|
{
|
|
**message,
|
|
"content": content[: last_cacheable_content_idx + 1],
|
|
},
|
|
)
|
|
cacheable_prefix.append(message_copy)
|
|
else:
|
|
# Content is not a list or cacheable content idx is None, include full message
|
|
cacheable_prefix.append(message)
|
|
else:
|
|
# Message comes after last cacheable block, don't include
|
|
break
|
|
|
|
return cacheable_prefix
|
|
|
|
@staticmethod
|
|
def get_prompt_caching_cache_key(
|
|
messages: Optional[List[AllMessageValues]],
|
|
tools: Optional[List[ChatCompletionToolParam]],
|
|
) -> Optional[str]:
|
|
if messages is None and tools is None:
|
|
return None
|
|
|
|
# Extract cacheable prefix from messages (only include up to last cache_control block)
|
|
cacheable_messages = None
|
|
if messages is not None:
|
|
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
|
|
# If no cacheable prefix found, return None (can't cache)
|
|
if not cacheable_messages:
|
|
return None
|
|
|
|
# Use serialize_object for consistent and stable serialization
|
|
data_to_hash = {}
|
|
if cacheable_messages is not None:
|
|
serialized_messages = PromptCachingCache.serialize_object(
|
|
cacheable_messages
|
|
)
|
|
data_to_hash["messages"] = serialized_messages
|
|
if tools is not None:
|
|
serialized_tools = PromptCachingCache.serialize_object(tools)
|
|
data_to_hash["tools"] = serialized_tools
|
|
|
|
# Combine serialized data into a single string
|
|
data_to_hash_str = json.dumps(
|
|
data_to_hash,
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
)
|
|
|
|
# Create a hash of the serialized data for a stable cache key
|
|
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
|
return f"deployment:{hashed_data}:prompt_caching"
|
|
|
|
def add_model_id(
|
|
self,
|
|
model_id: str,
|
|
messages: Optional[List[AllMessageValues]],
|
|
tools: Optional[List[ChatCompletionToolParam]],
|
|
) -> None:
|
|
if messages is None and tools is None:
|
|
return None
|
|
|
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
|
# If no cacheable prefix found, don't cache (can't generate cache key)
|
|
if cache_key is None:
|
|
return None
|
|
|
|
self.cache.set_cache(
|
|
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
|
)
|
|
return None
|
|
|
|
async def async_add_model_id(
|
|
self,
|
|
model_id: str,
|
|
messages: Optional[List[AllMessageValues]],
|
|
tools: Optional[List[ChatCompletionToolParam]],
|
|
) -> None:
|
|
if messages is None and tools is None:
|
|
return None
|
|
|
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
|
# If no cacheable prefix found, don't cache (can't generate cache key)
|
|
if cache_key is None:
|
|
return None
|
|
|
|
await self.cache.async_set_cache(
|
|
cache_key,
|
|
PromptCachingCacheValue(model_id=model_id),
|
|
ttl=300, # store for 5 minutes
|
|
)
|
|
return None
|
|
|
|
async def async_get_model_id(
|
|
self,
|
|
messages: Optional[List[AllMessageValues]],
|
|
tools: Optional[List[ChatCompletionToolParam]],
|
|
) -> Optional[PromptCachingCacheValue]:
|
|
"""
|
|
Get model ID from cache using the cacheable prefix.
|
|
|
|
The cache key is based on the cacheable prefix (everything up to and including
|
|
the last cache_control block), so requests with the same cacheable prefix but
|
|
different user messages will have the same cache key.
|
|
"""
|
|
if messages is None and tools is None:
|
|
return None
|
|
|
|
# Generate cache key using cacheable prefix
|
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
|
if cache_key is None:
|
|
return None
|
|
|
|
# Perform cache lookup
|
|
cache_result = await self.cache.async_get_cache(key=cache_key)
|
|
return cache_result
|
|
|
|
def get_model_id(
|
|
self,
|
|
messages: Optional[List[AllMessageValues]],
|
|
tools: Optional[List[ChatCompletionToolParam]],
|
|
) -> Optional[PromptCachingCacheValue]:
|
|
if messages is None and tools is None:
|
|
return None
|
|
|
|
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
|
# If no cacheable prefix found, return None (can't cache)
|
|
if cache_key is None:
|
|
return None
|
|
|
|
return self.cache.get_cache(cache_key)
|