chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,259 @@
"""
Wrapper around router cache. Meant to store model id when prompt caching supported prompt is called.
"""
import hashlib
import json
from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
from typing_extensions import TypedDict
from litellm.caching.caching import DualCache
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.types.llms.openai import AllMessageValues, ChatCompletionToolParam
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.router import Router
litellm_router = Router
Span = Union[_Span, Any]
else:
Span = Any
litellm_router = Any
class PromptCachingCacheValue(TypedDict):
model_id: str
class PromptCachingCache:
def __init__(self, cache: DualCache):
self.cache = cache
self.in_memory_cache = InMemoryCache()
@staticmethod
def serialize_object(obj: Any) -> Any:
"""Helper function to serialize Pydantic objects, dictionaries, or fallback to string."""
if hasattr(obj, "dict"):
# If the object is a Pydantic model, use its `dict()` method
return obj.dict()
elif isinstance(obj, dict):
# If the object is a dictionary, serialize it with sorted keys
return json.dumps(
obj, sort_keys=True, separators=(",", ":")
) # Standardize serialization
elif isinstance(obj, list):
# Serialize lists by ensuring each element is handled properly
return [PromptCachingCache.serialize_object(item) for item in obj]
elif isinstance(obj, (int, float, bool)):
return obj # Keep primitive types as-is
return str(obj)
@staticmethod
def extract_cacheable_prefix(
messages: List[AllMessageValues],
) -> List[AllMessageValues]:
"""
Extract the cacheable prefix from messages.
The cacheable prefix is everything UP TO AND INCLUDING the LAST content block
(across all messages) that has cache_control. This includes ALL blocks before
the last cacheable block (even if they don't have cache_control).
Args:
messages: List of messages to extract cacheable prefix from
Returns:
List of messages containing only the cacheable prefix
"""
if not messages:
return messages
# Find the last content block (across all messages) that has cache_control
last_cacheable_message_idx = None
last_cacheable_content_idx = None
for msg_idx, message in enumerate(messages):
content = message.get("content")
# Check for cache_control at message level (when content is a string)
# This handles the case where cache_control is a sibling of string content:
# {"role": "user", "content": "...", "cache_control": {"type": "ephemeral"}}
message_level_cache_control = message.get("cache_control")
if (
message_level_cache_control is not None
and isinstance(message_level_cache_control, dict)
and message_level_cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
# Set to None to indicate the entire message content is cacheable
# (not a specific content block index within a list)
last_cacheable_content_idx = None
# Also check for cache_control within content blocks (when content is a list)
if not isinstance(content, list):
continue
for content_idx, content_block in enumerate(content):
if isinstance(content_block, dict):
cache_control = content_block.get("cache_control")
if (
cache_control is not None
and isinstance(cache_control, dict)
and cache_control.get("type") == "ephemeral"
):
last_cacheable_message_idx = msg_idx
last_cacheable_content_idx = content_idx
# If no cacheable block found, return empty list (no cacheable prefix)
if last_cacheable_message_idx is None:
return []
# Build the cacheable prefix: all messages up to and including the last cacheable message
cacheable_prefix = []
for msg_idx, message in enumerate(messages):
if msg_idx < last_cacheable_message_idx:
# Include entire message (comes before last cacheable block)
cacheable_prefix.append(message)
elif msg_idx == last_cacheable_message_idx:
# Include message but only up to and including the last cacheable content block
content = message.get("content")
if isinstance(content, list) and last_cacheable_content_idx is not None:
# Create a copy of the message with only cacheable content blocks
message_copy = cast(
AllMessageValues,
{
**message,
"content": content[: last_cacheable_content_idx + 1],
},
)
cacheable_prefix.append(message_copy)
else:
# Content is not a list or cacheable content idx is None, include full message
cacheable_prefix.append(message)
else:
# Message comes after last cacheable block, don't include
break
return cacheable_prefix
@staticmethod
def get_prompt_caching_cache_key(
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[str]:
if messages is None and tools is None:
return None
# Extract cacheable prefix from messages (only include up to last cache_control block)
cacheable_messages = None
if messages is not None:
cacheable_messages = PromptCachingCache.extract_cacheable_prefix(messages)
# If no cacheable prefix found, return None (can't cache)
if not cacheable_messages:
return None
# Use serialize_object for consistent and stable serialization
data_to_hash = {}
if cacheable_messages is not None:
serialized_messages = PromptCachingCache.serialize_object(
cacheable_messages
)
data_to_hash["messages"] = serialized_messages
if tools is not None:
serialized_tools = PromptCachingCache.serialize_object(tools)
data_to_hash["tools"] = serialized_tools
# Combine serialized data into a single string
data_to_hash_str = json.dumps(
data_to_hash,
sort_keys=True,
separators=(",", ":"),
)
# Create a hash of the serialized data for a stable cache key
hashed_data = hashlib.sha256(data_to_hash_str.encode()).hexdigest()
return f"deployment:{hashed_data}:prompt_caching"
def add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
self.cache.set_cache(
cache_key, PromptCachingCacheValue(model_id=model_id), ttl=300
)
return None
async def async_add_model_id(
self,
model_id: str,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> None:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, don't cache (can't generate cache key)
if cache_key is None:
return None
await self.cache.async_set_cache(
cache_key,
PromptCachingCacheValue(model_id=model_id),
ttl=300, # store for 5 minutes
)
return None
async def async_get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
"""
Get model ID from cache using the cacheable prefix.
The cache key is based on the cacheable prefix (everything up to and including
the last cache_control block), so requests with the same cacheable prefix but
different user messages will have the same cache key.
"""
if messages is None and tools is None:
return None
# Generate cache key using cacheable prefix
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
if cache_key is None:
return None
# Perform cache lookup
cache_result = await self.cache.async_get_cache(key=cache_key)
return cache_result
def get_model_id(
self,
messages: Optional[List[AllMessageValues]],
tools: Optional[List[ChatCompletionToolParam]],
) -> Optional[PromptCachingCacheValue]:
if messages is None and tools is None:
return None
cache_key = PromptCachingCache.get_prompt_caching_cache_key(messages, tools)
# If no cacheable prefix found, return None (can't cache)
if cache_key is None:
return None
return self.cache.get_cache(cache_key)