chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,41 @@
|
||||
# Caching on LiteLLM
|
||||
|
||||
LiteLLM supports multiple caching mechanisms. This allows users to choose the most suitable caching solution for their use case.
|
||||
|
||||
The following caching mechanisms are supported:
|
||||
|
||||
1. **RedisCache**
|
||||
2. **RedisSemanticCache**
|
||||
3. **QdrantSemanticCache**
|
||||
4. **InMemoryCache**
|
||||
5. **DiskCache**
|
||||
6. **S3Cache**
|
||||
7. **AzureBlobCache**
|
||||
8. **DualCache** (updates both Redis and an in-memory cache simultaneously)
|
||||
|
||||
## Folder Structure
|
||||
|
||||
```
|
||||
litellm/caching/
|
||||
├── base_cache.py
|
||||
├── caching.py
|
||||
├── caching_handler.py
|
||||
├── disk_cache.py
|
||||
├── dual_cache.py
|
||||
├── in_memory_cache.py
|
||||
├── qdrant_semantic_cache.py
|
||||
├── redis_cache.py
|
||||
├── redis_semantic_cache.py
|
||||
├── s3_cache.py
|
||||
```
|
||||
|
||||
## Documentation
|
||||
- [Caching on LiteLLM Gateway](https://docs.litellm.ai/docs/proxy/caching)
|
||||
- [Caching on LiteLLM Python](https://docs.litellm.ai/docs/caching/all_caches)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .caching import Cache, LiteLLMCacheType
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
from .gcs_cache import GCSCache
|
||||
@@ -0,0 +1,30 @@
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def lru_cache_wrapper(
|
||||
maxsize: Optional[int] = None,
|
||||
) -> Callable[[Callable[..., T]], Callable[..., T]]:
|
||||
"""
|
||||
Wrapper for lru_cache that caches success and exceptions
|
||||
"""
|
||||
|
||||
def decorator(f: Callable[..., T]) -> Callable[..., T]:
|
||||
@lru_cache(maxsize=maxsize)
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return ("success", f(*args, **kwargs))
|
||||
except Exception as e:
|
||||
return ("error", e)
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
result = wrapper(*args, **kwargs)
|
||||
if result[0] == "error":
|
||||
raise result[1]
|
||||
return result[1]
|
||||
|
||||
return wrapped
|
||||
|
||||
return decorator
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Azure Blob Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import suppress
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class AzureBlobCache(BaseCache):
|
||||
def __init__(self, account_url, container) -> None:
|
||||
from azure.storage.blob import BlobServiceClient
|
||||
from azure.core.exceptions import ResourceExistsError
|
||||
from azure.identity import DefaultAzureCredential
|
||||
from azure.identity.aio import (
|
||||
DefaultAzureCredential as AsyncDefaultAzureCredential,
|
||||
)
|
||||
from azure.storage.blob.aio import BlobServiceClient as AsyncBlobServiceClient
|
||||
|
||||
self.container_client = BlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=DefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
self.async_container_client = AsyncBlobServiceClient(
|
||||
account_url=account_url,
|
||||
credential=AsyncDefaultAzureCredential(),
|
||||
).get_container_client(container)
|
||||
|
||||
with suppress(ResourceExistsError):
|
||||
self.container_client.create_container()
|
||||
|
||||
def set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
self.container_client.upload_blob(key, serialized_value)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs) -> None:
|
||||
print_verbose(f"LiteLLM SET Cache - Azure Blob. Key={key}. Value={value}")
|
||||
serialized_value = json.dumps(value)
|
||||
try:
|
||||
await self.async_container_client.upload_blob(
|
||||
key, serialized_value, overwrite=True
|
||||
)
|
||||
except Exception as e:
|
||||
# NON blocking - notify users Azure Blob is throwing an exception
|
||||
print_verbose(f"LiteLLM set_cache() - Got exception from Azure Blob: {e}")
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
as_bytes = self.container_client.download_blob(key).readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
|
||||
try:
|
||||
print_verbose(f"Get Azure Blob Cache: key: {key}")
|
||||
blob = await self.async_container_client.download_blob(key)
|
||||
as_bytes = await blob.readall()
|
||||
as_str = as_bytes.decode("utf-8")
|
||||
cached_response = json.loads(as_str)
|
||||
verbose_logger.debug(
|
||||
f"Got Azure Blob Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
except ResourceNotFoundError:
|
||||
return None
|
||||
|
||||
def flush_cache(self) -> None:
|
||||
for blob in self.container_client.walk_blobs():
|
||||
self.container_client.delete_blob(blob.name)
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
self.container_client.close()
|
||||
await self.async_container_client.close()
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs) -> None:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Base Cache implementation. All cache implementations should inherit from this class.
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class BaseCache(ABC):
|
||||
def __init__(self, default_ttl: int = 60):
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def get_ttl(self, **kwargs) -> Optional[int]:
|
||||
kwargs_ttl: Optional[int] = kwargs.get("ttl")
|
||||
if kwargs_ttl is not None:
|
||||
try:
|
||||
return int(kwargs_ttl)
|
||||
except ValueError:
|
||||
return self.default_ttl
|
||||
return self.default_ttl
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
pass
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def batch_cache_write(self, key, value, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
async def disconnect(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the cache connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,926 @@
|
||||
# +-----------------------------------------------+
|
||||
# | |
|
||||
# | Give Feedback / Get Help |
|
||||
# | https://github.com/BerriAI/litellm/issues/new |
|
||||
# | |
|
||||
# +-----------------------------------------------+
|
||||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import ast
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import EmbeddingResponse, all_litellm_params
|
||||
|
||||
from .azure_blob_cache import AzureBlobCache
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
from .dual_cache import DualCache # noqa
|
||||
from .gcs_cache import GCSCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .qdrant_semantic_cache import QdrantSemanticCache
|
||||
from .redis_cache import RedisCache
|
||||
from .redis_cluster_cache import RedisClusterCache
|
||||
from .redis_semantic_cache import RedisSemanticCache
|
||||
from .s3_cache import S3Cache
|
||||
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
class CacheMode(str, Enum):
|
||||
default_on = "default_on"
|
||||
default_off = "default_off"
|
||||
|
||||
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
mode: Optional[
|
||||
CacheMode
|
||||
] = CacheMode.default_on, # when default_on cache is always on, when default_off cache is opt in
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
namespace: Optional[str] = None,
|
||||
ttl: Optional[float] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_in_redis_ttl: Optional[float] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
# s3 Bucket, boto3 configuration
|
||||
azure_account_url: Optional[str] = None,
|
||||
azure_blob_container: Optional[str] = None,
|
||||
s3_bucket_name: Optional[str] = None,
|
||||
s3_region_name: Optional[str] = None,
|
||||
s3_api_version: Optional[str] = None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify: Optional[Union[bool, str]] = None,
|
||||
s3_endpoint_url: Optional[str] = None,
|
||||
s3_aws_access_key_id: Optional[str] = None,
|
||||
s3_aws_secret_access_key: Optional[str] = None,
|
||||
s3_aws_session_token: Optional[str] = None,
|
||||
s3_config: Optional[Any] = None,
|
||||
s3_path: Optional[str] = None,
|
||||
gcs_bucket_name: Optional[str] = None,
|
||||
gcs_path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
redis_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
redis_semantic_cache_index_name: Optional[str] = None,
|
||||
redis_flush_size: Optional[int] = None,
|
||||
redis_startup_nodes: Optional[List] = None,
|
||||
disk_cache_dir: Optional[str] = None,
|
||||
qdrant_api_base: Optional[str] = None,
|
||||
qdrant_api_key: Optional[str] = None,
|
||||
qdrant_collection_name: Optional[str] = None,
|
||||
qdrant_quantization_config: Optional[str] = None,
|
||||
qdrant_semantic_cache_embedding_model: str = "text-embedding-ada-002",
|
||||
qdrant_semantic_cache_vector_size: Optional[int] = None,
|
||||
# GCP IAM authentication parameters
|
||||
gcp_service_account: Optional[str] = None,
|
||||
gcp_ssl_ca_certs: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Can be "local", "redis", "redis-semantic", "qdrant-semantic", "s3" or "disk". Defaults to "local".
|
||||
|
||||
# Redis Cache Args
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
namespace (str, optional): The namespace for the Redis cache. Required if type is "redis".
|
||||
ttl (float, optional): The ttl for the Redis cache
|
||||
redis_flush_size (int, optional): The number of keys to flush at a time. Defaults to 1000. Only used if batch redis set caching is used.
|
||||
redis_startup_nodes (list, optional): The list of startup nodes for the Redis cache. Defaults to None.
|
||||
|
||||
# Qdrant Cache Args
|
||||
qdrant_api_base (str, optional): The url for your qdrant cluster. Required if type is "qdrant-semantic".
|
||||
qdrant_api_key (str, optional): The api_key for the local or cloud qdrant cluster.
|
||||
qdrant_collection_name (str, optional): The name for your qdrant collection. Required if type is "qdrant-semantic".
|
||||
similarity_threshold (float, optional): The similarity threshold for semantic-caching, Required if type is "redis-semantic" or "qdrant-semantic".
|
||||
|
||||
# Disk Cache Args
|
||||
disk_cache_dir (str, optional): The directory for the disk cache. Defaults to None.
|
||||
|
||||
# S3 Cache Args
|
||||
s3_bucket_name (str, optional): The bucket name for the s3 cache. Defaults to None.
|
||||
s3_region_name (str, optional): The region name for the s3 cache. Defaults to None.
|
||||
s3_api_version (str, optional): The api version for the s3 cache. Defaults to None.
|
||||
s3_use_ssl (bool, optional): The use ssl for the s3 cache. Defaults to True.
|
||||
s3_verify (bool, optional): The verify for the s3 cache. Defaults to None.
|
||||
s3_endpoint_url (str, optional): The endpoint url for the s3 cache. Defaults to None.
|
||||
s3_aws_access_key_id (str, optional): The aws access key id for the s3 cache. Defaults to None.
|
||||
s3_aws_secret_access_key (str, optional): The aws secret access key for the s3 cache. Defaults to None.
|
||||
s3_aws_session_token (str, optional): The aws session token for the s3 cache. Defaults to None.
|
||||
s3_config (dict, optional): The config for the s3 cache. Defaults to None.
|
||||
|
||||
# GCS Cache Args
|
||||
gcs_bucket_name (str, optional): The bucket name for the gcs cache. Defaults to None.
|
||||
gcs_path_service_account (str, optional): Path to the service account json.
|
||||
gcs_path (str, optional): Folder path inside the bucket to store cache files.
|
||||
|
||||
# Common Cache Args
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == LiteLLMCacheType.REDIS:
|
||||
# Check REDIS_CLUSTER_NODES env var if no explicit startup nodes
|
||||
if not redis_startup_nodes:
|
||||
_env_cluster_nodes = litellm.get_secret("REDIS_CLUSTER_NODES")
|
||||
if _env_cluster_nodes is not None and isinstance(
|
||||
_env_cluster_nodes, str
|
||||
):
|
||||
redis_startup_nodes = json.loads(_env_cluster_nodes)
|
||||
|
||||
if redis_startup_nodes:
|
||||
# Only pass GCP parameters if they are provided
|
||||
cluster_kwargs = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"password": password,
|
||||
"redis_flush_size": redis_flush_size,
|
||||
"startup_nodes": redis_startup_nodes,
|
||||
**kwargs,
|
||||
}
|
||||
if gcp_service_account is not None:
|
||||
cluster_kwargs["gcp_service_account"] = gcp_service_account
|
||||
if gcp_ssl_ca_certs is not None:
|
||||
cluster_kwargs["gcp_ssl_ca_certs"] = gcp_ssl_ca_certs
|
||||
|
||||
self.cache: BaseCache = RedisClusterCache(**cluster_kwargs)
|
||||
else:
|
||||
self.cache = RedisCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
redis_flush_size=redis_flush_size,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.REDIS_SEMANTIC:
|
||||
self.cache = RedisSemanticCache(
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
similarity_threshold=similarity_threshold,
|
||||
embedding_model=redis_semantic_cache_embedding_model,
|
||||
index_name=redis_semantic_cache_index_name,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.QDRANT_SEMANTIC:
|
||||
self.cache = QdrantSemanticCache(
|
||||
qdrant_api_base=qdrant_api_base,
|
||||
qdrant_api_key=qdrant_api_key,
|
||||
collection_name=qdrant_collection_name,
|
||||
similarity_threshold=similarity_threshold,
|
||||
quantization_config=qdrant_quantization_config,
|
||||
embedding_model=qdrant_semantic_cache_embedding_model,
|
||||
vector_size=qdrant_semantic_cache_vector_size,
|
||||
)
|
||||
elif type == LiteLLMCacheType.LOCAL:
|
||||
self.cache = InMemoryCache()
|
||||
elif type == LiteLLMCacheType.S3:
|
||||
self.cache = S3Cache(
|
||||
s3_bucket_name=s3_bucket_name,
|
||||
s3_region_name=s3_region_name,
|
||||
s3_api_version=s3_api_version,
|
||||
s3_use_ssl=s3_use_ssl,
|
||||
s3_verify=s3_verify,
|
||||
s3_endpoint_url=s3_endpoint_url,
|
||||
s3_aws_access_key_id=s3_aws_access_key_id,
|
||||
s3_aws_secret_access_key=s3_aws_secret_access_key,
|
||||
s3_aws_session_token=s3_aws_session_token,
|
||||
s3_config=s3_config,
|
||||
s3_path=s3_path,
|
||||
**kwargs,
|
||||
)
|
||||
elif type == LiteLLMCacheType.GCS:
|
||||
self.cache = GCSCache(
|
||||
bucket_name=gcs_bucket_name,
|
||||
path_service_account=gcs_path_service_account,
|
||||
gcs_path=gcs_path,
|
||||
)
|
||||
elif type == LiteLLMCacheType.AZURE_BLOB:
|
||||
self.cache = AzureBlobCache(
|
||||
account_url=azure_account_url,
|
||||
container=azure_blob_container,
|
||||
)
|
||||
elif type == LiteLLMCacheType.DISK:
|
||||
self.cache = DiskCache(disk_cache_dir=disk_cache_dir)
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
self.type = type
|
||||
self.namespace = namespace
|
||||
self.redis_flush_size = redis_flush_size
|
||||
self.ttl = ttl
|
||||
self.mode: CacheMode = mode or CacheMode.default_on
|
||||
|
||||
if self.type == LiteLLMCacheType.LOCAL and default_in_memory_ttl is not None:
|
||||
self.ttl = default_in_memory_ttl
|
||||
|
||||
if (
|
||||
self.type == LiteLLMCacheType.REDIS
|
||||
or self.type == LiteLLMCacheType.REDIS_SEMANTIC
|
||||
) and default_in_redis_ttl is not None:
|
||||
self.ttl = default_in_redis_ttl
|
||||
|
||||
if self.namespace is not None and isinstance(self.cache, RedisCache):
|
||||
self.cache.namespace = self.namespace
|
||||
|
||||
def get_cache_key(self, **kwargs) -> str:
|
||||
"""
|
||||
Get the cache key for the given arguments.
|
||||
|
||||
Args:
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key = ""
|
||||
# verbose_logger.debug("\nGetting Cache key. Kwargs: %s", kwargs)
|
||||
|
||||
preset_cache_key = self._get_preset_cache_key_from_kwargs(**kwargs)
|
||||
if preset_cache_key is not None:
|
||||
verbose_logger.debug("\nReturning preset cache key: %s", preset_cache_key)
|
||||
return preset_cache_key
|
||||
|
||||
combined_kwargs = ModelParamHelper._get_all_llm_api_params()
|
||||
litellm_param_kwargs = all_litellm_params
|
||||
for param in kwargs:
|
||||
if param in combined_kwargs:
|
||||
param_value: Optional[str] = self._get_param_value(param, kwargs)
|
||||
if param_value is not None:
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
elif (
|
||||
param not in litellm_param_kwargs
|
||||
): # check if user passed in optional param - e.g. top_k
|
||||
if (
|
||||
litellm.enable_caching_on_provider_specific_optional_params is True
|
||||
): # feature flagged for now
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key += f"{str(param)}: {str(param_value)}"
|
||||
|
||||
verbose_logger.debug("\nCreated cache key: %s", cache_key)
|
||||
hashed_cache_key = Cache._get_hashed_cache_key(cache_key)
|
||||
hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs)
|
||||
self._set_preset_cache_key_in_kwargs(
|
||||
preset_cache_key=hashed_cache_key, **kwargs
|
||||
)
|
||||
return hashed_cache_key
|
||||
|
||||
def _get_param_value(
|
||||
self,
|
||||
param: str,
|
||||
kwargs: dict,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get the value for the given param from kwargs
|
||||
"""
|
||||
if param == "model":
|
||||
return self._get_model_param_value(kwargs)
|
||||
elif param == "file":
|
||||
return self._get_file_param_value(kwargs)
|
||||
return kwargs[param]
|
||||
|
||||
def _get_model_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'model' param from kwargs
|
||||
|
||||
1. If caching groups are set, then return the caching group as the model https://docs.litellm.ai/docs/routing#caching-across-model-groups
|
||||
2. Else if a model_group is set, then return the model_group as the model. This is used for all requests sent through the litellm.Router()
|
||||
3. Else use the `model` passed in kwargs
|
||||
"""
|
||||
metadata: Dict = kwargs.get("metadata", {}) or {}
|
||||
litellm_params: Dict = kwargs.get("litellm_params", {}) or {}
|
||||
metadata_in_litellm_params: Dict = litellm_params.get("metadata", {}) or {}
|
||||
model_group: Optional[str] = metadata.get(
|
||||
"model_group"
|
||||
) or metadata_in_litellm_params.get("model_group")
|
||||
caching_group = self._get_caching_group(metadata, model_group)
|
||||
return caching_group or model_group or kwargs["model"]
|
||||
|
||||
def _get_caching_group(
|
||||
self, metadata: dict, model_group: Optional[str]
|
||||
) -> Optional[str]:
|
||||
caching_groups: Optional[List] = metadata.get("caching_groups", [])
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
return str(group)
|
||||
return None
|
||||
|
||||
def _get_file_param_value(self, kwargs: dict) -> str:
|
||||
"""
|
||||
Handles getting the value for the 'file' param from kwargs. Used for `transcription` requests
|
||||
"""
|
||||
file = kwargs.get("file")
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
return (
|
||||
metadata.get("file_checksum")
|
||||
or getattr(file, "name", None)
|
||||
or metadata.get("file_name")
|
||||
or litellm_params.get("file_name")
|
||||
)
|
||||
|
||||
def _get_preset_cache_key_from_kwargs(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Get the preset cache key from kwargs["litellm_params"]
|
||||
|
||||
We use _get_preset_cache_keys for two reasons
|
||||
|
||||
1. optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
2. avoid doing duplicate / repeated work
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
return kwargs["litellm_params"].get("preset_cache_key", None)
|
||||
return None
|
||||
|
||||
def _set_preset_cache_key_in_kwargs(self, preset_cache_key: str, **kwargs) -> None:
|
||||
"""
|
||||
Set the calculated cache key in kwargs
|
||||
|
||||
This is used to avoid doing duplicate / repeated work
|
||||
|
||||
Placed in kwargs["litellm_params"]
|
||||
"""
|
||||
if kwargs:
|
||||
if "litellm_params" in kwargs:
|
||||
kwargs["litellm_params"]["preset_cache_key"] = preset_cache_key
|
||||
|
||||
@staticmethod
|
||||
def _get_hashed_cache_key(cache_key: str) -> str:
|
||||
"""
|
||||
Get the hashed cache key for the given cache key.
|
||||
|
||||
Use hashlib to create a sha256 hash of the cache key
|
||||
|
||||
Args:
|
||||
cache_key (str): The cache key to hash.
|
||||
|
||||
Returns:
|
||||
str: The hashed cache key.
|
||||
"""
|
||||
hash_object = hashlib.sha256(cache_key.encode())
|
||||
# Hexadecimal representation of the hash
|
||||
hash_hex = hash_object.hexdigest()
|
||||
verbose_logger.debug("Hashed cache key (SHA-256): %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def _add_namespace_to_cache_key(self, hash_hex: str, **kwargs) -> str:
|
||||
"""
|
||||
If a redis namespace is provided, add it to the cache key
|
||||
|
||||
Args:
|
||||
hash_hex (str): The hashed cache key.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
str: The final hashed cache key with the redis namespace.
|
||||
"""
|
||||
dynamic_cache_control: DynamicCacheControl = kwargs.get("cache", {})
|
||||
namespace = (
|
||||
dynamic_cache_control.get("namespace")
|
||||
or kwargs.get("metadata", {}).get("redis_namespace")
|
||||
or self.namespace
|
||||
)
|
||||
if namespace:
|
||||
hash_hex = f"{namespace}:{hash_hex}"
|
||||
verbose_logger.debug("Final hashed key: %s", hash_hex)
|
||||
return hash_hex
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
chunk_size = 5 # Adjust the chunk size as needed
|
||||
for i in range(0, len(content), chunk_size):
|
||||
yield {
|
||||
"choices": [
|
||||
{
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
"content": content[i : i + chunk_size],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
time.sleep(CACHED_STREAMING_CHUNK_DELAY)
|
||||
|
||||
def _get_cache_logic(
|
||||
self,
|
||||
cached_result: Optional[Any],
|
||||
max_age: Optional[float],
|
||||
):
|
||||
"""
|
||||
Common get cache logic across sync + async implementations
|
||||
"""
|
||||
# Check if a timestamp was stored with the cached response
|
||||
if (
|
||||
cached_result is not None
|
||||
and isinstance(cached_result, dict)
|
||||
and "timestamp" in cached_result
|
||||
):
|
||||
timestamp = cached_result["timestamp"]
|
||||
current_time = time.time()
|
||||
|
||||
# Calculate age of the cached response
|
||||
response_age = current_time - timestamp
|
||||
|
||||
# Check if the cached response is older than the max-age
|
||||
if max_age is not None and response_age > max_age:
|
||||
return None # Cached response is too old
|
||||
|
||||
# If the response is fresh, or there's no max-age requirement, return the cached response
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = cached_result.get("response")
|
||||
try:
|
||||
if isinstance(cached_response, dict):
|
||||
pass
|
||||
else:
|
||||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
||||
def get_cache(self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs):
|
||||
"""
|
||||
Retrieves the cached result for the given arguments.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
The cached result if it exists, otherwise None.
|
||||
"""
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
messages = kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args: DynamicCacheControl = kwargs.get("cache", {})
|
||||
max_age = (
|
||||
cache_control_args.get("s-maxage")
|
||||
or cache_control_args.get("s-max-age")
|
||||
or float("inf")
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = dynamic_cache_object.get_cache(
|
||||
cache_key, messages=messages
|
||||
)
|
||||
else:
|
||||
cached_result = self.cache.get_cache(cache_key, messages=messages)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
async def async_get_cache(
|
||||
self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async get cache implementation.
|
||||
|
||||
Used for embedding calls in async wrapper
|
||||
"""
|
||||
|
||||
try: # never block execution
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
cache_control_args = kwargs.get("cache", {})
|
||||
max_age = cache_control_args.get(
|
||||
"s-max-age", cache_control_args.get("s-maxage", float("inf"))
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
cached_result = await dynamic_cache_object.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
else:
|
||||
cached_result = await self.cache.async_get_cache(
|
||||
cache_key, **kwargs
|
||||
)
|
||||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
def _add_cache_logic(self, result, **kwargs):
|
||||
"""
|
||||
Common implementation across sync + async add_cache functions
|
||||
"""
|
||||
try:
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
cache_key = self.get_cache_key(**kwargs)
|
||||
if cache_key is not None:
|
||||
if isinstance(result, BaseModel):
|
||||
result = result.model_dump_json()
|
||||
|
||||
## DEFAULT TTL ##
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
## Get Cache-Controls ##
|
||||
_cache_kwargs = kwargs.get("cache", None)
|
||||
if isinstance(_cache_kwargs, dict):
|
||||
for k, v in _cache_kwargs.items():
|
||||
if k == "ttl":
|
||||
kwargs["ttl"] = v
|
||||
|
||||
cached_data = {"timestamp": time.time(), "response": result}
|
||||
return cache_key, cached_data, kwargs
|
||||
else:
|
||||
raise Exception("cache key is None")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def add_cache(self, result, **kwargs):
|
||||
"""
|
||||
Adds a result to the cache.
|
||||
|
||||
Args:
|
||||
*args: args to litellm.completion() or embedding()
|
||||
**kwargs: kwargs to litellm.completion() or embedding()
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
self.cache.set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
async def async_add_cache(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
if self.type == "redis" and self.redis_flush_size is not None:
|
||||
# high traffic - fill in results in memory and then flush
|
||||
await self.batch_cache_write(result, **kwargs)
|
||||
else:
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=result, **kwargs
|
||||
)
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache(
|
||||
cache_key, cached_data, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache(cache_key, cached_data, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def _convert_to_cached_embedding(
|
||||
self, embedding_response: Any, model: Optional[str]
|
||||
) -> CachedEmbedding:
|
||||
"""
|
||||
Convert any embedding response into the standardized CachedEmbedding TypedDict format.
|
||||
"""
|
||||
try:
|
||||
if isinstance(embedding_response, dict):
|
||||
return {
|
||||
"embedding": embedding_response.get("embedding"),
|
||||
"index": embedding_response.get("index"),
|
||||
"object": embedding_response.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
elif hasattr(embedding_response, "model_dump"):
|
||||
data = embedding_response.model_dump()
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
else:
|
||||
data = vars(embedding_response)
|
||||
return {
|
||||
"embedding": data.get("embedding"),
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
}
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing expected key in embedding response: {e}")
|
||||
|
||||
def add_embedding_response_to_cache(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
input: str,
|
||||
kwargs: dict,
|
||||
idx_in_result_data: int = 0,
|
||||
) -> Tuple[str, dict, dict]:
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx_in_result_data]
|
||||
|
||||
# Always convert to properly typed CachedEmbedding
|
||||
model_name = result.model
|
||||
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
||||
embedding_response, model_name
|
||||
)
|
||||
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_dict,
|
||||
**kwargs,
|
||||
)
|
||||
return cache_key, cached_data, kwargs
|
||||
|
||||
async def async_add_cache_pipeline(
|
||||
self, result, dynamic_cache_object: Optional[BaseCache] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
|
||||
Does a bulk write, to prevent using too many clients
|
||||
"""
|
||||
try:
|
||||
if self.should_use_cache(**kwargs) is not True:
|
||||
return
|
||||
|
||||
# set default ttl if not set
|
||||
if self.ttl is not None:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
if isinstance(kwargs["input"], list):
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
(
|
||||
cache_key,
|
||||
cached_data,
|
||||
kwargs,
|
||||
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
elif isinstance(kwargs["input"], str):
|
||||
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
|
||||
result, kwargs["input"], kwargs
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
if dynamic_cache_object is not None:
|
||||
await dynamic_cache_object.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
else:
|
||||
await self.cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def should_use_cache(self, **kwargs):
|
||||
"""
|
||||
Returns true if we should use the cache for LLM API calls
|
||||
|
||||
If cache is default_on then this is True
|
||||
If cache is default_off then this is only true when user has opted in to use cache
|
||||
"""
|
||||
if self.mode == CacheMode.default_on:
|
||||
return True
|
||||
|
||||
# when mode == default_off -> Cache is opt in only
|
||||
_cache = kwargs.get("cache", None)
|
||||
verbose_logger.debug("should_use_cache: kwargs: %s; _cache: %s", kwargs, _cache)
|
||||
if _cache and isinstance(_cache, dict):
|
||||
if _cache.get("use-cache", False) is True:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def batch_cache_write(self, result, **kwargs):
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(result=result, **kwargs)
|
||||
await self.cache.batch_cache_write(cache_key, cached_data, **kwargs)
|
||||
|
||||
async def ping(self):
|
||||
cache_ping = getattr(self.cache, "ping")
|
||||
if cache_ping:
|
||||
return await cache_ping()
|
||||
return None
|
||||
|
||||
async def delete_cache_keys(self, keys):
|
||||
cache_delete_cache_keys = getattr(self.cache, "delete_cache_keys")
|
||||
if cache_delete_cache_keys:
|
||||
return await cache_delete_cache_keys(keys)
|
||||
return None
|
||||
|
||||
async def disconnect(self):
|
||||
if hasattr(self.cache, "disconnect"):
|
||||
await self.cache.disconnect()
|
||||
|
||||
def _supports_async(self) -> bool:
|
||||
"""
|
||||
Internal method to check if the cache type supports async get/set operations
|
||||
|
||||
All cache types now support async operations
|
||||
|
||||
"""
|
||||
return True
|
||||
|
||||
|
||||
def enable_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Enable cache with the specified configuration.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache to enable. Defaults to "local".
|
||||
host (Optional[str]): The host address of the cache server. Defaults to None.
|
||||
port (Optional[str]): The port number of the cache server. Defaults to None.
|
||||
password (Optional[str]): The password for the cache server. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Raises:
|
||||
None
|
||||
"""
|
||||
print_verbose("LiteLLM: Enabling Cache")
|
||||
if "cache" not in litellm.input_callback:
|
||||
litellm.input_callback.append("cache")
|
||||
if "cache" not in litellm.success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_success_callback("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm.logging_callback_manager.add_litellm_async_success_callback("cache")
|
||||
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache enabled, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def update_cache(
|
||||
type: Optional[LiteLLMCacheType] = LiteLLMCacheType.LOCAL,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[CachingSupportedCallTypes]] = [
|
||||
"completion",
|
||||
"acompletion",
|
||||
"embedding",
|
||||
"aembedding",
|
||||
"atranscription",
|
||||
"transcription",
|
||||
"atext_completion",
|
||||
"text_completion",
|
||||
"arerank",
|
||||
"rerank",
|
||||
"responses",
|
||||
"aresponses",
|
||||
],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Update the cache for LiteLLM.
|
||||
|
||||
Args:
|
||||
type (Optional[Literal["local", "redis", "s3", "disk"]]): The type of cache. Defaults to "local".
|
||||
host (Optional[str]): The host of the cache. Defaults to None.
|
||||
port (Optional[str]): The port of the cache. Defaults to None.
|
||||
password (Optional[str]): The password for the cache. Defaults to None.
|
||||
supported_call_types (Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]]):
|
||||
The supported call types for the cache. Defaults to ["completion", "acompletion", "embedding", "aembedding"].
|
||||
**kwargs: Additional keyword arguments for the cache.
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
"""
|
||||
print_verbose("LiteLLM: Updating Cache")
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
port=port,
|
||||
password=password,
|
||||
supported_call_types=supported_call_types,
|
||||
**kwargs,
|
||||
)
|
||||
print_verbose(f"LiteLLM: Cache Updated, litellm.cache={litellm.cache}")
|
||||
print_verbose(f"LiteLLM Cache: {vars(litellm.cache)}")
|
||||
|
||||
|
||||
def disable_cache():
|
||||
"""
|
||||
Disable the cache used by LiteLLM.
|
||||
|
||||
This function disables the cache used by the LiteLLM module. It removes the cache-related callbacks from the input_callback, success_callback, and _async_success_callback lists. It also sets the litellm.cache attribute to None.
|
||||
|
||||
Parameters:
|
||||
None
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
from contextlib import suppress
|
||||
|
||||
print_verbose("LiteLLM: Disabling Cache")
|
||||
with suppress(ValueError):
|
||||
litellm.input_callback.remove("cache")
|
||||
litellm.success_callback.remove("cache")
|
||||
litellm._async_success_callback.remove("cache")
|
||||
|
||||
litellm.cache = None
|
||||
print_verbose(f"LiteLLM: Cache disabled, litellm.cache={litellm.cache}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,93 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class DiskCache(BaseCache):
|
||||
def __init__(self, disk_cache_dir: Optional[str] = None):
|
||||
try:
|
||||
import diskcache as dc
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
"Please install litellm with `litellm[caching]` to use disk caching."
|
||||
) from e
|
||||
|
||||
# if users don't provider one, use the default litellm cache
|
||||
if disk_cache_dir is None:
|
||||
self.disk_cache = dc.Cache(".litellm_cache")
|
||||
else:
|
||||
self.disk_cache = dc.Cache(disk_cache_dir)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
if "ttl" in kwargs:
|
||||
self.disk_cache.set(key, value, expire=kwargs["ttl"])
|
||||
else:
|
||||
self.disk_cache.set(key, value)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if "ttl" in kwargs:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=kwargs["ttl"])
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
original_cached_response = self.disk_cache.get(key)
|
||||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response) # type: ignore
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value # type: ignore
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
def flush_cache(self):
|
||||
self.disk_cache.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self.disk_cache.pop(key)
|
||||
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
Dual Cache implementation - Class to update both Redis and an in-memory cache simultaneously.
|
||||
|
||||
Has 4 primary methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.constants import DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .in_memory_cache import InMemoryCache
|
||||
from .redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class LimitedSizeOrderedDict(OrderedDict):
|
||||
def __init__(self, *args, max_size=100, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = max_size
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# If inserting a new key exceeds max size, remove the oldest item
|
||||
if len(self) >= self.max_size:
|
||||
self.popitem(last=False)
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
class DualCache(BaseCache):
|
||||
"""
|
||||
DualCache is a cache implementation that updates both Redis and an in-memory cache simultaneously.
|
||||
When data is updated or inserted, it is written to both the in-memory cache + Redis.
|
||||
This ensures that even if Redis hasn't been updated yet, the in-memory cache reflects the most recent data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_memory_cache: Optional[InMemoryCache] = None,
|
||||
redis_cache: Optional[RedisCache] = None,
|
||||
default_in_memory_ttl: Optional[float] = None,
|
||||
default_redis_ttl: Optional[float] = None,
|
||||
default_redis_batch_cache_expiry: Optional[float] = None,
|
||||
default_max_redis_batch_cache_size: int = DEFAULT_MAX_REDIS_BATCH_CACHE_SIZE,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# If in_memory_cache is not provided, use the default InMemoryCache
|
||||
self.in_memory_cache = in_memory_cache or InMemoryCache()
|
||||
# If redis_cache is not provided, use the default RedisCache
|
||||
self.redis_cache = redis_cache
|
||||
self.last_redis_batch_access_time = LimitedSizeOrderedDict(
|
||||
max_size=default_max_redis_batch_cache_size
|
||||
)
|
||||
self._last_redis_batch_access_time_lock = Lock()
|
||||
self.redis_batch_cache_expiry = (
|
||||
default_redis_batch_cache_expiry
|
||||
or litellm.default_redis_batch_cache_expiry
|
||||
or 10
|
||||
)
|
||||
self.default_in_memory_ttl = (
|
||||
default_in_memory_ttl or litellm.default_in_memory_ttl
|
||||
)
|
||||
self.default_redis_ttl = default_redis_ttl or litellm.default_redis_ttl
|
||||
|
||||
def update_cache_ttl(
|
||||
self, default_in_memory_ttl: Optional[float], default_redis_ttl: Optional[float]
|
||||
):
|
||||
if default_in_memory_ttl is not None:
|
||||
self.default_in_memory_ttl = default_in_memory_ttl
|
||||
|
||||
if default_redis_ttl is not None:
|
||||
self.default_redis_ttl = default_redis_ttl
|
||||
|
||||
def set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
# Update both Redis and in-memory cache
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
||||
def increment_cache(
|
||||
self, key, value: int, local_only: bool = False, **kwargs
|
||||
) -> int:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - int - the value you want to increment by
|
||||
|
||||
Returns - int - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: int = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"LiteLLM Cache: Excepton async add_cache: {str(e)}")
|
||||
raise e
|
||||
|
||||
def get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = self.in_memory_cache.get_cache(key, **kwargs)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
self.in_memory_cache.set_cache(key, redis_result, **kwargs)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
received_args = locals()
|
||||
received_args.pop("self")
|
||||
|
||||
def run_in_new_loop():
|
||||
"""Run the coroutine in a new event loop within this thread."""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(new_loop)
|
||||
return new_loop.run_until_complete(
|
||||
self.async_batch_get_cache(**received_args)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
try:
|
||||
# First, try to get the current event loop
|
||||
_ = asyncio.get_running_loop()
|
||||
# If we're already in an event loop, run in a separate thread
|
||||
# to avoid nested event loop issues
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result()
|
||||
|
||||
except RuntimeError:
|
||||
# No running event loop, we can safely run in this thread
|
||||
return run_in_new_loop()
|
||||
|
||||
async def async_get_cache(
|
||||
self,
|
||||
key,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Try to fetch from in-memory cache first
|
||||
try:
|
||||
print_verbose(
|
||||
f"async get cache: cache key: {key}; local_only: {local_only}"
|
||||
)
|
||||
result = None
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_get_cache(
|
||||
key, **kwargs
|
||||
)
|
||||
|
||||
print_verbose(f"in_memory_result: {in_memory_result}")
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(
|
||||
key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
|
||||
if redis_result is not None:
|
||||
# Update in-memory cache with the value from Redis
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, redis_result, **kwargs
|
||||
)
|
||||
|
||||
result = redis_result
|
||||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def _reserve_redis_batch_keys(
|
||||
self,
|
||||
current_time: float,
|
||||
keys: List[str],
|
||||
result: List[Any],
|
||||
) -> Tuple[List[str], Dict[str, Optional[float]]]:
|
||||
"""
|
||||
Atomically choose keys to fetch from Redis and reserve their access time.
|
||||
This prevents check-then-act races under concurrent async callers.
|
||||
"""
|
||||
sublist_keys: List[str] = []
|
||||
previous_access_times: Dict[str, Optional[float]] = {}
|
||||
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, value in zip(keys, result):
|
||||
if value is not None:
|
||||
continue
|
||||
|
||||
if (
|
||||
key not in self.last_redis_batch_access_time
|
||||
or current_time - self.last_redis_batch_access_time[key]
|
||||
>= self.redis_batch_cache_expiry
|
||||
):
|
||||
sublist_keys.append(key)
|
||||
previous_access_times[key] = self.last_redis_batch_access_time.get(
|
||||
key
|
||||
)
|
||||
self.last_redis_batch_access_time[key] = current_time
|
||||
|
||||
return sublist_keys, previous_access_times
|
||||
|
||||
def _rollback_redis_batch_key_reservations(
|
||||
self, previous_access_times: Dict[str, Optional[float]]
|
||||
) -> None:
|
||||
with self._last_redis_batch_access_time_lock:
|
||||
for key, previous_time in previous_access_times.items():
|
||||
if previous_time is None:
|
||||
self.last_redis_batch_access_time.pop(key, None)
|
||||
else:
|
||||
self.last_redis_batch_access_time[key] = previous_time
|
||||
|
||||
async def async_batch_get_cache(
|
||||
self,
|
||||
keys: list,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
result = [None] * len(keys)
|
||||
if self.in_memory_cache is not None:
|
||||
in_memory_result = await self.in_memory_cache.async_batch_get_cache(
|
||||
keys, **kwargs
|
||||
)
|
||||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
"""
|
||||
current_time = time.time()
|
||||
sublist_keys, previous_access_times = self._reserve_redis_batch_keys(
|
||||
current_time, keys, result
|
||||
)
|
||||
|
||||
# Only hit Redis if enough time has passed since last access.
|
||||
if len(sublist_keys) > 0:
|
||||
try:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_batch_get_cache(
|
||||
sublist_keys, parent_otel_span=parent_otel_span
|
||||
)
|
||||
except Exception:
|
||||
# Do not throttle subsequent callers if the Redis read fails.
|
||||
self._rollback_redis_batch_key_reservations(
|
||||
previous_access_times
|
||||
)
|
||||
raise
|
||||
|
||||
# Short-circuit if redis_result is None or contains only None values
|
||||
if redis_result is None or all(
|
||||
v is None for v in redis_result.values()
|
||||
):
|
||||
return result
|
||||
|
||||
# Pre-compute key-to-index mapping for O(1) lookup
|
||||
key_to_index = {key: i for i, key in enumerate(keys)}
|
||||
|
||||
# Update both result and in-memory cache in a single loop
|
||||
for key, value in redis_result.items():
|
||||
result[key_to_index[key]] = value
|
||||
|
||||
if value is not None and self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
print_verbose(
|
||||
f"async set cache: cache key: {key}; local_only: {local_only}; value: {value}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
# async_batch_set_cache
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: list, local_only: bool = False, **kwargs
|
||||
):
|
||||
"""
|
||||
Batch write values to the cache
|
||||
"""
|
||||
print_verbose(
|
||||
f"async batch set cache: cache keys: {cache_list}; local_only: {local_only}"
|
||||
)
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
if "ttl" not in kwargs and self.default_in_memory_ttl is not None:
|
||||
kwargs["ttl"] = self.default_in_memory_ttl
|
||||
await self.in_memory_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"LiteLLM Cache: Excepton async add_cache: {str(e)}"
|
||||
)
|
||||
|
||||
async def async_increment_cache(
|
||||
self,
|
||||
key,
|
||||
value: float,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
local_only: bool = False,
|
||||
**kwargs,
|
||||
) -> float:
|
||||
"""
|
||||
Key - the key in cache
|
||||
|
||||
Value - float - the value you want to increment by
|
||||
|
||||
Returns - float - the incremented value
|
||||
"""
|
||||
try:
|
||||
result: float = value
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment(
|
||||
key, value, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment(
|
||||
key,
|
||||
value,
|
||||
parent_otel_span=parent_otel_span,
|
||||
ttl=kwargs.get("ttl", None),
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_increment_cache_pipeline(
|
||||
self,
|
||||
increment_list: List["RedisPipelineIncrementOperation"],
|
||||
local_only: bool = False,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
**kwargs,
|
||||
) -> Optional[List[float]]:
|
||||
try:
|
||||
result: Optional[List[float]] = None
|
||||
if self.in_memory_cache is not None:
|
||||
result = await self.in_memory_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = await self.redis_cache.async_increment_pipeline(
|
||||
increment_list=increment_list,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
raise e # don't log if exception is raised
|
||||
|
||||
async def async_set_cache_sadd(
|
||||
self, key, value: List, local_only: bool = False, **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Add value to a set
|
||||
|
||||
Key - the key in cache
|
||||
|
||||
Value - str - the value you want to add to the set
|
||||
|
||||
Returns - None
|
||||
"""
|
||||
try:
|
||||
if self.in_memory_cache is not None:
|
||||
_ = await self.in_memory_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
_ = await self.redis_cache.async_set_cache_sadd(
|
||||
key, value, ttl=kwargs.get("ttl", None)
|
||||
)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
raise e # don't log, if exception is raised
|
||||
|
||||
def flush_cache(self):
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.flush_cache()
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
def delete_cache(self, key):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
self.redis_cache.delete_cache(key)
|
||||
|
||||
async def async_delete_cache(self, key: str):
|
||||
"""
|
||||
Delete a key from the cache
|
||||
"""
|
||||
if self.in_memory_cache is not None:
|
||||
self.in_memory_cache.delete_cache(key)
|
||||
if self.redis_cache is not None:
|
||||
await self.redis_cache.async_delete_cache(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache or redis
|
||||
"""
|
||||
ttl = await self.in_memory_cache.async_get_ttl(key)
|
||||
if ttl is None and self.redis_cache is not None:
|
||||
ttl = await self.redis_cache.async_get_ttl(key)
|
||||
return ttl
|
||||
@@ -0,0 +1,113 @@
|
||||
"""GCS Cache implementation
|
||||
Supports syncing responses to Google Cloud Storage Buckets using HTTP requests.
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
_get_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class GCSCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
bucket_name: Optional[str] = None,
|
||||
path_service_account: Optional[str] = None,
|
||||
gcs_path: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.bucket_name = bucket_name or GCSBucketBase(bucket_name=None).BUCKET_NAME
|
||||
self.path_service_account = (
|
||||
path_service_account
|
||||
or GCSBucketBase(bucket_name=None).path_service_account_json
|
||||
)
|
||||
self.key_prefix = gcs_path.rstrip("/") + "/" if gcs_path else ""
|
||||
# create httpx clients
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
self.sync_client = _get_httpx_client()
|
||||
|
||||
def _construct_headers(self) -> dict:
|
||||
base = GCSBucketBase(bucket_name=self.bucket_name)
|
||||
base.path_service_account_json = self.path_service_account
|
||||
base.BUCKET_NAME = self.bucket_name
|
||||
return base.sync_construct_request_headers()
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - GCS. Key={key}. Value={value}")
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
self.sync_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(f"GCS Caching: set_cache() - Got exception from GCS: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/upload/storage/v1/b/{bucket_name}/o?uploadType=media&name={object_name}"
|
||||
data = json.dumps(value)
|
||||
await self.async_client.post(url=url, data=data, headers=headers)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"GCS Caching: async_set_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = self.sync_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
cached_response = json.loads(response.text)
|
||||
verbose_logger.debug(
|
||||
f"Got GCS Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
return cached_response
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
try:
|
||||
headers = self._construct_headers()
|
||||
object_name = self.key_prefix + key
|
||||
bucket_name = self.bucket_name
|
||||
url = f"https://storage.googleapis.com/storage/v1/b/{bucket_name}/o/{object_name}?alt=media"
|
||||
response = await self.async_client.get(url=url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
return json.loads(response.text)
|
||||
return None
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"GCS Caching: async_get_cache() - Got exception from GCS: {e}"
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
In-Memory Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
import heapq
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.caching import RedisPipelineIncrementOperation
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm.constants import MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class InMemoryCache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
max_size_in_memory: Optional[int] = 200,
|
||||
default_ttl: Optional[
|
||||
int
|
||||
] = 600, # default ttl is 10 minutes. At maximum litellm rate limiting logic requires objects to be in memory for 1 minute
|
||||
max_size_per_item: Optional[int] = 1024, # 1MB = 1024KB
|
||||
):
|
||||
"""
|
||||
max_size_in_memory [int]: Maximum number of items in cache. done to prevent memory leaks. Use 200 items as a default
|
||||
"""
|
||||
self.max_size_in_memory = (
|
||||
max_size_in_memory if max_size_in_memory is not None else 200
|
||||
) # set an upper bound of 200 items in-memory
|
||||
self.default_ttl = default_ttl or 600
|
||||
self.max_size_per_item = (
|
||||
max_size_per_item or MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
) # 1MB = 1024KB
|
||||
|
||||
# in-memory cache
|
||||
self.cache_dict: dict = {}
|
||||
self.ttl_dict: dict = {}
|
||||
self.expiration_heap: list[tuple[float, str]] = []
|
||||
|
||||
def check_value_size(self, value: Any):
|
||||
"""
|
||||
Check if value size exceeds max_size_per_item (1MB)
|
||||
Returns True if value size is acceptable, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Fast path for common primitive types that are typically small
|
||||
if (
|
||||
isinstance(value, (bool, int, float, str))
|
||||
and len(str(value))
|
||||
< self.max_size_per_item * MAX_SIZE_PER_ITEM_IN_MEMORY_CACHE_IN_KB
|
||||
): # Conservative estimate
|
||||
return True
|
||||
|
||||
# Direct size check for bytes objects
|
||||
if isinstance(value, bytes):
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
# Handle special types without full conversion when possible
|
||||
if hasattr(value, "__sizeof__"): # Use __sizeof__ if available
|
||||
size = value.__sizeof__() / 1024
|
||||
return size <= self.max_size_per_item
|
||||
|
||||
# Fallback for complex types
|
||||
if isinstance(value, BaseModel) and hasattr(
|
||||
value, "model_dump"
|
||||
): # Pydantic v2
|
||||
value = value.model_dump()
|
||||
elif hasattr(value, "isoformat"): # datetime objects
|
||||
return True # datetime strings are always small
|
||||
|
||||
# Only convert to JSON if absolutely necessary
|
||||
if not isinstance(value, (str, bytes)):
|
||||
value = json.dumps(value, default=str)
|
||||
|
||||
return sys.getsizeof(value) / 1024 <= self.max_size_per_item
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _is_key_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Check if a specific key is expired
|
||||
"""
|
||||
return key in self.ttl_dict and time.time() > self.ttl_dict[key]
|
||||
|
||||
def _remove_key(self, key: str) -> None:
|
||||
"""
|
||||
Remove a key from both cache_dict and ttl_dict
|
||||
"""
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
def evict_cache(self):
|
||||
"""
|
||||
Eviction policy:
|
||||
1. First, remove expired items from ttl_dict and cache_dict
|
||||
2. If cache is still at or above max_size_in_memory, evict items with earliest expiration times
|
||||
|
||||
|
||||
This guarantees the following:
|
||||
- 1. When item ttl not set: At minimum each item will remain in memory for the default ttl
|
||||
- 2. When ttl is set: the item will remain in memory for at least that amount of time, unless cache size requires eviction
|
||||
- 3. the size of in-memory cache is bounded
|
||||
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Step 1: Remove expired or outdated items
|
||||
while self.expiration_heap:
|
||||
expiration_time, key = self.expiration_heap[0]
|
||||
|
||||
# Case 1: Heap entry is outdated
|
||||
if expiration_time != self.ttl_dict.get(key):
|
||||
heapq.heappop(self.expiration_heap)
|
||||
# Case 2: Entry is valid but expired
|
||||
elif expiration_time <= current_time:
|
||||
heapq.heappop(self.expiration_heap)
|
||||
self._remove_key(key)
|
||||
else:
|
||||
# Case 3: Entry is valid and not expired
|
||||
break
|
||||
|
||||
# Step 2: Evict if cache is still full
|
||||
while len(self.cache_dict) >= self.max_size_in_memory:
|
||||
expiration_time, key = heapq.heappop(self.expiration_heap)
|
||||
# Skip if key was removed or updated
|
||||
if self.ttl_dict.get(key) == expiration_time:
|
||||
self._remove_key(key)
|
||||
|
||||
# de-reference the removed item
|
||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||
|
||||
def allow_ttl_override(self, key: str) -> bool:
|
||||
"""
|
||||
Check if ttl is set for a key
|
||||
"""
|
||||
ttl_time = self.ttl_dict.get(key)
|
||||
if ttl_time is None: # if ttl is not set, allow override
|
||||
return True
|
||||
elif float(ttl_time) < time.time(): # if ttl is expired, allow override
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
# Handle the edge case where max_size_in_memory is 0
|
||||
if self.max_size_in_memory == 0:
|
||||
return # Don't cache anything if max size is 0
|
||||
|
||||
if len(self.cache_dict) >= self.max_size_in_memory:
|
||||
# only evict when cache is full
|
||||
self.evict_cache()
|
||||
if not self.check_value_size(value):
|
||||
return
|
||||
|
||||
self.cache_dict[key] = value
|
||||
if self.allow_ttl_override(key): # if ttl is not set, set it to default ttl
|
||||
if "ttl" in kwargs and kwargs["ttl"] is not None:
|
||||
self.ttl_dict[key] = time.time() + float(kwargs["ttl"])
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
else:
|
||||
self.ttl_dict[key] = time.time() + self.default_ttl
|
||||
heapq.heappush(self.expiration_heap, (self.ttl_dict[key], key))
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
self.set_cache(key=key, value=value, **kwargs)
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, ttl=None, **kwargs):
|
||||
for cache_key, cache_value in cache_list:
|
||||
if ttl is not None:
|
||||
self.set_cache(key=cache_key, value=cache_value, ttl=ttl)
|
||||
else:
|
||||
self.set_cache(key=cache_key, value=cache_value)
|
||||
|
||||
async def async_set_cache_sadd(self, key, value: List, ttl: Optional[float]):
|
||||
"""
|
||||
Add value to set
|
||||
"""
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or set()
|
||||
for val in value:
|
||||
init_value.add(val)
|
||||
self.set_cache(key, init_value, ttl=ttl)
|
||||
return value
|
||||
|
||||
def evict_element_if_expired(self, key: str) -> bool:
|
||||
"""
|
||||
Returns True if the element is expired and removed from the cache
|
||||
|
||||
Returns False if the element is not expired
|
||||
"""
|
||||
if self._is_key_expired(key):
|
||||
self._remove_key(key)
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
if key in self.cache_dict:
|
||||
if self.evict_element_if_expired(key):
|
||||
return None
|
||||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
||||
def batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
def increment_cache(self, key, value: int, **kwargs) -> int:
|
||||
# get the value
|
||||
init_value = self.get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
self.set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
return self.get_cache(key=key, **kwargs)
|
||||
|
||||
async def async_batch_get_cache(self, keys: list, **kwargs):
|
||||
return_val = []
|
||||
for k in keys:
|
||||
val = self.get_cache(key=k, **kwargs)
|
||||
return_val.append(val)
|
||||
return return_val
|
||||
|
||||
async def async_increment(self, key, value: float, **kwargs) -> float:
|
||||
# get the value
|
||||
init_value = await self.async_get_cache(key=key) or 0
|
||||
value = init_value + value
|
||||
await self.async_set_cache(key, value, **kwargs)
|
||||
return value
|
||||
|
||||
async def async_increment_pipeline(
|
||||
self, increment_list: List["RedisPipelineIncrementOperation"], **kwargs
|
||||
) -> Optional[List[float]]:
|
||||
results = []
|
||||
for increment in increment_list:
|
||||
result = await self.async_increment(
|
||||
increment["key"], increment["increment_value"], **kwargs
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
def flush_cache(self):
|
||||
self.cache_dict.clear()
|
||||
self.ttl_dict.clear()
|
||||
self.expiration_heap.clear()
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
def delete_cache(self, key):
|
||||
self._remove_key(key)
|
||||
|
||||
async def async_get_ttl(self, key: str) -> Optional[int]:
|
||||
"""
|
||||
Get the remaining TTL of a key in in-memory cache
|
||||
"""
|
||||
return self.ttl_dict.get(key, None)
|
||||
|
||||
async def async_get_oldest_n_keys(self, n: int) -> List[str]:
|
||||
"""
|
||||
Get the oldest n keys in the cache
|
||||
"""
|
||||
# sorted ttl dict by ttl
|
||||
sorted_ttl_dict = sorted(self.ttl_dict.items(), key=lambda x: x[1])
|
||||
return [key for key, _ in sorted_ttl_dict[:n]]
|
||||
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from .in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class LLMClientCache(InMemoryCache):
|
||||
"""Cache for LLM HTTP clients (OpenAI, Azure, httpx, etc.).
|
||||
|
||||
IMPORTANT: This cache intentionally does NOT close clients on eviction.
|
||||
Evicted clients may still be in use by in-flight requests. Closing them
|
||||
eagerly causes ``RuntimeError: Cannot send a request, as the client has
|
||||
been closed.`` errors in production after the TTL (1 hour) expires.
|
||||
|
||||
Clients that are no longer referenced will be garbage-collected normally.
|
||||
For explicit shutdown cleanup, use ``close_litellm_async_clients()``.
|
||||
"""
|
||||
|
||||
def update_cache_key_with_event_loop(self, key):
|
||||
"""
|
||||
Add the event loop to the cache key, to prevent event loop closed errors.
|
||||
If none, use the key as is.
|
||||
"""
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
stringified_event_loop = str(id(event_loop))
|
||||
return f"{key}-{stringified_event_loop}"
|
||||
except RuntimeError: # handle no current running event loop
|
||||
return key
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return super().set_cache(key, value, **kwargs)
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
return await super().async_set_cache(key, value, **kwargs)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return super().get_cache(key, **kwargs)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
key = self.update_cache_key_with_event_loop(key)
|
||||
|
||||
return await super().async_get_cache(key, **kwargs)
|
||||
@@ -0,0 +1,446 @@
|
||||
"""
|
||||
Qdrant Semantic Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache
|
||||
- async_get_cache
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
def __init__( # noqa: PLR0915
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
qdrant_api_key=None,
|
||||
collection_name=None,
|
||||
similarity_threshold=None,
|
||||
quantization_config=None,
|
||||
embedding_model="text-embedding-ada-002",
|
||||
host_type=None,
|
||||
vector_size=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
if collection_name is None:
|
||||
raise Exception("collection_name must be provided, passed None")
|
||||
|
||||
self.collection_name = collection_name
|
||||
print_verbose(
|
||||
f"qdrant semantic-cache initializing COLLECTION - {self.collection_name}"
|
||||
)
|
||||
|
||||
if similarity_threshold is None:
|
||||
raise Exception("similarity_threshold must be provided, passed None")
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
self.vector_size = (
|
||||
vector_size if vector_size is not None else QDRANT_VECTOR_SIZE
|
||||
)
|
||||
headers = {}
|
||||
|
||||
# check if defined as os.environ/ variable
|
||||
if qdrant_api_base:
|
||||
if isinstance(qdrant_api_base, str) and qdrant_api_base.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_base = get_secret_str(qdrant_api_base)
|
||||
if qdrant_api_key:
|
||||
if isinstance(qdrant_api_key, str) and qdrant_api_key.startswith(
|
||||
"os.environ/"
|
||||
):
|
||||
qdrant_api_key = get_secret_str(qdrant_api_key)
|
||||
|
||||
qdrant_api_base = (
|
||||
qdrant_api_base or os.getenv("QDRANT_URL") or os.getenv("QDRANT_API_BASE")
|
||||
)
|
||||
qdrant_api_key = qdrant_api_key or os.getenv("QDRANT_API_KEY")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if qdrant_api_key:
|
||||
headers["api-key"] = qdrant_api_key
|
||||
|
||||
if qdrant_api_base is None:
|
||||
raise ValueError("Qdrant url must be provided")
|
||||
|
||||
self.qdrant_api_base = qdrant_api_base
|
||||
self.qdrant_api_key = qdrant_api_key
|
||||
print_verbose(f"qdrant semantic-cache qdrant_api_base: {self.qdrant_api_base}")
|
||||
|
||||
self.headers = headers
|
||||
|
||||
self.sync_client = _get_httpx_client()
|
||||
self.async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Caching
|
||||
)
|
||||
|
||||
if quantization_config is None:
|
||||
print_verbose(
|
||||
"Quantization config is not provided. Default binary quantization will be used."
|
||||
)
|
||||
collection_exists = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/exists",
|
||||
headers=self.headers,
|
||||
)
|
||||
if collection_exists.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Error from qdrant checking if /collections exist {collection_exists.text}"
|
||||
)
|
||||
|
||||
if collection_exists.json()["result"]["exists"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "scalar":
|
||||
quantization_params = {
|
||||
"scalar": {
|
||||
"type": "int8",
|
||||
"quantile": QDRANT_SCALAR_QUANTILE,
|
||||
"always_ram": False,
|
||||
}
|
||||
}
|
||||
elif quantization_config == "product":
|
||||
quantization_params = {
|
||||
"product": {"compression": "x16", "always_ram": False}
|
||||
}
|
||||
else:
|
||||
raise Exception(
|
||||
"Quantization config must be one of 'scalar', 'binary' or 'product'"
|
||||
)
|
||||
|
||||
new_collection_status = self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
json={
|
||||
"vectors": {"size": self.vector_size, "distance": "Cosine"},
|
||||
"quantization_config": quantization_params,
|
||||
},
|
||||
headers=self.headers,
|
||||
)
|
||||
if new_collection_status.json()["result"]:
|
||||
collection_details = self.sync_client.get(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}",
|
||||
headers=self.headers,
|
||||
)
|
||||
self.collection_info = collection_details.json()
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any):
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
from litellm._uuid import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
print_verbose(f"sync qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm._uuid import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
value = str(value)
|
||||
assert isinstance(value, str)
|
||||
|
||||
data = {
|
||||
"points": [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
await self.async_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
|
||||
data = {
|
||||
"vector": embedding,
|
||||
"params": {
|
||||
"quantization": {
|
||||
"ignore": False,
|
||||
"rescore": True,
|
||||
"oversampling": 3.0,
|
||||
}
|
||||
},
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
)
|
||||
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
return self._get_cache_logic(cached_response=cached_value)
|
||||
else:
|
||||
# cache miss !
|
||||
return None
|
||||
pass
|
||||
|
||||
async def _collection_info(self):
|
||||
return self.collection_info
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Redis Cluster Cache implementation
|
||||
|
||||
Key differences:
|
||||
- RedisClient NEEDs to be re-used across requests, adds 3000ms latency if it's re-created
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
|
||||
from litellm.caching.redis_cache import RedisCache
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
from redis.asyncio import Redis, RedisCluster
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
pipeline = Pipeline
|
||||
async_redis_client = Redis
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
pipeline = Any
|
||||
async_redis_client = Any
|
||||
Span = Any
|
||||
|
||||
|
||||
class RedisClusterCache(RedisCache):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.redis_async_redis_cluster_client: Optional[RedisCluster] = None
|
||||
self.redis_sync_redis_cluster_client: Optional[RedisCluster] = None
|
||||
|
||||
def init_async_client(self):
|
||||
from redis.asyncio import RedisCluster
|
||||
|
||||
from .._redis import get_redis_async_client
|
||||
|
||||
if self.redis_async_redis_cluster_client:
|
||||
return self.redis_async_redis_cluster_client
|
||||
|
||||
_redis_client = get_redis_async_client(
|
||||
connection_pool=self.async_redis_conn_pool, **self.redis_kwargs
|
||||
)
|
||||
if isinstance(_redis_client, RedisCluster):
|
||||
self.redis_async_redis_cluster_client = _redis_client
|
||||
|
||||
return _redis_client
|
||||
|
||||
def _run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
return self.redis_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def _async_run_redis_mget_operation(self, keys: List[str]) -> List[Any]:
|
||||
"""
|
||||
Overrides `_async_run_redis_mget_operation` in redis_cache.py
|
||||
"""
|
||||
async_redis_cluster_client = self.init_async_client()
|
||||
return await async_redis_cluster_client.mget_nonatomic(keys=keys) # type: ignore
|
||||
|
||||
async def test_connection(self) -> dict:
|
||||
"""
|
||||
Test the Redis Cluster connection.
|
||||
|
||||
Returns:
|
||||
dict: {"status": "success" | "failed", "message": str, "error": Optional[str]}
|
||||
"""
|
||||
try:
|
||||
import redis.asyncio as redis_async
|
||||
from redis.cluster import ClusterNode
|
||||
|
||||
# Create ClusterNode objects from startup_nodes
|
||||
cluster_kwargs = self.redis_kwargs.copy()
|
||||
startup_nodes = cluster_kwargs.pop("startup_nodes", [])
|
||||
|
||||
new_startup_nodes: List[ClusterNode] = []
|
||||
for item in startup_nodes:
|
||||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
# Create a fresh Redis Cluster client with current settings
|
||||
redis_client = redis_async.RedisCluster(
|
||||
startup_nodes=new_startup_nodes, **cluster_kwargs # type: ignore
|
||||
)
|
||||
|
||||
# Test the connection
|
||||
ping_result = await redis_client.ping() # type: ignore[attr-defined, misc]
|
||||
|
||||
# Close the connection
|
||||
await redis_client.aclose() # type: ignore[attr-defined]
|
||||
|
||||
if ping_result:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Redis Cluster connection test successful",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": "Redis Cluster ping returned False",
|
||||
}
|
||||
except Exception as e:
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
verbose_logger.error(f"Redis Cluster connection test failed: {str(e)}")
|
||||
return {
|
||||
"status": "failed",
|
||||
"message": f"Redis Cluster connection failed: {str(e)}",
|
||||
"error": str(e),
|
||||
}
|
||||
@@ -0,0 +1,450 @@
|
||||
"""
|
||||
Redis Semantic Cache implementation for LiteLLM
|
||||
|
||||
The RedisSemanticCache provides semantic caching functionality using Redis as a backend.
|
||||
This cache stores responses based on the semantic similarity of prompts rather than
|
||||
exact matching, allowing for more flexible caching of LLM responses.
|
||||
|
||||
This implementation uses RedisVL's SemanticCache to find semantically similar prompts
|
||||
and their cached responses.
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
Redis-backed semantic cache for LLM responses.
|
||||
|
||||
This cache uses vector similarity to find semantically similar prompts that have been
|
||||
previously sent to the LLM, allowing for cache hits even when prompts are not identical
|
||||
but carry similar meaning.
|
||||
"""
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
redis_url: Optional[str] = None,
|
||||
similarity_threshold: Optional[float] = None,
|
||||
embedding_model: str = "text-embedding-ada-002",
|
||||
index_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Redis Semantic Cache.
|
||||
|
||||
Args:
|
||||
host: Redis host address
|
||||
port: Redis port
|
||||
password: Redis password
|
||||
redis_url: Full Redis URL (alternative to separate host/port/password)
|
||||
similarity_threshold: Threshold for semantic similarity (0.0 to 1.0)
|
||||
where 1.0 requires exact matches and 0.0 accepts any match
|
||||
embedding_model: Model to use for generating embeddings
|
||||
index_name: Name for the Redis index
|
||||
ttl: Default time-to-live for cache entries in seconds
|
||||
**kwargs: Additional arguments passed to the Redis client
|
||||
|
||||
Raises:
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
"""
|
||||
from redisvl.extensions.llmcache import SemanticCache
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer
|
||||
|
||||
if index_name is None:
|
||||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
|
||||
print_verbose(f"Redis semantic-cache initializing index - {index_name}")
|
||||
|
||||
# Validate similarity threshold
|
||||
if similarity_threshold is None:
|
||||
raise ValueError("similarity_threshold must be provided, passed None")
|
||||
|
||||
# Store configuration
|
||||
self.similarity_threshold = similarity_threshold
|
||||
|
||||
# Convert similarity threshold [0,1] to distance threshold [0,2]
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
self.distance_threshold = 1 - similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
# Set up Redis connection
|
||||
if redis_url is None:
|
||||
try:
|
||||
# Attempt to use provided parameters or fallback to environment variables
|
||||
host = host or os.environ["REDIS_HOST"]
|
||||
port = port or os.environ["REDIS_PORT"]
|
||||
password = password or os.environ["REDIS_PASSWORD"]
|
||||
except KeyError as e:
|
||||
# Raise a more informative exception if any of the required keys are missing
|
||||
missing_var = e.args[0]
|
||||
raise ValueError(
|
||||
f"Missing required Redis configuration: {missing_var}. "
|
||||
f"Provide {missing_var} or redis_url."
|
||||
) from e
|
||||
|
||||
redis_url = f"redis://:{password}@{host}:{port}"
|
||||
|
||||
print_verbose(f"Redis semantic-cache redis_url: {redis_url}")
|
||||
|
||||
# Initialize the Redis vectorizer and cache
|
||||
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
||||
|
||||
self.llmcache = SemanticCache(
|
||||
name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments that may contain a custom TTL
|
||||
|
||||
Returns:
|
||||
Optional[int]: The TTL value in seconds, or None if no TTL should be applied
|
||||
"""
|
||||
ttl = kwargs.get("ttl")
|
||||
if ttl is not None:
|
||||
ttl = int(ttl)
|
||||
return ttl
|
||||
|
||||
def _get_embedding(self, prompt: str) -> List[float]:
|
||||
"""
|
||||
Generate an embedding vector for the given prompt using the configured embedding model.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
# Create an embedding from prompt
|
||||
embedding_response = cast(
|
||||
EmbeddingResponse,
|
||||
litellm.embedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
),
|
||||
)
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
return embedding
|
||||
|
||||
def _get_cache_logic(self, cached_response: Any) -> Any:
|
||||
"""
|
||||
Process the cached response to prepare it for use.
|
||||
|
||||
Args:
|
||||
cached_response: The raw cached response
|
||||
|
||||
Returns:
|
||||
The processed cache response, or None if input was None
|
||||
"""
|
||||
if cached_response is None:
|
||||
return cached_response
|
||||
|
||||
# Convert bytes to string if needed
|
||||
if isinstance(cached_response, bytes):
|
||||
cached_response = cached_response.decode("utf-8")
|
||||
|
||||
# Convert string representation to Python object
|
||||
try:
|
||||
cached_response = json.loads(cached_response)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
print_verbose(f"Error parsing cached response: {str(e)}")
|
||||
return None
|
||||
|
||||
return cached_response
|
||||
|
||||
def set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
value_str: Optional[str] = None
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
||||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
)
|
||||
|
||||
def get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
|
||||
# Return None if no similar prompts found
|
||||
if not results:
|
||||
return None
|
||||
|
||||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
Asynchronously generate an embedding for the given prompt.
|
||||
|
||||
Args:
|
||||
prompt: The text to generate an embedding for
|
||||
**kwargs: Additional arguments that may contain metadata
|
||||
|
||||
Returns:
|
||||
List[float]: The embedding vector
|
||||
"""
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# Route the embedding request through the proxy if appropriate
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
# Use the router for embedding generation
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# Generate embedding directly
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
# Extract and return the embedding vector
|
||||
return embedding_response["data"][0]["embedding"]
|
||||
except Exception as e:
|
||||
print_verbose(f"Error generating async embedding: {str(e)}")
|
||||
raise ValueError(f"Failed to generate embedding: {str(e)}") from e
|
||||
|
||||
async def async_set_cache(self, key: str, value: Any, **kwargs) -> None:
|
||||
"""
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic caching")
|
||||
return
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
||||
async def async_get_cache(self, key: str, **kwargs) -> Any:
|
||||
"""
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
The cached response if a semantically similar prompt is found, else None
|
||||
"""
|
||||
print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}")
|
||||
|
||||
try:
|
||||
# Extract the prompt from messages
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity
|
||||
# For cosine distance: 0 = most similar, 2 = least similar
|
||||
# While similarity: 1 = most similar, 0 = least similar
|
||||
similarity = 1 - vector_distance
|
||||
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
f"current prompt: {prompt}, "
|
||||
f"cached prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_get_cache: {str(e)}")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
|
||||
async def _index_info(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about the Redis index.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Information about the Redis index
|
||||
"""
|
||||
aindex = await self.llmcache._get_async_index()
|
||||
return await aindex.info()
|
||||
|
||||
async def async_set_cache_pipeline(
|
||||
self, cache_list: List[Tuple[str, Any]], **kwargs
|
||||
) -> None:
|
||||
"""
|
||||
Asynchronously store multiple values in the semantic cache.
|
||||
|
||||
Args:
|
||||
cache_list: List of (key, value) tuples to cache
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
try:
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache_pipeline: {str(e)}")
|
||||
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
S3 Cache implementation
|
||||
|
||||
Has 4 methods:
|
||||
- set_cache
|
||||
- get_cache
|
||||
- async_set_cache (uses run_in_executor)
|
||||
- async_get_cache (uses run_in_executor)
|
||||
"""
|
||||
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from datetime import datetime, timezone, timedelta
|
||||
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class S3Cache(BaseCache):
|
||||
def __init__(
|
||||
self,
|
||||
s3_bucket_name,
|
||||
s3_region_name=None,
|
||||
s3_api_version=None,
|
||||
s3_use_ssl: Optional[bool] = True,
|
||||
s3_verify=None,
|
||||
s3_endpoint_url=None,
|
||||
s3_aws_access_key_id=None,
|
||||
s3_aws_secret_access_key=None,
|
||||
s3_aws_session_token=None,
|
||||
s3_config=None,
|
||||
s3_path=None,
|
||||
**kwargs,
|
||||
):
|
||||
import boto3
|
||||
|
||||
self.bucket_name = s3_bucket_name
|
||||
self.key_prefix = s3_path.rstrip("/") + "/" if s3_path else ""
|
||||
# Create an S3 client with custom endpoint URL
|
||||
|
||||
self.s3_client = boto3.client(
|
||||
"s3",
|
||||
region_name=s3_region_name,
|
||||
endpoint_url=s3_endpoint_url,
|
||||
api_version=s3_api_version,
|
||||
use_ssl=s3_use_ssl,
|
||||
verify=s3_verify,
|
||||
aws_access_key_id=s3_aws_access_key_id,
|
||||
aws_secret_access_key=s3_aws_secret_access_key,
|
||||
aws_session_token=s3_aws_session_token,
|
||||
config=s3_config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _to_s3_key(self, key: str) -> str:
|
||||
"""Convert cache key to S3 key"""
|
||||
return self.key_prefix + key.replace(":", "/")
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
try:
|
||||
print_verbose(f"LiteLLM SET Cache - S3. Key={key}. Value={value}")
|
||||
ttl = kwargs.get("ttl", None)
|
||||
# Convert value to JSON before storing in S3
|
||||
serialized_value = json.dumps(value)
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
if ttl is not None:
|
||||
cache_control = f"immutable, max-age={ttl}, s-maxage={ttl}"
|
||||
|
||||
# Calculate expiration time
|
||||
expiration_time = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
# Upload the data to S3 with the calculated expiration time
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
Expires=expiration_time,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
else:
|
||||
cache_control = "immutable, max-age=31536000, s-maxage=31536000"
|
||||
# Upload the data to S3 without specifying Expires
|
||||
self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=key,
|
||||
Body=serialized_value,
|
||||
CacheControl=cache_control,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{key}.json"',
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"S3 Caching: set_cache() - Got exception from S3: {e}")
|
||||
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
"""
|
||||
Asynchronously set cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Set ASYNC S3 Cache: Key={key}. Value={value}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.set_cache, key, value, **kwargs)
|
||||
await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_set_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
def get_cache(self, key, **kwargs):
|
||||
import botocore
|
||||
|
||||
try:
|
||||
key = self._to_s3_key(key)
|
||||
|
||||
print_verbose(f"Get S3 Cache: key: {key}")
|
||||
# Download the data from S3
|
||||
cached_response = self.s3_client.get_object(
|
||||
Bucket=self.bucket_name, Key=key
|
||||
)
|
||||
|
||||
if cached_response is not None:
|
||||
if "Expires" in cached_response:
|
||||
expires_time = cached_response["Expires"]
|
||||
current_time = datetime.now(expires_time.tzinfo)
|
||||
|
||||
if current_time > expires_time:
|
||||
return None
|
||||
|
||||
# cached_response is in `b{} convert it to ModelResponse
|
||||
cached_response = (
|
||||
cached_response["Body"].read().decode("utf-8")
|
||||
) # Convert bytes to string
|
||||
try:
|
||||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if not isinstance(cached_response, dict):
|
||||
cached_response = dict(cached_response)
|
||||
verbose_logger.debug(
|
||||
f"Got S3 Cache: key: {key}, cached_response {cached_response}. Type Response {type(cached_response)}"
|
||||
)
|
||||
|
||||
return cached_response
|
||||
except botocore.exceptions.ClientError as e: # type: ignore
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
verbose_logger.debug(
|
||||
f"S3 Cache: The specified key '{key}' does not exist in the S3 bucket."
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
"""
|
||||
Asynchronously get cache using run_in_executor to avoid blocking the event loop.
|
||||
Compatible with Python 3.8+.
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Get ASYNC S3 Cache: key: {key}")
|
||||
loop = asyncio.get_event_loop()
|
||||
func = partial(self.get_cache, key, **kwargs)
|
||||
result = await loop.run_in_executor(None, func)
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"S3 Caching: async_get_cache() - Got exception from S3: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
def flush_cache(self):
|
||||
pass
|
||||
|
||||
async def disconnect(self):
|
||||
pass
|
||||
|
||||
async def async_set_cache_pipeline(self, cache_list, **kwargs):
|
||||
tasks = []
|
||||
for val in cache_list:
|
||||
tasks.append(self.async_set_cache(val[0], val[1], **kwargs))
|
||||
await asyncio.gather(*tasks)
|
||||
Reference in New Issue
Block a user