chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
litellm_router = Router
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
litellm_router = Any
|
||||
|
||||
|
||||
class PromptCachingCacheValue(TypedDict):
|
||||
model_id: str
|
||||
|
||||
|
||||
class PromptCachingCache:
|
||||
def __init__(self, cache: DualCache):
|
||||
self.cache = cache
|
||||
self.in_memory_cache = InMemoryCache()
|
||||
|
||||
@staticmethod
|
||||
def serialize_object(obj: Any) -> Any:
|
||||
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
|
||||
if hasattr(obj, "dict"):
|
||||
# If the object is a Pydantic model, use its `dict()` method
|
||||
return obj.dict()
|
||||
elif isinstance(obj, dict):
|
||||
# If the object is a dictionary, serialize it with sorted keys
|
||||
return json.dumps(
|
||||
obj, sort_keys=True, separators=(",", ":")
|
||||
) # Standardize serialization
|
||||
|
||||
elif isinstance(obj, list):
|
||||
# Serialize lists by ensuring each element is handled properly
|
||||
return [PromptCachingCache.serialize_object(item) for item in obj]
|
||||
elif isinstance(obj, (int, float, bool)):
|
||||
return obj # Keep primitive types as-is
|
||||
return str(obj)
|
||||
|
||||
@staticmethod
|
||||
def extract_cacheable_prefix(
|
||||
messages: List[AllMessageValues],
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Extract the cacheable prefix from messages.
|
||||
|
||||
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
|
||||
(across all messages) that has cache_control. This includes ALL blocks before
|
||||
the last cacheable block (even if they don't have cache_control).
|
||||
|
||||
Args:
|
||||
messages: List of messages to extract cacheable prefix from
|
||||
|
||||
Returns:
|
||||
List of messages containing only the cacheable prefix
|
||||
"""
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
# Find the last content block (across all messages) that has cache_control
|
||||
last_cacheable_message_idx = None
|
||||
last_cacheable_content_idx = None
|
||||
|
||||
for msg_idx, message in enumerate(messages):
|
||||
content = message.get("content")
|
||||
|
||||
# Check for cache_control at message level (when content is a string)
|
||||
# This handles the case where cache_control is a sibling of string content:
|
||||
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
|
||||
message_level_cache_control = message.get("cache_control")
|
||||
if (
|
||||
message_level_cache_control is not None
|
||||
and isinstance(message_level_cache_control, dict)
|
||||
and message_level_cache_control.get("type") == "ephemeral"
|
||||
):
|
||||
last_cacheable_message_idx = msg_idx
|
||||
# Set to None to indicate the entire message content is cacheable
|
||||
# (not a specific content block index within a list)
|
||||
last_cacheable_content_idx = None
|
||||
|
||||
# Also check for cache_control within content blocks (when content is a list)
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for content_idx, content_block in enumerate(content):
|
||||
if isinstance(content_block, dict):
|
||||
cache_control = content_block.get("cache_control")
|
||||
if (
|
||||
cache_control is not None
|
||||
and isinstance(cache_control, dict)
|
||||
and cache_control.get("type") == "ephemeral"
|
||||
):
|
||||
last_cacheable_message_idx = msg_idx
|
||||
last_cacheable_content_idx = content_idx
|
||||
|
||||
# If no cacheable block found, return empty list (no cacheable prefix)
|
||||
if last_cacheable_message_idx is None:
|
||||
return []
|
||||
|
||||
# Build the cacheable prefix: all messages up to and including the last cacheable message
|
||||
cacheable_prefix = []
|
||||
|
||||
for msg_idx, message in enumerate(messages):
|
||||
if msg_idx < last_cacheable_message_idx:
|
||||
# Include entire message (comes before last cacheable block)
|
||||
cacheable_prefix.append(message)
|
||||
elif msg_idx == last_cacheable_message_idx:
|
||||
# Include message but only up to and including the last cacheable content block
|
||||
content = message.get("content")
|
||||
if isinstance(content, list) and last_cacheable_content_idx is not None:
|
||||
# Create a copy of the message with only cacheable content blocks
|
||||
message_copy = cast(
|
||||
AllMessageValues,
|
||||
{
|
||||
**message,
|
||||
"content": content[: last_cacheable_content_idx + 1],
|
||||
},
|
||||
)
|
||||
cacheable_prefix.append(message_copy)
|
||||
else:
|
||||
# Content is not a list or cacheable content idx is None, include full message
|
||||
cacheable_prefix.append(message)
|
||||
else:
|
||||
# Message comes after last cacheable block, don't include
|
||||
break
|
||||
|
||||
return cacheable_prefix
|
||||
|
||||
@staticmethod
|
||||
def get_prompt_caching_cache_key(
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[str]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Extract cacheable prefix from messages (only include up to last cache_control block)
|
||||
cacheable_messages = None
|
||||
if messages is not None:
|
||||
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
|
||||
# If no cacheable prefix found, return None (can't cache)
|
||||
if not cacheable_messages:
|
||||
return None
|
||||
|
||||
# Use serialize_object for consistent and stable serialization
|
||||
data_to_hash = {}
|
||||
if cacheable_messages is not None:
|
||||
serialized_messages = PromptCachingCache.serialize_object(
|
||||
cacheable_messages
|
||||
)
|
||||
data_to_hash["messages"] = serialized_messages
|
||||
if tools is not None:
|
||||
serialized_tools = PromptCachingCache.serialize_object(tools)
|
||||
data_to_hash["tools"] = serialized_tools
|
||||
|
||||
# Combine serialized data into a single string
|
||||
data_to_hash_str = json.dumps(
|
||||
data_to_hash,
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
|
||||
# Create a hash of the serialized data for a stable cache key
|
||||
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
|
||||
return f"deployment:{hashed_data}:prompt_caching"
|
||||
|
||||
def add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, don't cache (can't generate cache key)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
self.cache.set_cache(
|
||||
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_add_model_id(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> None:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, don't cache (can't generate cache key)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
await self.cache.async_set_cache(
|
||||
cache_key,
|
||||
PromptCachingCacheValue(model_id=model_id),
|
||||
ttl=300, # store for 5 minutes
|
||||
)
|
||||
return None
|
||||
|
||||
async def async_get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
"""
|
||||
Get model ID from cache using the cacheable prefix.
|
||||
|
||||
The cache key is based on the cacheable prefix (everything up to and including
|
||||
the last cache_control block), so requests with the same cacheable prefix but
|
||||
different user messages will have the same cache key.
|
||||
"""
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
# Generate cache key using cacheable prefix
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
# Perform cache lookup
|
||||
cache_result = await self.cache.async_get_cache(key=cache_key)
|
||||
return cache_result
|
||||
|
||||
def get_model_id(
|
||||
self,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
tools: Optional[List[ChatCompletionToolParam]],
|
||||
) -> Optional[PromptCachingCacheValue]:
|
||||
if messages is None and tools is None:
|
||||
return None
|
||||
|
||||
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
|
||||
# If no cacheable prefix found, return None (can't cache)
|
||||
if cache_key is None:
|
||||
return None
|
||||
|
||||
return self.cache.get_cache(cache_key)
|
||||
Reference in New Issue
Block a user