chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Auto-Routing Strategy that works with a Semantic Router Config
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
else:
|
||||
Router = Any
|
||||
PreRoutingHookResponse = Any
|
||||
Route = Any
|
||||
|
||||
|
||||
class AutoRouter(CustomLogger):
|
||||
DEFAULT_AUTO_SYNC_VALUE = "local"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
default_model: str,
|
||||
embedding_model: str,
|
||||
litellm_router_instance: "Router",
|
||||
auto_router_config_path: Optional[str] = None,
|
||||
auto_router_config: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Auto-Router class that uses a semantic router to route requests to the appropriate model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to use for the auto-router. eg. if model = "auto-router1" then us this router.
|
||||
auto_router_config_path: The path to the router config file.
|
||||
auto_router_config: The config to use for the auto-router. You can either use this or auto_router_config_path, not both.
|
||||
default_model: The default model to use if no route is found.
|
||||
embedding_model: The embedding model to use for the auto-router.
|
||||
litellm_router_instance: The instance of the LiteLLM Router.
|
||||
"""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
self.auto_router_config_path: Optional[str] = auto_router_config_path
|
||||
self.auto_router_config: Optional[str] = auto_router_config
|
||||
self.auto_sync_value = self.DEFAULT_AUTO_SYNC_VALUE
|
||||
self.loaded_routes: List[Route] = self._load_semantic_routing_routes()
|
||||
self.routelayer: Optional[SemanticRouter] = None
|
||||
self.default_model = default_model
|
||||
self.embedding_model: str = embedding_model
|
||||
self.litellm_router_instance: "Router" = litellm_router_instance
|
||||
|
||||
def _load_semantic_routing_routes(self) -> List[Route]:
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
if self.auto_router_config_path:
|
||||
return SemanticRouter.from_json(self.auto_router_config_path).routes
|
||||
elif self.auto_router_config:
|
||||
return self._load_auto_router_routes_from_config_json()
|
||||
else:
|
||||
raise ValueError("No router config provided")
|
||||
|
||||
def _load_auto_router_routes_from_config_json(self) -> List[Route]:
|
||||
import json
|
||||
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
if self.auto_router_config is None:
|
||||
raise ValueError("No auto router config provided")
|
||||
auto_router_routes: List[Route] = []
|
||||
loaded_config = json.loads(self.auto_router_config)
|
||||
for route in loaded_config.get("routes", []):
|
||||
auto_router_routes.append(
|
||||
Route(
|
||||
name=route.get("name"),
|
||||
description=route.get("description"),
|
||||
utterances=route.get("utterances", []),
|
||||
score_threshold=route.get("score_threshold"),
|
||||
)
|
||||
)
|
||||
return auto_router_routes
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""
|
||||
This hook is called before the routing decision is made.
|
||||
|
||||
Used for the litellm auto-router to modify the request before the routing decision is made.
|
||||
"""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
from semantic_router.schema import RouteChoice
|
||||
|
||||
from litellm.router_strategy.auto_router.litellm_encoder import (
|
||||
LiteLLMRouterEncoder,
|
||||
)
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
if messages is None:
|
||||
# do nothing, return same inputs
|
||||
return None
|
||||
|
||||
if self.routelayer is None:
|
||||
#######################
|
||||
# Create the route layer
|
||||
#######################
|
||||
self.routelayer = SemanticRouter(
|
||||
routes=self.loaded_routes,
|
||||
encoder=LiteLLMRouterEncoder(
|
||||
litellm_router_instance=self.litellm_router_instance,
|
||||
model_name=self.embedding_model,
|
||||
),
|
||||
auto_sync=self.auto_sync_value,
|
||||
)
|
||||
|
||||
user_message: Dict[str, str] = messages[-1]
|
||||
message_content: str = user_message.get("content", "")
|
||||
route_choice: Optional[Union[RouteChoice, List[RouteChoice]]] = self.routelayer(
|
||||
text=message_content
|
||||
)
|
||||
verbose_router_logger.debug(f"route_choice: {route_choice}")
|
||||
if isinstance(route_choice, RouteChoice):
|
||||
model = route_choice.name or self.default_model
|
||||
elif isinstance(route_choice, list):
|
||||
model = route_choice[0].name or self.default_model
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from semantic_router.encoders import DenseEncoder
|
||||
from semantic_router.encoders.base import AsymmetricDenseMixin
|
||||
|
||||
import litellm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router
|
||||
else:
|
||||
Router = Any
|
||||
|
||||
|
||||
def litellm_to_list(embeds: litellm.EmbeddingResponse) -> list[list[float]]:
|
||||
"""Convert a LiteLLM embedding response to a list of embeddings.
|
||||
|
||||
:param embeds: The LiteLLM embedding response.
|
||||
:return: A list of embeddings.
|
||||
"""
|
||||
if (
|
||||
not embeds
|
||||
or not isinstance(embeds, litellm.EmbeddingResponse)
|
||||
or not embeds.data
|
||||
):
|
||||
raise ValueError("No embeddings found in LiteLLM embedding response.")
|
||||
return [x["embedding"] for x in embeds.data]
|
||||
|
||||
|
||||
class CustomDenseEncoder(DenseEncoder):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def __init__(self, litellm_router_instance: Optional["Router"] = None, **kwargs):
|
||||
# Extract litellm_router_instance from kwargs if passed there
|
||||
if "litellm_router_instance" in kwargs:
|
||||
litellm_router_instance = kwargs.pop("litellm_router_instance")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
|
||||
class LiteLLMRouterEncoder(CustomDenseEncoder, AsymmetricDenseMixin):
|
||||
"""LiteLLM encoder class for generating embeddings using LiteLLM.
|
||||
|
||||
The LiteLLMRouterEncoder class is a subclass of DenseEncoder and utilizes the LiteLLM Router SDK
|
||||
to generate embeddings for given documents. It supports all encoders supported by LiteLLM
|
||||
and supports customization of the score threshold for filtering or processing the embeddings.
|
||||
"""
|
||||
|
||||
type: str = "internal_litellm_router"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
litellm_router_instance: "Router",
|
||||
model_name: str,
|
||||
score_threshold: Union[float, None] = None,
|
||||
):
|
||||
"""Initialize the LiteLLMEncoder.
|
||||
|
||||
:param litellm_router_instance: The instance of the LiteLLM Router.
|
||||
:type litellm_router_instance: Router
|
||||
:param model_name: The name of the embedding model to use. Must use LiteLLM naming
|
||||
convention (e.g. "openai/text-embedding-3-small" or "mistral/mistral-embed").
|
||||
:type model_name: str
|
||||
:param score_threshold: The score threshold for the embeddings.
|
||||
:type score_threshold: float
|
||||
"""
|
||||
super().__init__(
|
||||
name=model_name,
|
||||
score_threshold=score_threshold if score_threshold is not None else 0.3,
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
def __call__(self, docs: list[Any], **kwargs) -> list[list[float]]:
|
||||
"""Encode a list of text documents into embeddings using LiteLLM.
|
||||
|
||||
:param docs: List of text documents to encode.
|
||||
:return: List of embeddings for each document."""
|
||||
return self.encode_queries(docs, **kwargs)
|
||||
|
||||
async def acall(self, docs: list[Any], **kwargs) -> list[list[float]]:
|
||||
"""Encode a list of documents into embeddings using LiteLLM asynchronously.
|
||||
|
||||
:param docs: List of documents to encode.
|
||||
:return: List of embeddings for each document."""
|
||||
return await self.aencode_queries(docs, **kwargs)
|
||||
|
||||
def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = self.litellm_router_instance.embedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = self.litellm_router_instance.embedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = await self.litellm_router_instance.aembedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = await self.litellm_router_instance.aembedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
@@ -0,0 +1,261 @@
|
||||
"""
|
||||
Base class across routing strategies to abstract commmon functions like batch incrementing redis
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.constants import DEFAULT_REDIS_SYNC_INTERVAL
|
||||
|
||||
|
||||
class BaseRoutingStrategy(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
dual_cache: DualCache,
|
||||
should_batch_redis_writes: bool,
|
||||
default_sync_interval: Optional[Union[int, float]],
|
||||
):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
self._sync_task: Optional[asyncio.Task[None]] = None
|
||||
if should_batch_redis_writes:
|
||||
self.setup_sync_task(default_sync_interval)
|
||||
|
||||
self.in_memory_keys_to_update: set[
|
||||
str
|
||||
] = set() # Set with max size of 1000 keys
|
||||
|
||||
def setup_sync_task(self, default_sync_interval: Optional[Union[int, float]]):
|
||||
"""Setup the sync task in a way that's compatible with FastAPI"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
self._sync_task = loop.create_task(
|
||||
self.periodic_sync_in_memory_spend_with_redis(
|
||||
default_sync_interval=default_sync_interval
|
||||
)
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""Cleanup method to be called when shutting down"""
|
||||
if self._sync_task is not None:
|
||||
self._sync_task.cancel()
|
||||
try:
|
||||
await self._sync_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
async def _increment_value_list_in_current_window(
|
||||
self, increment_list: List[Tuple[str, int]], ttl: int
|
||||
) -> List[float]:
|
||||
"""
|
||||
Increment a list of values in the current window
|
||||
"""
|
||||
results = []
|
||||
for key, value in increment_list:
|
||||
result = await self._increment_value_in_current_window(
|
||||
key=key, value=value, ttl=ttl
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
async def _increment_value_in_current_window(
|
||||
self, key: str, value: Union[int, float], ttl: int
|
||||
):
|
||||
"""
|
||||
Increment spend within existing budget window
|
||||
|
||||
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
|
||||
|
||||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||
"""
|
||||
result = await self.dual_cache.in_memory_cache.async_increment(
|
||||
key=key,
|
||||
value=value,
|
||||
ttl=ttl,
|
||||
)
|
||||
increment_op = RedisPipelineIncrementOperation(
|
||||
key=key,
|
||||
increment_value=value,
|
||||
ttl=ttl,
|
||||
)
|
||||
|
||||
self.redis_increment_operation_queue.append(increment_op)
|
||||
self.add_to_in_memory_keys_to_update(key=key)
|
||||
return result
|
||||
|
||||
async def periodic_sync_in_memory_spend_with_redis(
|
||||
self, default_sync_interval: Optional[Union[int, float]]
|
||||
):
|
||||
"""
|
||||
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
|
||||
|
||||
Required for multi-instance environment usage of provider budgets
|
||||
"""
|
||||
default_sync_interval = default_sync_interval or DEFAULT_REDIS_SYNC_INTERVAL
|
||||
while True:
|
||||
try:
|
||||
await self._sync_in_memory_spend_with_redis()
|
||||
await asyncio.sleep(
|
||||
default_sync_interval
|
||||
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
|
||||
await asyncio.sleep(
|
||||
default_sync_interval
|
||||
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
|
||||
|
||||
async def _push_in_memory_increments_to_redis(self):
|
||||
"""
|
||||
How this works:
|
||||
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
|
||||
- This function compresses multiple increments for the same key into a single operation
|
||||
- Then pushes all increments to Redis in a batched pipeline to optimize performance
|
||||
|
||||
Only runs if Redis is initialized
|
||||
"""
|
||||
try:
|
||||
if not self.dual_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
if len(self.redis_increment_operation_queue) > 0:
|
||||
# Compress operations for the same key
|
||||
compressed_ops: Dict[str, RedisPipelineIncrementOperation] = {}
|
||||
ops_to_remove = []
|
||||
for idx, op in enumerate(self.redis_increment_operation_queue):
|
||||
if op["key"] in compressed_ops:
|
||||
# Add to existing increment
|
||||
compressed_ops[op["key"]]["increment_value"] += op[
|
||||
"increment_value"
|
||||
]
|
||||
else:
|
||||
compressed_ops[op["key"]] = op
|
||||
|
||||
ops_to_remove.append(idx)
|
||||
|
||||
# Convert back to list
|
||||
compressed_queue = list(compressed_ops.values())
|
||||
|
||||
increment_result = (
|
||||
await self.dual_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=compressed_queue,
|
||||
)
|
||||
)
|
||||
|
||||
self.redis_increment_operation_queue = [
|
||||
op
|
||||
for idx, op in enumerate(self.redis_increment_operation_queue)
|
||||
if idx not in ops_to_remove
|
||||
]
|
||||
|
||||
if increment_result is not None:
|
||||
return_result = {
|
||||
key["key"]: op
|
||||
for key, op in zip(compressed_queue, increment_result)
|
||||
}
|
||||
else:
|
||||
return_result = {}
|
||||
return return_result
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
def add_to_in_memory_keys_to_update(self, key: str):
|
||||
self.in_memory_keys_to_update.add(key)
|
||||
|
||||
def get_key_pattern_to_sync(self) -> Optional[str]:
|
||||
"""
|
||||
Get the key pattern to sync
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_in_memory_keys_to_update(self) -> Set[str]:
|
||||
return self.in_memory_keys_to_update
|
||||
|
||||
def get_and_reset_in_memory_keys_to_update(self) -> Set[str]:
|
||||
"""Atomic get and reset in-memory keys to update"""
|
||||
keys = self.in_memory_keys_to_update
|
||||
self.in_memory_keys_to_update = set()
|
||||
return keys
|
||||
|
||||
def reset_in_memory_keys_to_update(self):
|
||||
self.in_memory_keys_to_update = set()
|
||||
|
||||
async def _sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Ensures in-memory cache is updated with latest Redis values for all provider spends.
|
||||
|
||||
Why Do we need this?
|
||||
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
|
||||
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
|
||||
|
||||
What this does:
|
||||
1. Push all provider spend increments to Redis
|
||||
2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
"""
|
||||
|
||||
try:
|
||||
# No need to sync if Redis cache is not initialized
|
||||
if self.dual_cache.redis_cache is None:
|
||||
return
|
||||
|
||||
# 2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
cache_keys = (
|
||||
self.get_in_memory_keys_to_update()
|
||||
) # if no pattern OR redis cache does not support scan_iter, use in-memory keys
|
||||
|
||||
cache_keys_list = list(cache_keys)
|
||||
|
||||
# 1. Snapshot in-memory before
|
||||
in_memory_before_dict = {}
|
||||
in_memory_before = (
|
||||
await self.dual_cache.in_memory_cache.async_batch_get_cache(
|
||||
keys=cache_keys_list
|
||||
)
|
||||
)
|
||||
for k, v in zip(cache_keys_list, in_memory_before):
|
||||
in_memory_before_dict[k] = float(v or 0)
|
||||
|
||||
# 1. Push all provider spend increments to Redis
|
||||
redis_values = await self._push_in_memory_increments_to_redis()
|
||||
if redis_values is None:
|
||||
return
|
||||
|
||||
# 4. Merge
|
||||
for key in cache_keys_list:
|
||||
redis_val = float(redis_values.get(key, 0) or 0)
|
||||
before = float(in_memory_before_dict.get(key, 0) or 0)
|
||||
after = float(
|
||||
await self.dual_cache.in_memory_cache.async_get_cache(key=key) or 0
|
||||
)
|
||||
delta = after - before
|
||||
if after <= redis_val:
|
||||
merged = redis_val + delta
|
||||
else:
|
||||
continue
|
||||
# elif "rpm" in key: # redis is behind in-memory cache
|
||||
# # shut down the proxy
|
||||
# print(f"self.redis_increment_operation_queue: {self.redis_increment_operation_queue}")
|
||||
# print(f"Redis_val={redis_val} is behind in-memory cache_val={after} for key: {key}. This should not happen, since we should be updating redis with in-memory cache.")
|
||||
# import os
|
||||
# os._exit(1)
|
||||
# raise Exception(f"Redis is behind in-memory cache for key: {key}. This should not happen, since we should be updating redis with in-memory cache.")
|
||||
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||
key=key, value=merged
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
@@ -0,0 +1,898 @@
|
||||
"""
|
||||
Provider budget limiting
|
||||
|
||||
Use this if you want to set $ budget limits for each provider.
|
||||
|
||||
Note: This is a filter, like tag-routing. Meaning it will accept healthy deployments and then filter out deployments that have exceeded their budget limit.
|
||||
|
||||
This means you can use this with weighted-pick, lowest-latency, simple-shuffle, routing etc
|
||||
|
||||
Example:
|
||||
```
|
||||
openai:
|
||||
budget_limit: 0.000000000001
|
||||
time_period: 1d
|
||||
anthropic:
|
||||
budget_limit: 100
|
||||
time_period: 7d
|
||||
```
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.caching.redis_cache import RedisPipelineIncrementOperation
|
||||
from litellm.integrations.custom_logger import CustomLogger, Span
|
||||
from litellm.litellm_core_utils.duration_parser import duration_in_seconds
|
||||
from litellm.router_strategy.tag_based_routing import _get_tags_from_request_kwargs
|
||||
from litellm.router_utils.cooldown_callbacks import (
|
||||
_get_prometheus_logger_from_callbacks,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import DeploymentTypedDict, LiteLLM_Params, RouterErrors
|
||||
from litellm.types.utils import BudgetConfig
|
||||
from litellm.types.utils import BudgetConfig as GenericBudgetInfo
|
||||
from litellm.types.utils import GenericBudgetConfigType, StandardLoggingPayload
|
||||
|
||||
DEFAULT_REDIS_SYNC_INTERVAL = 1
|
||||
|
||||
|
||||
class _LiteLLMParamsDictView:
|
||||
"""
|
||||
Lightweight attribute view over `litellm_params` dict.
|
||||
|
||||
This avoids pydantic construction in request hot-path while preserving
|
||||
attribute-style access used by `litellm.get_llm_provider(...)`.
|
||||
"""
|
||||
|
||||
__slots__ = ("_params",)
|
||||
|
||||
def __init__(self, params: Dict[str, Any]):
|
||||
self._params = params
|
||||
|
||||
def __getattr__(self, key: str) -> Any:
|
||||
return self._params.get(key)
|
||||
|
||||
def __getitem__(self, key: str) -> Any:
|
||||
return self._params.get(key)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._params
|
||||
|
||||
def get(self, key: str, default: Any = None) -> Any:
|
||||
return self._params.get(key, default)
|
||||
|
||||
def keys(self):
|
||||
return self._params.keys()
|
||||
|
||||
def values(self):
|
||||
return self._params.values()
|
||||
|
||||
def items(self):
|
||||
return self._params.items()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._params)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._params)
|
||||
|
||||
def dict(self) -> Dict[str, Any]:
|
||||
return dict(self._params)
|
||||
|
||||
def model_dump(self) -> Dict[str, Any]:
|
||||
return dict(self._params)
|
||||
|
||||
|
||||
class RouterBudgetLimiting(CustomLogger):
|
||||
def __init__(
|
||||
self,
|
||||
dual_cache: DualCache,
|
||||
provider_budget_config: Optional[dict],
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
self.dual_cache = dual_cache
|
||||
self.redis_increment_operation_queue: List[RedisPipelineIncrementOperation] = []
|
||||
asyncio.create_task(self.periodic_sync_in_memory_spend_with_redis())
|
||||
self.provider_budget_config: Optional[
|
||||
GenericBudgetConfigType
|
||||
] = provider_budget_config
|
||||
self.deployment_budget_config: Optional[GenericBudgetConfigType] = None
|
||||
self.tag_budget_config: Optional[GenericBudgetConfigType] = None
|
||||
self._init_provider_budgets()
|
||||
self._init_deployment_budgets(model_list=model_list)
|
||||
self._init_tag_budgets()
|
||||
|
||||
# Add self to litellm callbacks if it's a list
|
||||
if isinstance(litellm.callbacks, list):
|
||||
litellm.logging_callback_manager.add_litellm_callback(self) # type: ignore
|
||||
|
||||
async def async_filter_deployments(
|
||||
self,
|
||||
model: str,
|
||||
healthy_deployments: List,
|
||||
messages: Optional[List[AllMessageValues]],
|
||||
request_kwargs: Optional[dict] = None,
|
||||
parent_otel_span: Optional[Span] = None, # type: ignore
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Filter out deployments that have exceeded their provider budget limit.
|
||||
|
||||
|
||||
Example:
|
||||
if deployment = openai/gpt-3.5-turbo
|
||||
and openai spend > openai budget limit
|
||||
then skip this deployment
|
||||
"""
|
||||
|
||||
# If a single deployment is passed, convert it to a list
|
||||
if isinstance(healthy_deployments, dict):
|
||||
healthy_deployments = [healthy_deployments]
|
||||
|
||||
# Don't do any filtering if there are no healthy deployments
|
||||
if len(healthy_deployments) == 0:
|
||||
return healthy_deployments
|
||||
|
||||
potential_deployments: List[Dict] = []
|
||||
|
||||
(
|
||||
cache_keys,
|
||||
provider_configs,
|
||||
deployment_configs,
|
||||
deployment_providers,
|
||||
) = await self._async_get_cache_keys_for_router_budget_limiting(
|
||||
healthy_deployments=healthy_deployments,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# Single cache read for all spend values
|
||||
if len(cache_keys) > 0:
|
||||
_current_spends = await self.dual_cache.async_batch_get_cache(
|
||||
keys=cache_keys,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
current_spends: List = _current_spends or [0.0] * len(cache_keys)
|
||||
|
||||
# Map spends to their respective keys
|
||||
spend_map: Dict[str, float] = {}
|
||||
for idx, key in enumerate(cache_keys):
|
||||
spend_map[key] = float(current_spends[idx] or 0.0)
|
||||
|
||||
(
|
||||
potential_deployments,
|
||||
deployment_above_budget_info,
|
||||
) = self._filter_out_deployments_above_budget(
|
||||
healthy_deployments=healthy_deployments,
|
||||
provider_configs=provider_configs,
|
||||
deployment_configs=deployment_configs,
|
||||
deployment_providers=deployment_providers,
|
||||
spend_map=spend_map,
|
||||
potential_deployments=potential_deployments,
|
||||
request_tags=_get_tags_from_request_kwargs(
|
||||
request_kwargs=request_kwargs
|
||||
),
|
||||
)
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_with_provider_budget_routing.value}: {deployment_above_budget_info}"
|
||||
)
|
||||
|
||||
return potential_deployments
|
||||
else:
|
||||
return healthy_deployments
|
||||
|
||||
def _filter_out_deployments_above_budget(
|
||||
self,
|
||||
potential_deployments: List[Dict[str, Any]],
|
||||
healthy_deployments: List[Dict[str, Any]],
|
||||
provider_configs: Dict[str, GenericBudgetInfo],
|
||||
deployment_configs: Dict[str, GenericBudgetInfo],
|
||||
deployment_providers: List[Optional[str]],
|
||||
spend_map: Dict[str, float],
|
||||
request_tags: List[str],
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""
|
||||
Filter out deployments that have exceeded their budget limit.
|
||||
Follow budget checks are run here:
|
||||
- Provider budget
|
||||
- Deployment budget
|
||||
- Request tags budget
|
||||
Returns:
|
||||
Tuple[List[Dict[str, Any]], str]:
|
||||
- A tuple containing the filtered deployments
|
||||
- A string containing debug information about deployments that exceeded their budget limit.
|
||||
"""
|
||||
# Filter deployments based on both provider and deployment budgets
|
||||
deployment_above_budget_info: str = ""
|
||||
for idx, deployment in enumerate(healthy_deployments):
|
||||
is_within_budget = True
|
||||
|
||||
# Check provider budget
|
||||
if self.provider_budget_config:
|
||||
if idx < len(deployment_providers):
|
||||
provider = deployment_providers[idx]
|
||||
else:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
if provider in provider_configs:
|
||||
config = provider_configs[provider]
|
||||
if config.max_budget is None:
|
||||
continue
|
||||
current_spend = spend_map.get(
|
||||
f"provider_spend:{provider}:{config.budget_duration}", 0.0
|
||||
)
|
||||
self._track_provider_remaining_budget_prometheus(
|
||||
provider=provider,
|
||||
spend=current_spend,
|
||||
budget_limit=config.max_budget,
|
||||
)
|
||||
|
||||
if config.max_budget and current_spend >= config.max_budget:
|
||||
debug_msg = f"Exceeded budget for provider {provider}: {current_spend} >= {config.max_budget}"
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
|
||||
# Check deployment budget
|
||||
if self.deployment_budget_config and is_within_budget:
|
||||
_model_name = deployment.get("model_name")
|
||||
_litellm_params = deployment.get("litellm_params") or {}
|
||||
_litellm_model_name = _litellm_params.get("model")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
if model_id in deployment_configs:
|
||||
config = deployment_configs[model_id]
|
||||
current_spend = spend_map.get(
|
||||
f"deployment_spend:{model_id}:{config.budget_duration}", 0.0
|
||||
)
|
||||
if config.max_budget and current_spend >= config.max_budget:
|
||||
debug_msg = f"Exceeded budget for deployment model_name: {_model_name}, litellm_params.model: {_litellm_model_name}, model_id: {model_id}: {current_spend} >= {config.budget_duration}"
|
||||
verbose_router_logger.debug(debug_msg)
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
# Check tag budget
|
||||
if self.tag_budget_config and is_within_budget:
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
_tag_spend = spend_map.get(
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}",
|
||||
0.0,
|
||||
)
|
||||
if (
|
||||
_tag_budget_config.max_budget
|
||||
and _tag_spend >= _tag_budget_config.max_budget
|
||||
):
|
||||
debug_msg = f"Exceeded budget for tag='{_tag}', tag_spend={_tag_spend}, tag_budget_limit={_tag_budget_config.max_budget}"
|
||||
verbose_router_logger.debug(debug_msg)
|
||||
deployment_above_budget_info += f"{debug_msg}\n"
|
||||
is_within_budget = False
|
||||
continue
|
||||
if is_within_budget:
|
||||
potential_deployments.append(deployment)
|
||||
|
||||
return potential_deployments, deployment_above_budget_info
|
||||
|
||||
async def _async_get_cache_keys_for_router_budget_limiting(
|
||||
self,
|
||||
healthy_deployments: List[Dict[str, Any]],
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
) -> Tuple[
|
||||
List[str],
|
||||
Dict[str, GenericBudgetInfo],
|
||||
Dict[str, GenericBudgetInfo],
|
||||
List[Optional[str]],
|
||||
]:
|
||||
"""
|
||||
Returns list of cache keys to fetch from router cache for budget limiting and provider and deployment configs
|
||||
|
||||
Returns:
|
||||
Tuple[List[str], Dict[str, GenericBudgetInfo], Dict[str, GenericBudgetInfo], List[Optional[str]]]:
|
||||
- List of cache keys to fetch from router cache for budget limiting
|
||||
- Dict of provider budget configs `provider_configs`
|
||||
- Dict of deployment budget configs `deployment_configs`
|
||||
- List of resolved providers aligned by deployment index `deployment_providers`
|
||||
"""
|
||||
cache_keys: List[str] = []
|
||||
provider_configs: Dict[str, GenericBudgetInfo] = {}
|
||||
deployment_configs: Dict[str, GenericBudgetInfo] = {}
|
||||
deployment_providers: List[Optional[str]] = []
|
||||
|
||||
for deployment in healthy_deployments:
|
||||
# Check provider budgets
|
||||
if self.provider_budget_config:
|
||||
provider = self._get_llm_provider_for_deployment(deployment)
|
||||
deployment_providers.append(provider)
|
||||
if provider is not None:
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if (
|
||||
budget_config is not None
|
||||
and budget_config.budget_duration is not None
|
||||
):
|
||||
provider_configs[provider] = budget_config
|
||||
cache_keys.append(
|
||||
f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
)
|
||||
|
||||
# Check deployment budgets
|
||||
if self.deployment_budget_config:
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
if model_id is not None:
|
||||
budget_config = self._get_budget_config_for_deployment(model_id)
|
||||
if budget_config is not None:
|
||||
deployment_configs[model_id] = budget_config
|
||||
cache_keys.append(
|
||||
f"deployment_spend:{model_id}:{budget_config.budget_duration}"
|
||||
)
|
||||
# Check tag budgets
|
||||
if self.tag_budget_config:
|
||||
request_tags = _get_tags_from_request_kwargs(
|
||||
request_kwargs=request_kwargs
|
||||
)
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
cache_keys.append(
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
|
||||
)
|
||||
return (
|
||||
cache_keys,
|
||||
provider_configs,
|
||||
deployment_configs,
|
||||
deployment_providers,
|
||||
)
|
||||
|
||||
async def _get_or_set_budget_start_time(
|
||||
self, start_time_key: str, current_time: float, ttl_seconds: int
|
||||
) -> float:
|
||||
"""
|
||||
Checks if the key = `provider_budget_start_time:{provider}` exists in cache.
|
||||
|
||||
If it does, return the value.
|
||||
If it does not, set the key to `current_time` and return the value.
|
||||
"""
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
return float(budget_start)
|
||||
|
||||
async def _handle_new_budget_window(
|
||||
self,
|
||||
spend_key: str,
|
||||
start_time_key: str,
|
||||
current_time: float,
|
||||
response_cost: float,
|
||||
ttl_seconds: int,
|
||||
) -> float:
|
||||
"""
|
||||
Handle start of new budget window by resetting spend and start time
|
||||
|
||||
Enters this when:
|
||||
- The budget does not exist in cache, so we need to set it
|
||||
- The budget window has expired, so we need to reset everything
|
||||
|
||||
Does 2 things:
|
||||
- stores key: `provider_spend:{provider}:1d`, value: response_cost
|
||||
- stores key: `provider_budget_start_time:{provider}`, value: current_time.
|
||||
This stores the start time of the new budget window
|
||||
"""
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=response_cost, ttl=ttl_seconds
|
||||
)
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=current_time, ttl=ttl_seconds
|
||||
)
|
||||
return current_time
|
||||
|
||||
async def _increment_spend_in_current_window(
|
||||
self, spend_key: str, response_cost: float, ttl: int
|
||||
):
|
||||
"""
|
||||
Increment spend within existing budget window
|
||||
|
||||
Runs once the budget start time exists in Redis Cache (on the 2nd and subsequent requests to the same provider)
|
||||
|
||||
- Increments the spend in memory cache (so spend instantly updated in memory)
|
||||
- Queues the increment operation to Redis Pipeline (using batched pipeline to optimize performance. Using Redis for multi instance environment of LiteLLM)
|
||||
"""
|
||||
await self.dual_cache.in_memory_cache.async_increment(
|
||||
key=spend_key,
|
||||
value=response_cost,
|
||||
ttl=ttl,
|
||||
)
|
||||
increment_op = RedisPipelineIncrementOperation(
|
||||
key=spend_key,
|
||||
increment_value=response_cost,
|
||||
ttl=ttl,
|
||||
)
|
||||
self.redis_increment_operation_queue.append(increment_op)
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""Original method now uses helper functions"""
|
||||
verbose_router_logger.debug("in RouterBudgetLimiting.async_log_success_event")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
if standard_logging_payload is None:
|
||||
raise ValueError("standard_logging_payload is required")
|
||||
|
||||
response_cost: float = standard_logging_payload.get("response_cost", 0)
|
||||
model_id: str = str(standard_logging_payload.get("model_id", ""))
|
||||
custom_llm_provider: str = kwargs.get("litellm_params", {}).get(
|
||||
"custom_llm_provider", None
|
||||
)
|
||||
if custom_llm_provider is None:
|
||||
raise ValueError("custom_llm_provider is required")
|
||||
|
||||
budget_config = self._get_budget_config_for_provider(custom_llm_provider)
|
||||
if budget_config:
|
||||
# increment spend for provider
|
||||
spend_key = (
|
||||
f"provider_spend:{custom_llm_provider}:{budget_config.budget_duration}"
|
||||
)
|
||||
start_time_key = f"provider_budget_start_time:{custom_llm_provider}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=budget_config,
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
deployment_budget_config = self._get_budget_config_for_deployment(model_id)
|
||||
if deployment_budget_config:
|
||||
# increment spend for specific deployment id
|
||||
deployment_spend_key = f"deployment_spend:{model_id}:{deployment_budget_config.budget_duration}"
|
||||
deployment_start_time_key = f"deployment_budget_start_time:{model_id}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=deployment_budget_config,
|
||||
spend_key=deployment_spend_key,
|
||||
start_time_key=deployment_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
request_tags = _get_tags_from_request_kwargs(kwargs)
|
||||
if len(request_tags) > 0:
|
||||
for _tag in request_tags:
|
||||
_tag_budget_config = self._get_budget_config_for_tag(_tag)
|
||||
if _tag_budget_config:
|
||||
_tag_spend_key = (
|
||||
f"tag_spend:{_tag}:{_tag_budget_config.budget_duration}"
|
||||
)
|
||||
_tag_start_time_key = f"tag_budget_start_time:{_tag}"
|
||||
await self._increment_spend_for_key(
|
||||
budget_config=_tag_budget_config,
|
||||
spend_key=_tag_spend_key,
|
||||
start_time_key=_tag_start_time_key,
|
||||
response_cost=response_cost,
|
||||
)
|
||||
|
||||
async def _increment_spend_for_key(
|
||||
self,
|
||||
budget_config: GenericBudgetInfo,
|
||||
spend_key: str,
|
||||
start_time_key: str,
|
||||
response_cost: float,
|
||||
):
|
||||
if budget_config.budget_duration is None:
|
||||
return
|
||||
|
||||
current_time = datetime.now(timezone.utc).timestamp()
|
||||
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
|
||||
|
||||
budget_start = await self._get_or_set_budget_start_time(
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
if budget_start is None:
|
||||
# First spend for this provider
|
||||
budget_start = await self._handle_new_budget_window(
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
response_cost=response_cost,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
elif (current_time - budget_start) > ttl_seconds:
|
||||
# Budget window expired - reset everything
|
||||
verbose_router_logger.debug("Budget window expired - resetting everything")
|
||||
budget_start = await self._handle_new_budget_window(
|
||||
spend_key=spend_key,
|
||||
start_time_key=start_time_key,
|
||||
current_time=current_time,
|
||||
response_cost=response_cost,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
else:
|
||||
# Within existing window - increment spend
|
||||
remaining_time = ttl_seconds - (current_time - budget_start)
|
||||
ttl_for_increment = int(remaining_time)
|
||||
|
||||
await self._increment_spend_in_current_window(
|
||||
spend_key=spend_key, response_cost=response_cost, ttl=ttl_for_increment
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Incremented spend for {spend_key} by {response_cost}"
|
||||
)
|
||||
|
||||
async def periodic_sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Handler that triggers sync_in_memory_spend_with_redis every DEFAULT_REDIS_SYNC_INTERVAL seconds
|
||||
|
||||
Required for multi-instance environment usage of provider budgets
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
await self._sync_in_memory_spend_with_redis()
|
||||
await asyncio.sleep(
|
||||
DEFAULT_REDIS_SYNC_INTERVAL
|
||||
) # Wait for DEFAULT_REDIS_SYNC_INTERVAL seconds before next sync
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(f"Error in periodic sync task: {str(e)}")
|
||||
await asyncio.sleep(
|
||||
DEFAULT_REDIS_SYNC_INTERVAL
|
||||
) # Still wait DEFAULT_REDIS_SYNC_INTERVAL seconds on error before retrying
|
||||
|
||||
async def _push_in_memory_increments_to_redis(self):
|
||||
"""
|
||||
How this works:
|
||||
- async_log_success_event collects all provider spend increments in `redis_increment_operation_queue`
|
||||
- This function pushes all increments to Redis in a batched pipeline to optimize performance
|
||||
|
||||
Only runs if Redis is initialized
|
||||
"""
|
||||
try:
|
||||
if not self.dual_cache.redis_cache:
|
||||
return # Redis is not initialized
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"Pushing Redis Increment Pipeline for queue: %s",
|
||||
self.redis_increment_operation_queue,
|
||||
)
|
||||
if len(self.redis_increment_operation_queue) > 0:
|
||||
asyncio.create_task(
|
||||
self.dual_cache.redis_cache.async_increment_pipeline(
|
||||
increment_list=self.redis_increment_operation_queue,
|
||||
)
|
||||
)
|
||||
|
||||
self.redis_increment_operation_queue = []
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
async def _sync_in_memory_spend_with_redis(self):
|
||||
"""
|
||||
Ensures in-memory cache is updated with latest Redis values for all provider spends.
|
||||
|
||||
Why Do we need this?
|
||||
- Optimization to hit sub 100ms latency. Performance was impacted when redis was used for read/write per request
|
||||
- Use provider budgets in multi-instance environment, we use Redis to sync spend across all instances
|
||||
|
||||
What this does:
|
||||
1. Push all provider spend increments to Redis
|
||||
2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
"""
|
||||
|
||||
try:
|
||||
# No need to sync if Redis cache is not initialized
|
||||
if self.dual_cache.redis_cache is None:
|
||||
return
|
||||
|
||||
# 1. Push all provider spend increments to Redis
|
||||
await self._push_in_memory_increments_to_redis()
|
||||
|
||||
# 2. Fetch all current provider spend from Redis to update in-memory cache
|
||||
cache_keys = []
|
||||
|
||||
if self.provider_budget_config is not None:
|
||||
for provider, config in self.provider_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(
|
||||
f"provider_spend:{provider}:{config.budget_duration}"
|
||||
)
|
||||
|
||||
if self.deployment_budget_config is not None:
|
||||
for model_id, config in self.deployment_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(
|
||||
f"deployment_spend:{model_id}:{config.budget_duration}"
|
||||
)
|
||||
|
||||
if self.tag_budget_config is not None:
|
||||
for tag, config in self.tag_budget_config.items():
|
||||
if config is None:
|
||||
continue
|
||||
cache_keys.append(f"tag_spend:{tag}:{config.budget_duration}")
|
||||
|
||||
# Batch fetch current spend values from Redis
|
||||
redis_values = await self.dual_cache.redis_cache.async_batch_get_cache(
|
||||
key_list=cache_keys
|
||||
)
|
||||
|
||||
# Update in-memory cache with Redis values
|
||||
if isinstance(redis_values, dict): # Check if redis_values is a dictionary
|
||||
for key, value in redis_values.items():
|
||||
if value is not None:
|
||||
await self.dual_cache.in_memory_cache.async_set_cache(
|
||||
key=key, value=float(value)
|
||||
)
|
||||
verbose_router_logger.debug(
|
||||
f"Updated in-memory cache for {key}: {value}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
f"Error syncing in-memory cache with Redis: {str(e)}"
|
||||
)
|
||||
|
||||
def _get_budget_config_for_deployment(
|
||||
self,
|
||||
model_id: str,
|
||||
) -> Optional[GenericBudgetInfo]:
|
||||
if self.deployment_budget_config is None:
|
||||
return None
|
||||
return self.deployment_budget_config.get(model_id, None)
|
||||
|
||||
def _get_budget_config_for_provider(
|
||||
self, provider: str
|
||||
) -> Optional[GenericBudgetInfo]:
|
||||
if self.provider_budget_config is None:
|
||||
return None
|
||||
return self.provider_budget_config.get(provider, None)
|
||||
|
||||
def _get_budget_config_for_tag(self, tag: str) -> Optional[GenericBudgetInfo]:
|
||||
if self.tag_budget_config is None:
|
||||
return None
|
||||
return self.tag_budget_config.get(tag, None)
|
||||
|
||||
def _get_llm_provider_for_deployment(self, deployment: Dict) -> Optional[str]:
|
||||
try:
|
||||
deployment_litellm_params = deployment.get("litellm_params") or {}
|
||||
|
||||
if isinstance(deployment_litellm_params, LiteLLM_Params):
|
||||
model = deployment_litellm_params.model or ""
|
||||
provider_resolution_params: Any = deployment_litellm_params
|
||||
elif isinstance(deployment_litellm_params, dict):
|
||||
model = deployment_litellm_params.get("model") or ""
|
||||
provider_resolution_params = _LiteLLMParamsDictView(
|
||||
deployment_litellm_params
|
||||
)
|
||||
else:
|
||||
model = ""
|
||||
provider_resolution_params = _LiteLLMParamsDictView({})
|
||||
|
||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=str(model),
|
||||
litellm_params=provider_resolution_params,
|
||||
)
|
||||
except Exception:
|
||||
verbose_router_logger.error(
|
||||
f"Error getting LLM provider for deployment: {deployment}"
|
||||
)
|
||||
return None
|
||||
return custom_llm_provider
|
||||
|
||||
def _track_provider_remaining_budget_prometheus(
|
||||
self, provider: str, spend: float, budget_limit: float
|
||||
):
|
||||
"""
|
||||
Optional helper - emit provider remaining budget metric to Prometheus
|
||||
|
||||
This is helpful for debugging and monitoring provider budget limits.
|
||||
"""
|
||||
|
||||
prometheus_logger = _get_prometheus_logger_from_callbacks()
|
||||
if prometheus_logger:
|
||||
prometheus_logger.track_provider_remaining_budget(
|
||||
provider=provider,
|
||||
spend=spend,
|
||||
budget_limit=budget_limit,
|
||||
)
|
||||
|
||||
async def _get_current_provider_spend(self, provider: str) -> Optional[float]:
|
||||
"""
|
||||
GET the current spend for a provider from cache
|
||||
|
||||
used for GET /provider/budgets endpoint in spend_management_endpoints.py
|
||||
|
||||
Args:
|
||||
provider (str): The provider to get spend for (e.g., "openai", "anthropic")
|
||||
|
||||
Returns:
|
||||
Optional[float]: The current spend for the provider, or None if not found
|
||||
"""
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if budget_config is None:
|
||||
return None
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
|
||||
if self.dual_cache.redis_cache:
|
||||
# use Redis as source of truth since that has spend across all instances
|
||||
current_spend = await self.dual_cache.redis_cache.async_get_cache(spend_key)
|
||||
else:
|
||||
# use in-memory cache if Redis is not initialized
|
||||
current_spend = await self.dual_cache.async_get_cache(spend_key)
|
||||
return float(current_spend) if current_spend is not None else 0.0
|
||||
|
||||
async def _get_current_provider_budget_reset_at(
|
||||
self, provider: str
|
||||
) -> Optional[str]:
|
||||
budget_config = self._get_budget_config_for_provider(provider)
|
||||
if budget_config is None:
|
||||
return None
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
if self.dual_cache.redis_cache:
|
||||
ttl_seconds = await self.dual_cache.redis_cache.async_get_ttl(spend_key)
|
||||
else:
|
||||
ttl_seconds = await self.dual_cache.async_get_ttl(spend_key)
|
||||
|
||||
if ttl_seconds is None:
|
||||
return None
|
||||
|
||||
return (datetime.now(timezone.utc) + timedelta(seconds=ttl_seconds)).isoformat()
|
||||
|
||||
async def _init_provider_budget_in_cache(
|
||||
self, provider: str, budget_config: GenericBudgetInfo
|
||||
):
|
||||
"""
|
||||
Initialize provider budget in cache by storing the following keys if they don't exist:
|
||||
- provider_spend:{provider}:{budget_config.time_period} - stores the current spend
|
||||
- provider_budget_start_time:{provider} - stores the start time of the budget window
|
||||
|
||||
"""
|
||||
|
||||
spend_key = f"provider_spend:{provider}:{budget_config.budget_duration}"
|
||||
start_time_key = f"provider_budget_start_time:{provider}"
|
||||
ttl_seconds: Optional[int] = None
|
||||
if budget_config.budget_duration is not None:
|
||||
ttl_seconds = duration_in_seconds(budget_config.budget_duration)
|
||||
|
||||
budget_start = await self.dual_cache.async_get_cache(start_time_key)
|
||||
if budget_start is None:
|
||||
budget_start = datetime.now(timezone.utc).timestamp()
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=start_time_key, value=budget_start, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
_spend_key = await self.dual_cache.async_get_cache(spend_key)
|
||||
if _spend_key is None:
|
||||
await self.dual_cache.async_set_cache(
|
||||
key=spend_key, value=0.0, ttl=ttl_seconds
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_init_router_budget_limiter(
|
||||
provider_budget_config: Optional[dict],
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
"""
|
||||
Returns `True` if the router budget routing settings are set and RouterBudgetLimiting should be initialized
|
||||
|
||||
Either:
|
||||
- provider_budget_config is set
|
||||
- budgets are set for deployments in the model_list
|
||||
- tag_budget_config is set
|
||||
"""
|
||||
if provider_budget_config is not None:
|
||||
return True
|
||||
|
||||
if litellm.tag_budget_config is not None:
|
||||
return True
|
||||
|
||||
if model_list is None:
|
||||
return False
|
||||
|
||||
for _model in model_list:
|
||||
_litellm_params = _model.get("litellm_params", {})
|
||||
if (
|
||||
_litellm_params.get("max_budget")
|
||||
or _litellm_params.get("budget_duration") is not None
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _init_provider_budgets(self):
|
||||
if self.provider_budget_config is not None:
|
||||
# cast elements of provider_budget_config to GenericBudgetInfo
|
||||
for provider, config in self.provider_budget_config.items():
|
||||
if config is None:
|
||||
raise ValueError(
|
||||
f"No budget config found for provider {provider}, provider_budget_config: {self.provider_budget_config}"
|
||||
)
|
||||
|
||||
if not isinstance(config, GenericBudgetInfo):
|
||||
self.provider_budget_config[provider] = GenericBudgetInfo(
|
||||
budget_limit=config.get("budget_limit"),
|
||||
time_period=config.get("time_period"),
|
||||
)
|
||||
asyncio.create_task(
|
||||
self._init_provider_budget_in_cache(
|
||||
provider=provider,
|
||||
budget_config=self.provider_budget_config[provider],
|
||||
)
|
||||
)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initalized Provider budget config: {self.provider_budget_config}"
|
||||
)
|
||||
|
||||
def _init_deployment_budgets(
|
||||
self,
|
||||
model_list: Optional[
|
||||
Union[List[DeploymentTypedDict], List[Dict[str, Any]]]
|
||||
] = None,
|
||||
):
|
||||
if model_list is None:
|
||||
return
|
||||
for _model in model_list:
|
||||
_litellm_params = _model.get("litellm_params", {})
|
||||
_model_info: Dict = _model.get("model_info") or {}
|
||||
_model_id = _model_info.get("id")
|
||||
_max_budget = _litellm_params.get("max_budget")
|
||||
_budget_duration = _litellm_params.get("budget_duration")
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Init Deployment Budget: max_budget: {_max_budget}, budget_duration: {_budget_duration}, model_id: {_model_id}"
|
||||
)
|
||||
if (
|
||||
_max_budget is not None
|
||||
and _budget_duration is not None
|
||||
and _model_id is not None
|
||||
):
|
||||
_budget_config = GenericBudgetInfo(
|
||||
time_period=_budget_duration,
|
||||
budget_limit=_max_budget,
|
||||
)
|
||||
if self.deployment_budget_config is None:
|
||||
self.deployment_budget_config = {}
|
||||
self.deployment_budget_config[_model_id] = _budget_config
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initialized Deployment Budget Config: {self.deployment_budget_config}"
|
||||
)
|
||||
|
||||
def _init_tag_budgets(self):
|
||||
if litellm.tag_budget_config is None:
|
||||
return
|
||||
from litellm.proxy.proxy_server import CommonProxyErrors, premium_user
|
||||
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Tag budgets are an Enterprise only feature, {CommonProxyErrors.not_premium_user}"
|
||||
)
|
||||
|
||||
if self.tag_budget_config is None:
|
||||
self.tag_budget_config = {}
|
||||
|
||||
for _tag, _tag_budget_config in litellm.tag_budget_config.items():
|
||||
if isinstance(_tag_budget_config, dict):
|
||||
_tag_budget_config = BudgetConfig(**_tag_budget_config)
|
||||
_generic_budget_config = GenericBudgetInfo(
|
||||
time_period=_tag_budget_config.budget_duration,
|
||||
budget_limit=_tag_budget_config.max_budget,
|
||||
)
|
||||
self.tag_budget_config[_tag] = _generic_budget_config
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"Initialized Tag Budget Config: {self.tag_budget_config}"
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
# Complexity Router
|
||||
|
||||
A rule-based routing strategy that classifies requests by complexity and routes them to appropriate models - with zero API calls and sub-millisecond latency.
|
||||
|
||||
## Overview
|
||||
|
||||
Unlike the semantic `auto_router` which uses embedding-based matching, the `complexity_router` uses weighted rule-based scoring across multiple dimensions to classify request complexity. This approach:
|
||||
|
||||
- **Zero external API calls** - all scoring is local
|
||||
- **Sub-millisecond latency** - typically <1ms per classification
|
||||
- **Predictable behavior** - rule-based scoring is deterministic
|
||||
- **Fully configurable** - weights, thresholds, and keyword lists can be customized
|
||||
|
||||
## How It Works
|
||||
|
||||
The router scores each request across 7 dimensions:
|
||||
|
||||
| Dimension | Description | Weight |
|
||||
|-----------|-------------|--------|
|
||||
| `tokenCount` | Short prompts = simple, long = complex | 0.10 |
|
||||
| `codePresence` | Code keywords (function, class, etc.) | 0.30 |
|
||||
| `reasoningMarkers` | "step by step", "think through", etc. | 0.25 |
|
||||
| `technicalTerms` | Domain complexity indicators | 0.25 |
|
||||
| `simpleIndicators` | "what is", "define" (negative weight) | 0.05 |
|
||||
| `multiStepPatterns` | "first...then", numbered steps | 0.03 |
|
||||
| `questionComplexity` | Multiple question marks | 0.02 |
|
||||
|
||||
The weighted sum is mapped to tiers using configurable boundaries:
|
||||
|
||||
| Tier | Score Range | Typical Use |
|
||||
|------|-------------|-------------|
|
||||
| SIMPLE | < 0.15 | Basic questions, greetings |
|
||||
| MEDIUM | 0.15 - 0.35 | Standard queries |
|
||||
| COMPLEX | 0.35 - 0.60 | Technical, multi-part requests |
|
||||
| REASONING | > 0.60 | Chain-of-thought, analysis |
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: smart-router
|
||||
litellm_params:
|
||||
model: auto_router/complexity_router
|
||||
complexity_router_config:
|
||||
tiers:
|
||||
SIMPLE: gpt-4o-mini
|
||||
MEDIUM: gpt-4o
|
||||
COMPLEX: claude-sonnet-4
|
||||
REASONING: o1-preview
|
||||
```
|
||||
|
||||
### Full Configuration
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: smart-router
|
||||
litellm_params:
|
||||
model: auto_router/complexity_router
|
||||
complexity_router_config:
|
||||
# Tier to model mapping
|
||||
tiers:
|
||||
SIMPLE: gpt-4o-mini
|
||||
MEDIUM: gpt-4o
|
||||
COMPLEX: claude-sonnet-4
|
||||
REASONING: o1-preview
|
||||
|
||||
# Tier boundaries (normalized scores)
|
||||
tier_boundaries:
|
||||
simple_medium: 0.15
|
||||
medium_complex: 0.35
|
||||
complex_reasoning: 0.60
|
||||
|
||||
# Token count thresholds
|
||||
token_thresholds:
|
||||
simple: 15 # Below this = "short" (default: 15)
|
||||
complex: 400 # Above this = "long" (default: 400)
|
||||
|
||||
# Dimension weights (must sum to ~1.0)
|
||||
dimension_weights:
|
||||
tokenCount: 0.10
|
||||
codePresence: 0.30
|
||||
reasoningMarkers: 0.25
|
||||
technicalTerms: 0.25
|
||||
simpleIndicators: 0.05
|
||||
multiStepPatterns: 0.03
|
||||
questionComplexity: 0.02
|
||||
|
||||
# Override default keyword lists
|
||||
code_keywords:
|
||||
- function
|
||||
- class
|
||||
- def
|
||||
- async
|
||||
- database
|
||||
|
||||
reasoning_keywords:
|
||||
- step by step
|
||||
- think through
|
||||
- analyze
|
||||
|
||||
# Fallback model if tier cannot be determined
|
||||
default_model: gpt-4o
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Once configured, use the model name like any other:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.completion(
|
||||
model="smart-router", # Your complexity_router model name
|
||||
messages=[{"role": "user", "content": "What is 2+2?"}]
|
||||
)
|
||||
# Routes to SIMPLE tier (gpt-4o-mini)
|
||||
|
||||
response = litellm.completion(
|
||||
model="smart-router",
|
||||
messages=[{"role": "user", "content": "Think step by step: analyze the performance implications of implementing a distributed consensus algorithm for our microservices architecture."}]
|
||||
)
|
||||
# Routes to REASONING tier (o1-preview)
|
||||
```
|
||||
|
||||
## Special Behaviors
|
||||
|
||||
### Reasoning Override
|
||||
|
||||
If 2+ reasoning markers are detected in the user message, the request is automatically routed to the REASONING tier regardless of the weighted score. This ensures complex reasoning tasks get the appropriate model.
|
||||
|
||||
### System Prompt Handling
|
||||
|
||||
Reasoning markers in the system prompt do **not** trigger the reasoning override. This prevents system prompts like "Think step by step before answering" from forcing all requests to the reasoning tier.
|
||||
|
||||
### Code Detection
|
||||
|
||||
Technical code keywords are detected case-insensitively and include:
|
||||
- Language keywords: `function`, `class`, `def`, `const`, `let`, `var`
|
||||
- Operations: `import`, `export`, `return`, `async`, `await`
|
||||
- Infrastructure: `database`, `api`, `endpoint`, `docker`, `kubernetes`
|
||||
- Actions: `debug`, `implement`, `refactor`, `optimize`
|
||||
|
||||
## Performance
|
||||
|
||||
- **Classification time**: <1ms typical
|
||||
- **Memory usage**: Minimal (compiled regex patterns + keyword sets)
|
||||
- **No external dependencies**: Works offline with no API calls
|
||||
|
||||
## Comparison with auto_router
|
||||
|
||||
| Feature | complexity_router | auto_router |
|
||||
|---------|-------------------|-------------|
|
||||
| Classification | Rule-based scoring | Semantic embedding |
|
||||
| Latency | <1ms | ~100-500ms (embedding API) |
|
||||
| API Calls | None | Requires embedding model |
|
||||
| Training | None | Requires utterance examples |
|
||||
| Customization | Weights, keywords, thresholds | Utterance examples |
|
||||
| Best For | Cost optimization | Intent routing |
|
||||
|
||||
Use `complexity_router` when you want to optimize costs by routing simple queries to cheaper models. Use `auto_router` when you need semantic intent matching (e.g., routing "customer support" queries to a specialized model).
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Complexity-based Auto Router
|
||||
|
||||
A rule-based routing strategy that uses weighted scoring across multiple dimensions
|
||||
to classify requests by complexity and route them to appropriate models.
|
||||
|
||||
No external API calls - all scoring is local and <1ms.
|
||||
"""
|
||||
|
||||
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
|
||||
from litellm.router_strategy.complexity_router.config import (
|
||||
ComplexityTier,
|
||||
DEFAULT_COMPLEXITY_CONFIG,
|
||||
ComplexityRouterConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ComplexityRouter",
|
||||
"ComplexityTier",
|
||||
"DEFAULT_COMPLEXITY_CONFIG",
|
||||
"ComplexityRouterConfig",
|
||||
]
|
||||
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Complexity-based Auto Router
|
||||
|
||||
A rule-based routing strategy that uses weighted scoring across multiple dimensions
|
||||
to classify requests by complexity and route them to appropriate models.
|
||||
|
||||
No external API calls - all scoring is local and <1ms.
|
||||
|
||||
Inspired by ClawRouter: https://github.com/BlockRunAI/ClawRouter
|
||||
"""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
from .config import (
|
||||
DEFAULT_CODE_KEYWORDS,
|
||||
DEFAULT_REASONING_KEYWORDS,
|
||||
DEFAULT_SIMPLE_KEYWORDS,
|
||||
DEFAULT_TECHNICAL_KEYWORDS,
|
||||
ComplexityRouterConfig,
|
||||
ComplexityTier,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
else:
|
||||
Router = Any
|
||||
PreRoutingHookResponse = Any
|
||||
|
||||
|
||||
class DimensionScore:
|
||||
"""Represents a score for a single dimension with optional signal."""
|
||||
|
||||
__slots__ = ("name", "score", "signal")
|
||||
|
||||
def __init__(self, name: str, score: float, signal: Optional[str] = None):
|
||||
self.name = name
|
||||
self.score = score
|
||||
self.signal = signal
|
||||
|
||||
|
||||
class ComplexityRouter(CustomLogger):
|
||||
"""
|
||||
Rule-based complexity router that classifies requests and routes to appropriate models.
|
||||
|
||||
Handles requests in <1ms with zero external API calls by using weighted scoring
|
||||
across multiple dimensions:
|
||||
- Token count (short=simple, long=complex)
|
||||
- Code presence (code keywords → complex)
|
||||
- Reasoning markers ("step by step", "think through" → reasoning tier)
|
||||
- Technical terms (domain complexity)
|
||||
- Simple indicators ("what is", "define" → simple, negative weight)
|
||||
- Multi-step patterns ("first...then", numbered steps)
|
||||
- Question complexity (multiple questions)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
litellm_router_instance: "Router",
|
||||
complexity_router_config: Optional[Dict[str, Any]] = None,
|
||||
default_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize ComplexityRouter.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model/deployment using this router.
|
||||
litellm_router_instance: The LiteLLM Router instance.
|
||||
complexity_router_config: Optional configuration dict from proxy config.
|
||||
default_model: Optional default model to use if tier cannot be determined.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
# Parse config - always create a new instance to avoid singleton mutation
|
||||
if complexity_router_config:
|
||||
self.config = ComplexityRouterConfig(**complexity_router_config)
|
||||
else:
|
||||
self.config = ComplexityRouterConfig()
|
||||
|
||||
# Override default_model if provided
|
||||
if default_model:
|
||||
self.config.default_model = default_model
|
||||
|
||||
# Build effective keyword lists (use config overrides or defaults)
|
||||
self.code_keywords = self.config.code_keywords or DEFAULT_CODE_KEYWORDS
|
||||
self.reasoning_keywords = (
|
||||
self.config.reasoning_keywords or DEFAULT_REASONING_KEYWORDS
|
||||
)
|
||||
self.technical_keywords = (
|
||||
self.config.technical_keywords or DEFAULT_TECHNICAL_KEYWORDS
|
||||
)
|
||||
self.simple_keywords = self.config.simple_keywords or DEFAULT_SIMPLE_KEYWORDS
|
||||
|
||||
# Pre-compile regex patterns for efficiency
|
||||
# Use non-greedy .*? to prevent ReDoS on pathological inputs
|
||||
self._multi_step_patterns = [
|
||||
re.compile(r"first.*?then", re.IGNORECASE),
|
||||
re.compile(r"step\s*\d", re.IGNORECASE),
|
||||
re.compile(r"\d+\.\s"),
|
||||
re.compile(r"[a-z]\)\s", re.IGNORECASE),
|
||||
]
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"ComplexityRouter initialized for {model_name} with tiers: {self.config.tiers}"
|
||||
)
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Estimate token count from text.
|
||||
Uses a simple heuristic: ~4 characters per token on average.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _score_token_count(self, estimated_tokens: int) -> DimensionScore:
|
||||
"""Score based on token count."""
|
||||
thresholds = self.config.token_thresholds
|
||||
simple_threshold = thresholds.get("simple", 15)
|
||||
complex_threshold = thresholds.get("complex", 400)
|
||||
|
||||
if estimated_tokens < simple_threshold:
|
||||
return DimensionScore(
|
||||
"tokenCount", -1.0, f"short ({estimated_tokens} tokens)"
|
||||
)
|
||||
if estimated_tokens > complex_threshold:
|
||||
return DimensionScore(
|
||||
"tokenCount", 1.0, f"long ({estimated_tokens} tokens)"
|
||||
)
|
||||
return DimensionScore("tokenCount", 0, None)
|
||||
|
||||
def _keyword_matches(self, text: str, keyword: str) -> bool:
|
||||
"""
|
||||
Check if a keyword matches in text using word boundary matching.
|
||||
|
||||
For single-word keywords, uses regex word boundaries to avoid
|
||||
false positives (e.g., "error" matching "terrorism", "class" matching "classical").
|
||||
For multi-word phrases, uses substring matching.
|
||||
"""
|
||||
kw_lower = keyword.lower()
|
||||
|
||||
# For single-word keywords, use word boundary matching to avoid false positives
|
||||
# e.g., "api" should not match "capital", "error" should not match "terrorism"
|
||||
if " " not in kw_lower:
|
||||
pattern = r"\b" + re.escape(kw_lower) + r"\b"
|
||||
return bool(re.search(pattern, text))
|
||||
|
||||
# For multi-word phrases, substring matching is fine
|
||||
return kw_lower in text
|
||||
|
||||
def _score_keyword_match(
|
||||
self,
|
||||
text: str,
|
||||
keywords: List[str],
|
||||
name: str,
|
||||
signal_label: str,
|
||||
thresholds: Tuple[int, int], # (low, high)
|
||||
scores: Tuple[float, float, float], # (none, low, high)
|
||||
) -> Tuple[DimensionScore, int]:
|
||||
"""Score based on keyword matches using word boundary matching.
|
||||
|
||||
Returns:
|
||||
Tuple of (DimensionScore, match_count) so callers can reuse the count.
|
||||
"""
|
||||
low_threshold, high_threshold = thresholds
|
||||
score_none, score_low, score_high = scores
|
||||
|
||||
matches = [kw for kw in keywords if self._keyword_matches(text, kw)]
|
||||
match_count = len(matches)
|
||||
|
||||
if match_count >= high_threshold:
|
||||
return (
|
||||
DimensionScore(
|
||||
name, score_high, f"{signal_label} ({', '.join(matches[:3])})"
|
||||
),
|
||||
match_count,
|
||||
)
|
||||
if match_count >= low_threshold:
|
||||
return (
|
||||
DimensionScore(
|
||||
name, score_low, f"{signal_label} ({', '.join(matches[:3])})"
|
||||
),
|
||||
match_count,
|
||||
)
|
||||
return DimensionScore(name, score_none, None), match_count
|
||||
|
||||
def _score_multi_step(self, text: str) -> DimensionScore:
|
||||
"""Score based on multi-step patterns."""
|
||||
hits = sum(1 for p in self._multi_step_patterns if p.search(text))
|
||||
if hits > 0:
|
||||
return DimensionScore("multiStepPatterns", 0.5, "multi-step")
|
||||
return DimensionScore("multiStepPatterns", 0, None)
|
||||
|
||||
def _score_question_complexity(self, text: str) -> DimensionScore:
|
||||
"""Score based on number of question marks."""
|
||||
count = text.count("?")
|
||||
if count > 3:
|
||||
return DimensionScore("questionComplexity", 0.5, f"{count} questions")
|
||||
return DimensionScore("questionComplexity", 0, None)
|
||||
|
||||
def classify(
|
||||
self, prompt: str, system_prompt: Optional[str] = None
|
||||
) -> Tuple[ComplexityTier, float, List[str]]:
|
||||
"""
|
||||
Classify a prompt by complexity.
|
||||
|
||||
Args:
|
||||
prompt: The user's prompt/message.
|
||||
system_prompt: Optional system prompt for context.
|
||||
|
||||
Returns:
|
||||
Tuple of (tier, score, signals) where:
|
||||
- tier: The ComplexityTier (SIMPLE, MEDIUM, COMPLEX, REASONING)
|
||||
- score: The raw weighted score
|
||||
- signals: List of triggered signals for debugging
|
||||
"""
|
||||
# Combine text for analysis.
|
||||
# System prompt is intentionally included in code/technical/simple scoring
|
||||
# because it provides deployment-level context (e.g., "You are a Python assistant"
|
||||
# signals that code-capable models are appropriate). Reasoning markers use
|
||||
# user_text only to prevent system prompts from forcing REASONING tier.
|
||||
full_text = f"{system_prompt or ''} {prompt}".lower()
|
||||
user_text = prompt.lower()
|
||||
|
||||
# Estimate tokens
|
||||
estimated_tokens = self._estimate_tokens(prompt)
|
||||
|
||||
# Score all dimensions, capturing match counts where needed
|
||||
code_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.code_keywords,
|
||||
"codePresence",
|
||||
"code",
|
||||
(1, 2),
|
||||
(0, 0.5, 1.0),
|
||||
)
|
||||
reasoning_score, reasoning_match_count = self._score_keyword_match(
|
||||
user_text,
|
||||
self.reasoning_keywords,
|
||||
"reasoningMarkers",
|
||||
"reasoning",
|
||||
(1, 2),
|
||||
(0, 0.7, 1.0),
|
||||
)
|
||||
technical_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.technical_keywords,
|
||||
"technicalTerms",
|
||||
"technical",
|
||||
(2, 4),
|
||||
(0, 0.5, 1.0),
|
||||
)
|
||||
simple_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.simple_keywords,
|
||||
"simpleIndicators",
|
||||
"simple",
|
||||
(1, 2),
|
||||
(0, -1.0, -1.0),
|
||||
)
|
||||
|
||||
dimensions: List[DimensionScore] = [
|
||||
self._score_token_count(estimated_tokens),
|
||||
code_score,
|
||||
reasoning_score,
|
||||
technical_score,
|
||||
simple_score,
|
||||
self._score_multi_step(full_text),
|
||||
self._score_question_complexity(prompt),
|
||||
]
|
||||
|
||||
# Collect signals
|
||||
signals = [d.signal for d in dimensions if d.signal is not None]
|
||||
|
||||
# Compute weighted score
|
||||
weights = self.config.dimension_weights
|
||||
weighted_score = sum(d.score * weights.get(d.name, 0) for d in dimensions)
|
||||
|
||||
# Check for reasoning override (2+ reasoning markers)
|
||||
# Reuse match count from _score_keyword_match to avoid scanning twice
|
||||
if reasoning_match_count >= 2:
|
||||
return ComplexityTier.REASONING, weighted_score, signals
|
||||
|
||||
# Map score to tier
|
||||
boundaries = self.config.tier_boundaries
|
||||
simple_medium = boundaries.get("simple_medium", 0.15)
|
||||
medium_complex = boundaries.get("medium_complex", 0.35)
|
||||
complex_reasoning = boundaries.get("complex_reasoning", 0.60)
|
||||
|
||||
if weighted_score < simple_medium:
|
||||
tier = ComplexityTier.SIMPLE
|
||||
elif weighted_score < medium_complex:
|
||||
tier = ComplexityTier.MEDIUM
|
||||
elif weighted_score < complex_reasoning:
|
||||
tier = ComplexityTier.COMPLEX
|
||||
else:
|
||||
tier = ComplexityTier.REASONING
|
||||
|
||||
return tier, weighted_score, signals
|
||||
|
||||
def get_model_for_tier(self, tier: ComplexityTier) -> str:
|
||||
"""
|
||||
Get the model name for a given complexity tier.
|
||||
|
||||
Args:
|
||||
tier: The complexity tier.
|
||||
|
||||
Returns:
|
||||
The model name configured for that tier.
|
||||
"""
|
||||
tier_key = tier.value if isinstance(tier, ComplexityTier) else tier
|
||||
|
||||
# Check config tiers mapping
|
||||
model = self.config.tiers.get(tier_key)
|
||||
if model:
|
||||
return model
|
||||
|
||||
# Fallback to default model if configured
|
||||
if self.config.default_model:
|
||||
return self.config.default_model
|
||||
|
||||
# Last resort: return MEDIUM tier model or error
|
||||
medium_model = self.config.tiers.get(ComplexityTier.MEDIUM.value)
|
||||
if medium_model:
|
||||
return medium_model
|
||||
|
||||
raise ValueError(
|
||||
f"No model configured for tier {tier_key} and no default_model set"
|
||||
)
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""
|
||||
Pre-routing hook called before the routing decision.
|
||||
|
||||
Classifies the request by complexity and returns the appropriate model.
|
||||
|
||||
Args:
|
||||
model: The original model name requested.
|
||||
request_kwargs: The request kwargs.
|
||||
messages: The messages in the request.
|
||||
input: Optional input for embeddings.
|
||||
specific_deployment: Whether a specific deployment was requested.
|
||||
|
||||
Returns:
|
||||
PreRoutingHookResponse with the routed model, or None if no routing needed.
|
||||
"""
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
if messages is None or len(messages) == 0:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No messages provided, skipping routing"
|
||||
)
|
||||
return None
|
||||
|
||||
# Extract the last user message and the last system prompt
|
||||
user_message: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
# content may be a list of content parts (e.g. [{"type": "text", "text": "..."}])
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
content = " ".join(text_parts).strip()
|
||||
if isinstance(content, str) and content:
|
||||
if role == "user" and user_message is None:
|
||||
user_message = content
|
||||
elif role == "system" and system_prompt is None:
|
||||
system_prompt = content
|
||||
|
||||
if user_message is None:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No user message found, routing to default model"
|
||||
)
|
||||
return PreRoutingHookResponse(
|
||||
model=self.config.default_model
|
||||
or self.get_model_for_tier(ComplexityTier.MEDIUM),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Classify the request
|
||||
tier, score, signals = self.classify(user_message, system_prompt)
|
||||
|
||||
# Get the model for this tier
|
||||
routed_model = self.get_model_for_tier(tier)
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"ComplexityRouter: tier={tier.value}, score={score:.3f}, "
|
||||
f"signals={signals}, routed_model={routed_model}"
|
||||
)
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=routed_model,
|
||||
messages=messages,
|
||||
)
|
||||
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Configuration for the Complexity Router.
|
||||
|
||||
Contains default keyword lists, weights, tier boundaries, and configuration classes.
|
||||
All values are configurable via proxy config.yaml.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ComplexityTier(str, Enum):
|
||||
"""Complexity tiers for routing decisions."""
|
||||
|
||||
SIMPLE = "SIMPLE"
|
||||
MEDIUM = "MEDIUM"
|
||||
COMPLEX = "COMPLEX"
|
||||
REASONING = "REASONING"
|
||||
|
||||
|
||||
# ─── Default Keyword Lists ───
|
||||
# Note: Keywords should be full words/phrases to avoid substring false positives.
|
||||
# The matching logic uses word boundary detection for single-word keywords.
|
||||
|
||||
DEFAULT_CODE_KEYWORDS: List[str] = [
|
||||
"function",
|
||||
"class",
|
||||
"def",
|
||||
"const",
|
||||
"let",
|
||||
"var",
|
||||
"import",
|
||||
"export",
|
||||
"return",
|
||||
"async",
|
||||
"await",
|
||||
"try",
|
||||
"catch",
|
||||
"exception",
|
||||
"error",
|
||||
"debug",
|
||||
"api",
|
||||
"endpoint",
|
||||
"request",
|
||||
"response",
|
||||
"database",
|
||||
"sql",
|
||||
"query",
|
||||
"schema",
|
||||
"algorithm",
|
||||
"implement",
|
||||
"refactor",
|
||||
"optimize",
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"java",
|
||||
"rust",
|
||||
"golang",
|
||||
"react",
|
||||
"vue",
|
||||
"angular",
|
||||
"node",
|
||||
"docker",
|
||||
"kubernetes",
|
||||
"git",
|
||||
"commit",
|
||||
"merge",
|
||||
"branch",
|
||||
"pull request",
|
||||
]
|
||||
|
||||
DEFAULT_REASONING_KEYWORDS: List[str] = [
|
||||
"step by step",
|
||||
"think through",
|
||||
"let's think",
|
||||
"reason through",
|
||||
"analyze this",
|
||||
"break down",
|
||||
"explain your reasoning",
|
||||
"show your work",
|
||||
"chain of thought",
|
||||
"think carefully",
|
||||
"consider all",
|
||||
"evaluate",
|
||||
"pros and cons",
|
||||
"compare and contrast",
|
||||
"weigh the options",
|
||||
"logical",
|
||||
"deduce",
|
||||
"infer",
|
||||
"conclude",
|
||||
]
|
||||
|
||||
DEFAULT_TECHNICAL_KEYWORDS: List[str] = [
|
||||
"architecture",
|
||||
"distributed",
|
||||
"scalable",
|
||||
"microservice",
|
||||
"machine learning",
|
||||
"neural network",
|
||||
"deep learning",
|
||||
"encryption",
|
||||
"authentication",
|
||||
"authorization",
|
||||
"performance",
|
||||
"latency",
|
||||
"throughput",
|
||||
"benchmark",
|
||||
"concurrency",
|
||||
"parallel",
|
||||
"threading",
|
||||
"memory",
|
||||
"cpu",
|
||||
"gpu",
|
||||
"optimization",
|
||||
"protocol",
|
||||
"tcp",
|
||||
"http",
|
||||
"grpc",
|
||||
"websocket",
|
||||
"container",
|
||||
"orchestration",
|
||||
# Note: "async", "kubernetes", "docker" are in DEFAULT_CODE_KEYWORDS
|
||||
]
|
||||
|
||||
DEFAULT_SIMPLE_KEYWORDS: List[str] = [
|
||||
"what is",
|
||||
"what's",
|
||||
"define",
|
||||
"definition of",
|
||||
"who is",
|
||||
"who was",
|
||||
"when did",
|
||||
"when was",
|
||||
"where is",
|
||||
"where was",
|
||||
"how many",
|
||||
"how much",
|
||||
"yes or no",
|
||||
"true or false",
|
||||
"simple",
|
||||
"brief",
|
||||
"short",
|
||||
"quick",
|
||||
"hello",
|
||||
"hi",
|
||||
"hey",
|
||||
"thanks",
|
||||
"thank you",
|
||||
"goodbye",
|
||||
"bye",
|
||||
"okay",
|
||||
# Note: "ok" removed due to false positives (matches "token", "book", etc.)
|
||||
]
|
||||
|
||||
|
||||
# ─── Default Dimension Weights ───
|
||||
|
||||
DEFAULT_DIMENSION_WEIGHTS: Dict[str, float] = {
|
||||
"tokenCount": 0.10, # Reduced - length is less important than content
|
||||
"codePresence": 0.30, # High - code requests need capable models
|
||||
"reasoningMarkers": 0.25, # High - explicit reasoning requests
|
||||
"technicalTerms": 0.25, # High - technical content matters
|
||||
"simpleIndicators": 0.05, # Low - don't over-penalize simple patterns
|
||||
"multiStepPatterns": 0.03,
|
||||
"questionComplexity": 0.02,
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Tier Boundaries ───
|
||||
|
||||
DEFAULT_TIER_BOUNDARIES: Dict[str, float] = {
|
||||
"simple_medium": 0.15, # Lower threshold to catch more MEDIUM cases
|
||||
"medium_complex": 0.35, # Lower threshold to catch technical COMPLEX cases
|
||||
"complex_reasoning": 0.60, # Reasoning tier reserved for explicit reasoning markers
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Token Thresholds ───
|
||||
|
||||
DEFAULT_TOKEN_THRESHOLDS: Dict[str, int] = {
|
||||
"simple": 15, # Only very short prompts (<15 tokens) are penalized
|
||||
"complex": 400, # Long prompts (>400 tokens) get complexity boost
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Tier to Model Mapping ───
|
||||
|
||||
DEFAULT_TIER_MODELS: Dict[str, str] = {
|
||||
"SIMPLE": "gpt-4o-mini",
|
||||
"MEDIUM": "gpt-4o",
|
||||
"COMPLEX": "claude-sonnet-4-20250514",
|
||||
"REASONING": "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
|
||||
class ComplexityRouterConfig(BaseModel):
|
||||
"""Configuration for the ComplexityRouter."""
|
||||
|
||||
# Tier to model mapping
|
||||
tiers: Dict[str, str] = Field(
|
||||
default_factory=lambda: DEFAULT_TIER_MODELS.copy(),
|
||||
description="Mapping of complexity tiers to model names",
|
||||
)
|
||||
|
||||
# Tier boundaries (normalized scores)
|
||||
tier_boundaries: Dict[str, float] = Field(
|
||||
default_factory=lambda: DEFAULT_TIER_BOUNDARIES.copy(),
|
||||
description="Score boundaries between tiers",
|
||||
)
|
||||
|
||||
# Token count thresholds
|
||||
token_thresholds: Dict[str, int] = Field(
|
||||
default_factory=lambda: DEFAULT_TOKEN_THRESHOLDS.copy(),
|
||||
description="Token count thresholds for simple/complex classification",
|
||||
)
|
||||
|
||||
# Dimension weights
|
||||
dimension_weights: Dict[str, float] = Field(
|
||||
default_factory=lambda: DEFAULT_DIMENSION_WEIGHTS.copy(),
|
||||
description="Weights for each scoring dimension",
|
||||
)
|
||||
|
||||
# Keyword lists (overridable)
|
||||
code_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating code-related content",
|
||||
)
|
||||
reasoning_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating reasoning-required content",
|
||||
)
|
||||
technical_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating technical content",
|
||||
)
|
||||
simple_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating simple/basic queries",
|
||||
)
|
||||
|
||||
# Default model if scoring fails
|
||||
default_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Default model to use if tier cannot be determined",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow") # Allow additional fields
|
||||
|
||||
|
||||
# Combined default config
|
||||
DEFAULT_COMPLEXITY_CONFIG = ComplexityRouterConfig()
|
||||
@@ -0,0 +1 @@
|
||||
# Evaluation suite for ComplexityRouter
|
||||
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Evaluation suite for the ComplexityRouter.
|
||||
|
||||
Tests the router's ability to correctly classify prompts into complexity tiers.
|
||||
Run with: python -m litellm.router_strategy.complexity_router.evals.eval_complexity_router
|
||||
"""
|
||||
import os
|
||||
|
||||
# Add parent to path for imports
|
||||
import sys
|
||||
|
||||
# ruff: noqa: T201
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
|
||||
)
|
||||
|
||||
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
|
||||
from litellm.router_strategy.complexity_router.config import ComplexityTier
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalCase:
|
||||
"""A single evaluation case."""
|
||||
|
||||
prompt: str
|
||||
expected_tier: ComplexityTier
|
||||
description: str
|
||||
system_prompt: Optional[str] = None
|
||||
# Allow some flexibility - if actual tier is in acceptable_tiers, still passes
|
||||
acceptable_tiers: Optional[List[ComplexityTier]] = None
|
||||
|
||||
|
||||
# ─── Evaluation Dataset ───
|
||||
|
||||
EVAL_CASES: List[EvalCase] = [
|
||||
# === SIMPLE tier cases ===
|
||||
EvalCase(
|
||||
prompt="Hello!",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Basic greeting",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What is Python?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple definition question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Who is Elon Musk?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple factual question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What's the capital of France?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple geography question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Thanks for your help!",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple thank you",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Define machine learning",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Definition request",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="When was the iPhone released?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple date question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="How many planets are in our solar system?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple count question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Yes",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Single word response",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What time is it in Tokyo?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple time zone question",
|
||||
),
|
||||
# === MEDIUM tier cases ===
|
||||
EvalCase(
|
||||
prompt="Explain how REST APIs work and when to use them",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical explanation",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Write a short poem about the ocean",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Creative writing - short",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Summarize the main differences between SQL and NoSQL databases",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical comparison",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What are the benefits of using TypeScript over JavaScript?",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical comparison question",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Help me debug this error: TypeError: Cannot read property 'map' of undefined",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Debugging help",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
# === COMPLEX tier cases ===
|
||||
EvalCase(
|
||||
prompt="Design a distributed microservice architecture for a high-throughput "
|
||||
"real-time data processing pipeline with Kubernetes orchestration, "
|
||||
"implementing proper authentication and encryption protocols",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex architecture design",
|
||||
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Write a Python function that implements a binary search tree with "
|
||||
"insert, delete, and search operations. Include proper error handling "
|
||||
"and optimize for memory efficiency.",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex coding task",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Explain the differences between TCP and UDP protocols, including "
|
||||
"use cases for each, performance implications, and how they handle "
|
||||
"packet loss in distributed systems",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Deep technical explanation",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Create a comprehensive database schema for an e-commerce platform "
|
||||
"that handles users, products, orders, payments, shipping, reviews, "
|
||||
"and inventory management with proper indexing strategies",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex database design",
|
||||
acceptable_tiers=[
|
||||
ComplexityTier.MEDIUM,
|
||||
ComplexityTier.COMPLEX,
|
||||
ComplexityTier.REASONING,
|
||||
],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Implement a rate limiter using the token bucket algorithm in Python "
|
||||
"that supports multiple rate limit tiers and can be used across "
|
||||
"distributed systems with Redis as the backend",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex distributed systems coding",
|
||||
acceptable_tiers=[
|
||||
ComplexityTier.MEDIUM,
|
||||
ComplexityTier.COMPLEX,
|
||||
ComplexityTier.REASONING,
|
||||
],
|
||||
),
|
||||
# === REASONING tier cases ===
|
||||
EvalCase(
|
||||
prompt="Think step by step about how to solve this: A farmer has 17 sheep. "
|
||||
"All but 9 die. How many are left? Explain your reasoning.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Explicit reasoning request",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Let's think through this carefully. Analyze the pros and cons of "
|
||||
"microservices vs monolithic architecture for a startup with 5 engineers. "
|
||||
"Consider scalability, development speed, and operational complexity.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Multiple reasoning markers + analysis",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Reason through this problem: If I have a function that's O(n^2) and "
|
||||
"I need to process 1 million items, what are my options to optimize it? "
|
||||
"Walk me through each approach step by step.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Algorithm reasoning",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="I need you to think carefully and analyze this code for potential "
|
||||
"security vulnerabilities. Consider injection attacks, authentication "
|
||||
"bypasses, and data exposure risks. Show your reasoning process.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Security analysis with reasoning",
|
||||
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Step by step, explain your reasoning as you evaluate whether we should "
|
||||
"use PostgreSQL or MongoDB for our new project. Consider our requirements: "
|
||||
"complex queries, high write volume, and eventual consistency is acceptable.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Database decision with explicit reasoning",
|
||||
),
|
||||
# === Edge cases / regression tests ===
|
||||
EvalCase(
|
||||
prompt="What is the capital of France?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'capital' should not trigger 'api' keyword",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="I tried to book a flight but the entry form wasn't working",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'tried' and 'entry' should not trigger code keywords",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="The poetry of digital art is fascinating",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'poetry' should not trigger 'try' keyword",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Can you recommend a good book about country music history?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'country' should not trigger 'try' keyword",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_eval() -> Tuple[int, int, List[dict]]:
|
||||
"""
|
||||
Run the evaluation suite.
|
||||
|
||||
Returns:
|
||||
Tuple of (passed, total, failures)
|
||||
"""
|
||||
# Create router with default config
|
||||
mock_router = MagicMock()
|
||||
router = ComplexityRouter(
|
||||
model_name="eval-router",
|
||||
litellm_router_instance=mock_router,
|
||||
)
|
||||
|
||||
passed = 0
|
||||
total = len(EVAL_CASES)
|
||||
failures = []
|
||||
|
||||
print("=" * 70) # noqa: T201
|
||||
print("COMPLEXITY ROUTER EVALUATION") # noqa: T201
|
||||
print("=" * 70) # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
for i, case in enumerate(EVAL_CASES, 1):
|
||||
tier, score, signals = router.classify(case.prompt, case.system_prompt)
|
||||
|
||||
# Check if pass
|
||||
is_exact_match = tier == case.expected_tier
|
||||
is_acceptable = (
|
||||
case.acceptable_tiers is not None and tier in case.acceptable_tiers
|
||||
)
|
||||
is_pass = is_exact_match or is_acceptable
|
||||
|
||||
if is_pass:
|
||||
passed += 1
|
||||
status = "✓ PASS"
|
||||
else:
|
||||
status = "✗ FAIL"
|
||||
failures.append(
|
||||
{
|
||||
"case": i,
|
||||
"description": case.description,
|
||||
"prompt": case.prompt[:80] + "..."
|
||||
if len(case.prompt) > 80
|
||||
else case.prompt,
|
||||
"expected": case.expected_tier.value,
|
||||
"actual": tier.value,
|
||||
"score": round(score, 3),
|
||||
"signals": signals,
|
||||
"acceptable": [t.value for t in case.acceptable_tiers]
|
||||
if case.acceptable_tiers
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Print result
|
||||
print(f"[{i:2d}] {status} | {case.description}") # noqa: T201
|
||||
print(
|
||||
f" Expected: {case.expected_tier.value:10s} | Got: {tier.value:10s} | Score: {score:+.3f}"
|
||||
) # noqa: T201
|
||||
if signals:
|
||||
print(f" Signals: {', '.join(signals)}") # noqa: T201
|
||||
if not is_pass:
|
||||
print(f" Prompt: {case.prompt[:60]}...") # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
# Summary
|
||||
print("=" * 70) # noqa: T201
|
||||
print(f"RESULTS: {passed}/{total} passed ({100*passed/total:.1f}%)") # noqa: T201
|
||||
print("=" * 70) # noqa: T201
|
||||
|
||||
if failures:
|
||||
print("\nFAILURES:") # noqa: T201
|
||||
print("-" * 70) # noqa: T201
|
||||
for f in failures:
|
||||
print(f"Case {f['case']}: {f['description']}") # noqa: T201
|
||||
print(
|
||||
f" Expected: {f['expected']}, Got: {f['actual']} (score: {f['score']})"
|
||||
) # noqa: T201
|
||||
print(f" Signals: {f['signals']}") # noqa: T201
|
||||
if f["acceptable"]:
|
||||
print(f" Acceptable: {f['acceptable']}") # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
return passed, total, failures
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
passed, total, failures = run_eval()
|
||||
|
||||
# Exit with error code if too many failures
|
||||
pass_rate = passed / total
|
||||
if pass_rate < 0.80:
|
||||
print(
|
||||
f"\n❌ EVAL FAILED: Pass rate {pass_rate:.1%} is below 80% threshold"
|
||||
) # noqa: T201
|
||||
sys.exit(1)
|
||||
elif pass_rate < 0.90:
|
||||
print(
|
||||
f"\n⚠️ EVAL WARNING: Pass rate {pass_rate:.1%} is below 90%"
|
||||
) # noqa: T201
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n✅ EVAL PASSED: Pass rate {pass_rate:.1%}") # noqa: T201
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,250 @@
|
||||
#### What this does ####
|
||||
# identifies least busy deployment
|
||||
# How is this achieved?
|
||||
# - Before each call, have the router print the state of requests {"deployment": "requests_in_flight"}
|
||||
# - use litellm.input_callbacks to log when a request is just about to be made to a model - {"deployment-id": traffic}
|
||||
# - use litellm.success + failure callbacks to log when a request completed
|
||||
# - in get_available_deployment, for a given model group name -> pick based on traffic
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class LeastBusyLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(self, router_cache: DualCache):
|
||||
self.router_cache = router_cache
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
"""
|
||||
Log when a model is being used.
|
||||
|
||||
Caching based on model group.
|
||||
"""
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# update cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
self.router_cache.set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_failure += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
# decrement count in cache
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key)
|
||||
or {}
|
||||
)
|
||||
request_count_value: Optional[int] = request_count_dict.get(id, 0)
|
||||
if request_count_value is None:
|
||||
return
|
||||
request_count_dict[id] = request_count_value - 1
|
||||
await self.router_cache.async_set_cache(
|
||||
key=request_count_api_key, value=request_count_dict
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_failure += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _get_available_deployments(
|
||||
self,
|
||||
healthy_deployments: list,
|
||||
all_deployments: dict,
|
||||
):
|
||||
"""
|
||||
Helper to get deployments using least busy strategy
|
||||
"""
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = 0
|
||||
# map deployment to id
|
||||
# pick least busy deployment
|
||||
min_traffic = float("inf")
|
||||
min_deployment = None
|
||||
for k, v in all_deployments.items():
|
||||
if v < min_traffic:
|
||||
min_traffic = v
|
||||
min_deployment = k
|
||||
if min_deployment is not None:
|
||||
## check if min deployment is a string, if so, cast it to int
|
||||
for m in healthy_deployments:
|
||||
if m["model_info"]["id"] == min_deployment:
|
||||
return m
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
else:
|
||||
min_deployment = random.choice(healthy_deployments)
|
||||
return min_deployment
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
):
|
||||
"""
|
||||
Sync helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = self.router_cache.get_cache(key=request_count_api_key) or {}
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self, model_group: str, healthy_deployments: list
|
||||
):
|
||||
"""
|
||||
Async helper to get deployments using least busy strategy
|
||||
"""
|
||||
request_count_api_key = f"{model_group}_request_count"
|
||||
all_deployments = (
|
||||
await self.router_cache.async_get_cache(key=request_count_api_key) or {}
|
||||
)
|
||||
return self._get_available_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
)
|
||||
@@ -0,0 +1,330 @@
|
||||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, token_counter, verbose_logger
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
|
||||
class LowestCostLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
|
||||
self.router_cache = router_cache
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None and isinstance(_usage, litellm.Usage):
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = self.router_cache.get_cache(key=cost_key) or {}
|
||||
|
||||
# check local result first
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(key=cost_key, value=request_count_dict)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.router_strategy.lowest_cost.py::log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update cost usage on success
|
||||
"""
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"cost": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms: timedelta = end_time - start_time
|
||||
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None and isinstance(_usage, litellm.Usage):
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
|
||||
float(response_ms.total_seconds() / completion_tokens)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=cost_key, value=request_count_dict
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest cost
|
||||
"""
|
||||
cost_key = f"{model_group}_map"
|
||||
|
||||
request_count_dict = await self.router_cache.async_get_cache(key=cost_key) or {}
|
||||
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
float("inf")
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
if request_count_dict is None: # base case
|
||||
return
|
||||
|
||||
all_deployments = request_count_dict
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = {
|
||||
precise_minute: {"tpm": 0, "rpm": 0},
|
||||
}
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
|
||||
# randomly sample from all_deployments, incase all deployments have latency=0.0
|
||||
_items = all_deployments.items()
|
||||
|
||||
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
|
||||
potential_deployments = []
|
||||
_cost_per_deployment = {}
|
||||
for item, item_map in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = (
|
||||
_deployment.get("tpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||
or _deployment.get("model_info", {}).get("tpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
|
||||
_deployment_rpm = (
|
||||
_deployment.get("rpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("rpm", None)
|
||||
or _deployment.get("model_info", {}).get("rpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
item_litellm_model_name = _deployment.get("litellm_params", {}).get("model")
|
||||
item_litellm_model_cost_map = litellm.model_cost.get(
|
||||
item_litellm_model_name, {}
|
||||
)
|
||||
|
||||
# check if user provided input_cost_per_token and output_cost_per_token in litellm_params
|
||||
item_input_cost = None
|
||||
item_output_cost = None
|
||||
if _deployment.get("litellm_params", {}).get("input_cost_per_token", None):
|
||||
item_input_cost = _deployment.get("litellm_params", {}).get(
|
||||
"input_cost_per_token"
|
||||
)
|
||||
|
||||
if _deployment.get("litellm_params", {}).get("output_cost_per_token", None):
|
||||
item_output_cost = _deployment.get("litellm_params", {}).get(
|
||||
"output_cost_per_token"
|
||||
)
|
||||
|
||||
if item_input_cost is None:
|
||||
item_input_cost = item_litellm_model_cost_map.get(
|
||||
"input_cost_per_token", 5.0
|
||||
)
|
||||
|
||||
if item_output_cost is None:
|
||||
item_output_cost = item_litellm_model_cost_map.get(
|
||||
"output_cost_per_token", 5.0
|
||||
)
|
||||
|
||||
# if litellm["model"] is not in model_cost map -> use item_cost = $10
|
||||
|
||||
item_cost = item_input_cost + item_output_cost
|
||||
|
||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"item_cost: {item_cost}, item_tpm: {item_tpm}, item_rpm: {item_rpm}, model_id: {_deployment.get('model_info', {}).get('id')}"
|
||||
)
|
||||
|
||||
# -------------- #
|
||||
# Debugging Logic
|
||||
# -------------- #
|
||||
# We use _cost_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
|
||||
# this helps a user to debug why the router picked a specfic deployment #
|
||||
_deployment_api_base = _deployment.get("litellm_params", {}).get(
|
||||
"api_base", ""
|
||||
)
|
||||
if _deployment_api_base is not None:
|
||||
_cost_per_deployment[_deployment_api_base] = item_cost
|
||||
# -------------- #
|
||||
# End of Debugging Logic
|
||||
# -------------- #
|
||||
|
||||
if (
|
||||
item_tpm + input_tokens > _deployment_tpm
|
||||
or item_rpm + 1 > _deployment_rpm
|
||||
): # if user passed in tpm / rpm in the model_list
|
||||
continue
|
||||
else:
|
||||
potential_deployments.append((_deployment, item_cost))
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
return None
|
||||
|
||||
potential_deployments = sorted(potential_deployments, key=lambda x: x[1])
|
||||
|
||||
selected_deployment = potential_deployments[0][0]
|
||||
return selected_deployment
|
||||
@@ -0,0 +1,627 @@
|
||||
#### What this does ####
|
||||
# picks based on response time (for streaming, this is time to first token)
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import ModelResponse, token_counter, verbose_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import safe_divide_seconds
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: float = 1 * 60 * 60 # 1 hour
|
||||
lowest_latency_buffer: float = 0
|
||||
max_latency_list_size: int = 10
|
||||
|
||||
|
||||
class LowestLatencyLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
|
||||
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
|
||||
self.router_cache = router_cache
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
||||
def log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Update latency usage on success
|
||||
"""
|
||||
metadata_field = self._select_metadata_field(kwargs)
|
||||
if kwargs["litellm_params"].get(metadata_field) is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"][metadata_field].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = (kwargs["litellm_params"].get("model_info") or {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms = end_time - start_time
|
||||
time_to_first_token_response_time = None
|
||||
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value: Union[float, timedelta] = response_ms
|
||||
time_to_first_token: Optional[float] = None
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None:
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
|
||||
# Handle both timedelta and float response times
|
||||
if isinstance(response_ms, timedelta):
|
||||
response_seconds = response_ms.total_seconds()
|
||||
else:
|
||||
response_seconds = response_ms
|
||||
|
||||
final_value = safe_divide_seconds(
|
||||
response_seconds, completion_tokens
|
||||
)
|
||||
if final_value is not None:
|
||||
final_value = float(final_value)
|
||||
else:
|
||||
final_value = response_seconds
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
if isinstance(time_to_first_token_response_time, timedelta):
|
||||
ttft_seconds = (
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
)
|
||||
else:
|
||||
ttft_seconds = time_to_first_token_response_time
|
||||
time_to_first_token = safe_divide_seconds(
|
||||
ttft_seconds, completion_tokens
|
||||
)
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(final_value)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
"""
|
||||
Check if Timeout Error, if timeout set deployment latency -> 100
|
||||
"""
|
||||
try:
|
||||
metadata_field = self._select_metadata_field(kwargs)
|
||||
_exception = kwargs.get("exception", None)
|
||||
if isinstance(_exception, litellm.Timeout):
|
||||
if kwargs["litellm_params"].get(metadata_field) is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"][metadata_field].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = (kwargs["litellm_params"].get("model_info") or {}).get(
|
||||
"id", None
|
||||
)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=latency_key) or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency - give 1000s penalty for failing
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(1000.0)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [1000.0]
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key,
|
||||
value=request_count_dict,
|
||||
ttl=self.routing_args.ttl,
|
||||
) # reset map within window
|
||||
else:
|
||||
# do nothing if it's not a timeout error
|
||||
return
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event( # noqa: PLR0915
|
||||
self, kwargs, response_obj, start_time, end_time
|
||||
):
|
||||
try:
|
||||
"""
|
||||
Update latency usage on success
|
||||
"""
|
||||
metadata_field = self._select_metadata_field(kwargs)
|
||||
if kwargs["litellm_params"].get(metadata_field) is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"][metadata_field].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = (kwargs["litellm_params"].get("model_info") or {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
"""
|
||||
{
|
||||
{model_group}_map: {
|
||||
id: {
|
||||
"latency": [..]
|
||||
"time_to_first_token": [..]
|
||||
f"{date:hour:minute}" : {"tpm": 34, "rpm": 3}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
response_ms = end_time - start_time
|
||||
time_to_first_token_response_time = None
|
||||
if kwargs.get("stream", None) is not None and kwargs["stream"] is True:
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
kwargs.get("completion_start_time", end_time) - start_time
|
||||
)
|
||||
|
||||
final_value: Union[float, timedelta] = response_ms
|
||||
total_tokens = 0
|
||||
time_to_first_token: Optional[float] = None
|
||||
|
||||
if isinstance(response_obj, ModelResponse):
|
||||
_usage = getattr(response_obj, "usage", None)
|
||||
if _usage is not None:
|
||||
completion_tokens = _usage.completion_tokens
|
||||
total_tokens = _usage.total_tokens
|
||||
|
||||
# Handle both timedelta and float response times
|
||||
if isinstance(response_ms, timedelta):
|
||||
response_seconds = response_ms.total_seconds()
|
||||
else:
|
||||
response_seconds = response_ms
|
||||
|
||||
final_value = safe_divide_seconds(
|
||||
response_seconds, completion_tokens
|
||||
)
|
||||
if final_value is not None:
|
||||
final_value = float(final_value)
|
||||
else:
|
||||
final_value = response_ms
|
||||
|
||||
if time_to_first_token_response_time is not None:
|
||||
if isinstance(time_to_first_token_response_time, timedelta):
|
||||
ttft_seconds = (
|
||||
time_to_first_token_response_time.total_seconds()
|
||||
)
|
||||
else:
|
||||
ttft_seconds = time_to_first_token_response_time
|
||||
time_to_first_token = safe_divide_seconds(
|
||||
ttft_seconds, completion_tokens
|
||||
)
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(
|
||||
key=latency_key,
|
||||
parent_otel_span=parent_otel_span,
|
||||
local_only=True,
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
if id not in request_count_dict:
|
||||
request_count_dict[id] = {}
|
||||
|
||||
## Latency
|
||||
if (
|
||||
len(request_count_dict[id].get("latency", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault("latency", []).append(final_value)
|
||||
else:
|
||||
request_count_dict[id]["latency"] = request_count_dict[id][
|
||||
"latency"
|
||||
][: self.routing_args.max_latency_list_size - 1] + [final_value]
|
||||
|
||||
## Time to first token
|
||||
if time_to_first_token is not None:
|
||||
if (
|
||||
len(request_count_dict[id].get("time_to_first_token", []))
|
||||
< self.routing_args.max_latency_list_size
|
||||
):
|
||||
request_count_dict[id].setdefault(
|
||||
"time_to_first_token", []
|
||||
).append(time_to_first_token)
|
||||
else:
|
||||
request_count_dict[id][
|
||||
"time_to_first_token"
|
||||
] = request_count_dict[id]["time_to_first_token"][
|
||||
: self.routing_args.max_latency_list_size - 1
|
||||
] + [
|
||||
time_to_first_token
|
||||
]
|
||||
|
||||
if precise_minute not in request_count_dict[id]:
|
||||
request_count_dict[id][precise_minute] = {}
|
||||
|
||||
## TPM
|
||||
request_count_dict[id][precise_minute]["tpm"] = (
|
||||
request_count_dict[id][precise_minute].get("tpm", 0) + total_tokens
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict[id][precise_minute]["rpm"] = (
|
||||
request_count_dict[id][precise_minute].get("rpm", 0) + 1
|
||||
)
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=latency_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
) # reset map within window
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.router_strategy.lowest_latency.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def _get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
request_count_dict: Optional[Dict] = None,
|
||||
):
|
||||
"""Common logic for both sync and async get_available_deployments"""
|
||||
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
_latency_per_deployment = {}
|
||||
lowest_latency = float("inf")
|
||||
|
||||
current_date = datetime.now().strftime("%Y-%m-%d")
|
||||
current_hour = datetime.now().strftime("%H")
|
||||
current_minute = datetime.now().strftime("%M")
|
||||
precise_minute = f"{current_date}-{current_hour}-{current_minute}"
|
||||
|
||||
deployment = None
|
||||
|
||||
if request_count_dict is None: # base case
|
||||
return
|
||||
|
||||
all_deployments = request_count_dict
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in all_deployments:
|
||||
all_deployments[d["model_info"]["id"]] = {
|
||||
"latency": [0],
|
||||
precise_minute: {"tpm": 0, "rpm": 0},
|
||||
}
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
|
||||
# randomly sample from all_deployments, incase all deployments have latency=0.0
|
||||
_items = all_deployments.items()
|
||||
|
||||
_all_deployments = random.sample(list(_items), len(_items))
|
||||
all_deployments = dict(_all_deployments)
|
||||
### GET AVAILABLE DEPLOYMENTS ### filter out any deployments > tpm/rpm limits
|
||||
|
||||
potential_deployments = []
|
||||
for item, item_map in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = (
|
||||
_deployment.get("tpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("tpm", None)
|
||||
or _deployment.get("model_info", {}).get("tpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
|
||||
_deployment_rpm = (
|
||||
_deployment.get("rpm", None)
|
||||
or _deployment.get("litellm_params", {}).get("rpm", None)
|
||||
or _deployment.get("model_info", {}).get("rpm", None)
|
||||
or float("inf")
|
||||
)
|
||||
item_latency = item_map.get("latency", [])
|
||||
item_ttft_latency = item_map.get("time_to_first_token", [])
|
||||
item_rpm = item_map.get(precise_minute, {}).get("rpm", 0)
|
||||
item_tpm = item_map.get(precise_minute, {}).get("tpm", 0)
|
||||
|
||||
# get average latency or average ttft (depending on streaming/non-streaming)
|
||||
total: float = 0.0
|
||||
use_ttft = (
|
||||
request_kwargs is not None
|
||||
and request_kwargs.get("stream", None) is not None
|
||||
and request_kwargs["stream"] is True
|
||||
and len(item_ttft_latency) > 0
|
||||
)
|
||||
if use_ttft:
|
||||
for _call_latency in item_ttft_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_ttft_latency)
|
||||
else:
|
||||
for _call_latency in item_latency:
|
||||
if isinstance(_call_latency, float):
|
||||
total += _call_latency
|
||||
item_latency = total / len(item_latency)
|
||||
|
||||
# -------------- #
|
||||
# Debugging Logic
|
||||
# -------------- #
|
||||
# We use _latency_per_deployment to log to langfuse, slack - this is not used to make a decision on routing
|
||||
# this helps a user to debug why the router picked a specfic deployment #
|
||||
_deployment_api_base = _deployment.get("litellm_params", {}).get(
|
||||
"api_base", ""
|
||||
)
|
||||
if _deployment_api_base is not None:
|
||||
_latency_per_deployment[_deployment_api_base] = item_latency
|
||||
# -------------- #
|
||||
# End of Debugging Logic
|
||||
# -------------- #
|
||||
|
||||
if (
|
||||
item_tpm + input_tokens > _deployment_tpm
|
||||
or item_rpm + 1 > _deployment_rpm
|
||||
): # if user passed in tpm / rpm in the model_list
|
||||
continue
|
||||
else:
|
||||
potential_deployments.append((_deployment, item_latency))
|
||||
|
||||
if len(potential_deployments) == 0:
|
||||
return None
|
||||
|
||||
# Sort potential deployments by latency
|
||||
sorted_deployments = sorted(potential_deployments, key=lambda x: x[1])
|
||||
|
||||
# Find lowest latency deployment
|
||||
lowest_latency = sorted_deployments[0][1]
|
||||
|
||||
# Find deployments within buffer of lowest latency
|
||||
buffer = self.routing_args.lowest_latency_buffer * lowest_latency
|
||||
|
||||
valid_deployments = [
|
||||
x for x in sorted_deployments if x[1] <= lowest_latency + buffer
|
||||
]
|
||||
|
||||
# Pick a random deployment from valid deployments
|
||||
random_valid_deployment = random.choice(valid_deployments)
|
||||
deployment = random_valid_deployment[0]
|
||||
metadata_field = self._select_metadata_field(request_kwargs)
|
||||
if request_kwargs is not None and metadata_field in request_kwargs:
|
||||
request_kwargs[metadata_field][
|
||||
"_latency_per_deployment"
|
||||
] = _latency_per_deployment
|
||||
return deployment
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
# get list of potential deployments
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
||||
request_kwargs
|
||||
)
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
return self._get_available_deployments(
|
||||
model_group,
|
||||
healthy_deployments,
|
||||
messages,
|
||||
input,
|
||||
request_kwargs,
|
||||
request_count_dict,
|
||||
)
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
request_kwargs: Optional[Dict] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest latency
|
||||
"""
|
||||
# get list of potential deployments
|
||||
latency_key = f"{model_group}_map"
|
||||
|
||||
parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs(
|
||||
request_kwargs
|
||||
)
|
||||
request_count_dict = (
|
||||
self.router_cache.get_cache(
|
||||
key=latency_key, parent_otel_span=parent_otel_span
|
||||
)
|
||||
or {}
|
||||
)
|
||||
|
||||
return self._get_available_deployments(
|
||||
model_group,
|
||||
healthy_deployments,
|
||||
messages,
|
||||
input,
|
||||
request_kwargs,
|
||||
request_count_dict,
|
||||
)
|
||||
@@ -0,0 +1,249 @@
|
||||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from litellm import token_counter
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase
|
||||
from litellm.utils import print_verbose
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler(CustomLogger):
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
|
||||
self.router_cache = router_cache
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
if "litellm_params" not in kwargs or kwargs["litellm_params"] is None:
|
||||
return
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
id = kwargs["litellm_params"].get("model_info", {}).get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = response_obj["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
|
||||
## TPM
|
||||
request_count_dict = self.router_cache.get_cache(key=tpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = self.router_cache.get_cache(key=rpm_key) or {}
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
self.router_cache.set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_router_logger.error(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
if "litellm_params" not in kwargs or kwargs["litellm_params"] is None:
|
||||
return
|
||||
if kwargs["litellm_params"].get("metadata") is None:
|
||||
pass
|
||||
else:
|
||||
model_group = kwargs["litellm_params"]["metadata"].get(
|
||||
"model_group", None
|
||||
)
|
||||
|
||||
model_info = kwargs["litellm_params"].get("model_info")
|
||||
id = None
|
||||
if model_info is not None and isinstance(model_info, dict):
|
||||
id = model_info.get("id", None)
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
if "usage" not in response_obj:
|
||||
return
|
||||
total_tokens = response_obj["usage"]["total_tokens"]
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=tpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + total_tokens
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=tpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
## RPM
|
||||
request_count_dict = (
|
||||
await self.router_cache.async_get_cache(key=rpm_key) or {}
|
||||
)
|
||||
request_count_dict[id] = request_count_dict.get(id, 0) + 1
|
||||
|
||||
await self.router_cache.async_set_cache(
|
||||
key=rpm_key, value=request_count_dict, ttl=self.routing_args.ttl
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_router_logger.exception(
|
||||
"litellm.router_strategy.lowest_tpm_rpm.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_router_logger.debug(traceback.format_exc())
|
||||
pass
|
||||
|
||||
def get_available_deployments( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest TPM/RPM usage.
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
current_minute = datetime.now().strftime("%H-%M")
|
||||
tpm_key = f"{model_group}:tpm:{current_minute}"
|
||||
rpm_key = f"{model_group}:rpm:{current_minute}"
|
||||
|
||||
tpm_dict = self.router_cache.get_cache(key=tpm_key)
|
||||
rpm_dict = self.router_cache.get_cache(key=rpm_key)
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"tpm_key={tpm_key}, tpm_dict: {tpm_dict}, rpm_dict: {rpm_dict}"
|
||||
)
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
lowest_tpm = float("inf")
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
if d["model_info"]["id"] not in tpm_dict:
|
||||
tpm_dict[d["model_info"]["id"]] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
|
||||
deployment = None
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
_deployment = None
|
||||
for m in healthy_deployments:
|
||||
if item == m["model_info"]["id"]:
|
||||
_deployment = m
|
||||
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (rpm_dict is not None and item in rpm_dict) and (
|
||||
rpm_dict[item] + 1 >= _deployment_rpm
|
||||
):
|
||||
continue
|
||||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
deployment = _deployment
|
||||
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||
return deployment
|
||||
@@ -0,0 +1,668 @@
|
||||
#### What this does ####
|
||||
# identifies lowest tpm deployment
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import token_counter
|
||||
from litellm._logging import verbose_logger, verbose_router_logger
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.litellm_core_utils.core_helpers import _get_parent_otel_span_from_kwargs
|
||||
from litellm.types.router import RouterErrors
|
||||
from litellm.types.utils import LiteLLMPydanticObjectBase, StandardLoggingPayload
|
||||
from litellm.utils import get_utc_datetime, print_verbose
|
||||
|
||||
from .base_routing_strategy import BaseRoutingStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
|
||||
Span = Union[_Span, Any]
|
||||
else:
|
||||
Span = Any
|
||||
|
||||
|
||||
class RoutingArgs(LiteLLMPydanticObjectBase):
|
||||
ttl: int = 1 * 60 # 1min (RPM/TPM expire key)
|
||||
|
||||
|
||||
class LowestTPMLoggingHandler_v2(BaseRoutingStrategy, CustomLogger):
|
||||
"""
|
||||
Updated version of TPM/RPM Logging.
|
||||
|
||||
Meant to work across instances.
|
||||
|
||||
Caches individual models, not model_groups
|
||||
|
||||
Uses batch get (redis.mget)
|
||||
|
||||
Increments tpm/rpm limit using redis.incr
|
||||
"""
|
||||
|
||||
test_flag: bool = False
|
||||
logged_success: int = 0
|
||||
logged_failure: int = 0
|
||||
default_cache_time_seconds: int = 1 * 60 * 60 # 1 hour
|
||||
|
||||
def __init__(self, router_cache: DualCache, routing_args: dict = {}):
|
||||
self.router_cache = router_cache
|
||||
self.routing_args = RoutingArgs(**routing_args)
|
||||
BaseRoutingStrategy.__init__(
|
||||
self,
|
||||
dual_cache=router_cache,
|
||||
should_batch_redis_writes=True,
|
||||
default_sync_interval=0.1,
|
||||
)
|
||||
|
||||
def pre_call_check(self, deployment: Dict) -> Optional[Dict]:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
try:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
|
||||
local_result = self.router_cache.get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}. id={}, model_group={}. Get the model info by calling 'router.get_model_info(id)".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
model_id,
|
||||
deployment.get("model_name", ""),
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
|
||||
result = self.router_cache.increment_cache(
|
||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||
)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return deployment
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
raise e
|
||||
return deployment # don't fail calls if eg. redis fails to connect
|
||||
|
||||
async def async_pre_call_check(
|
||||
self, deployment: Dict, parent_otel_span: Optional[Span]
|
||||
) -> Optional[Dict]:
|
||||
"""
|
||||
Pre-call check + update model rpm
|
||||
- Used inside semaphore
|
||||
- raise rate limit error if deployment over limit
|
||||
|
||||
Why? solves concurrency issue - https://github.com/BerriAI/litellm/issues/2994
|
||||
|
||||
Returns - deployment
|
||||
|
||||
Raises - RateLimitError if deployment over defined RPM limit
|
||||
"""
|
||||
try:
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
model_id = deployment.get("model_info", {}).get("id")
|
||||
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||
|
||||
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||
local_result = await self.router_cache.async_get_cache(
|
||||
key=rpm_key, local_only=True
|
||||
) # check local result first
|
||||
|
||||
deployment_rpm = None
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("litellm_params", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = deployment.get("model_info", {}).get("rpm")
|
||||
if deployment_rpm is None:
|
||||
deployment_rpm = float("inf")
|
||||
if local_result is not None and local_result >= deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, local_result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
local_result,
|
||||
),
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
num_retries=deployment.get("num_retries"),
|
||||
)
|
||||
else:
|
||||
# if local result below limit, check redis ## prevent unnecessary redis checks
|
||||
result = await self._increment_value_in_current_window(
|
||||
key=rpm_key, value=1, ttl=self.routing_args.ttl
|
||||
)
|
||||
if result is not None and result > deployment_rpm:
|
||||
raise litellm.RateLimitError(
|
||||
message="Deployment over defined rpm limit={}. current usage={}".format(
|
||||
deployment_rpm, result
|
||||
),
|
||||
llm_provider="",
|
||||
model=deployment.get("litellm_params", {}).get("model"),
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="{} rpm limit={}. current usage={}".format(
|
||||
RouterErrors.user_defined_ratelimit_error.value,
|
||||
deployment_rpm,
|
||||
result,
|
||||
),
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
num_retries=deployment.get("num_retries"),
|
||||
)
|
||||
return deployment
|
||||
except Exception as e:
|
||||
if isinstance(e, litellm.RateLimitError):
|
||||
raise e
|
||||
return deployment # don't fail calls if eg. redis fails to connect
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM/RPM usage on success
|
||||
"""
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"].get("litellm_model_name")
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None or model is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
|
||||
total_tokens = standard_logging_object.get("total_tokens")
|
||||
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
|
||||
## TPM
|
||||
self.router_cache.increment_cache(
|
||||
key=tpm_key, value=total_tokens, ttl=self.routing_args.ttl
|
||||
)
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
"""
|
||||
Update TPM usage on success
|
||||
"""
|
||||
standard_logging_object: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("standard_logging_object not passed in.")
|
||||
model_group = standard_logging_object.get("model_group")
|
||||
model = standard_logging_object["hidden_params"]["litellm_model_name"]
|
||||
id = standard_logging_object.get("model_id")
|
||||
if model_group is None or id is None:
|
||||
return
|
||||
elif isinstance(id, int):
|
||||
id = str(id)
|
||||
total_tokens = standard_logging_object.get("total_tokens")
|
||||
# ------------
|
||||
# Setup values
|
||||
# ------------
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime(
|
||||
"%H-%M"
|
||||
) # use the same timezone regardless of system clock
|
||||
|
||||
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||
# ------------
|
||||
# Update usage
|
||||
# ------------
|
||||
# update cache
|
||||
parent_otel_span = _get_parent_otel_span_from_kwargs(kwargs)
|
||||
## TPM
|
||||
await self.router_cache.async_increment_cache(
|
||||
key=tpm_key,
|
||||
value=total_tokens,
|
||||
ttl=self.routing_args.ttl,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
### TESTING ###
|
||||
if self.test_flag:
|
||||
self.logged_success += 1
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"litellm.proxy.hooks.lowest_tpm_rpm_v2.py::async_log_success_event(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
pass
|
||||
|
||||
def _return_potential_deployments(
|
||||
self,
|
||||
healthy_deployments: List[Dict],
|
||||
all_deployments: Dict,
|
||||
input_tokens: int,
|
||||
rpm_dict: Dict,
|
||||
):
|
||||
lowest_tpm = float("inf")
|
||||
potential_deployments = [] # if multiple deployments have the same low value
|
||||
deployment_lookup = {
|
||||
deployment.get("model_info", {}).get("id"): deployment
|
||||
for deployment in healthy_deployments
|
||||
}
|
||||
for item, item_tpm in all_deployments.items():
|
||||
## get the item from model list
|
||||
item = item.split(":")[0]
|
||||
_deployment = deployment_lookup.get(item)
|
||||
if _deployment is None:
|
||||
continue # skip to next one
|
||||
elif item_tpm is None:
|
||||
continue # skip if unhealthy deployment
|
||||
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get("tpm")
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get("rpm")
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
if item_tpm + input_tokens > _deployment_tpm:
|
||||
continue
|
||||
elif (
|
||||
(rpm_dict is not None and item in rpm_dict)
|
||||
and rpm_dict[item] is not None
|
||||
and (rpm_dict[item] + 1 >= _deployment_rpm)
|
||||
):
|
||||
continue
|
||||
elif item_tpm == lowest_tpm:
|
||||
potential_deployments.append(_deployment)
|
||||
elif item_tpm < lowest_tpm:
|
||||
lowest_tpm = item_tpm
|
||||
potential_deployments = [_deployment]
|
||||
return potential_deployments
|
||||
|
||||
def _common_checks_available_deployment( # noqa: PLR0915
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
tpm_keys: list,
|
||||
tpm_values: Optional[list],
|
||||
rpm_keys: list,
|
||||
rpm_values: Optional[list],
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Common checks for get available deployment, across sync + async implementations
|
||||
"""
|
||||
|
||||
if tpm_values is None or rpm_values is None:
|
||||
return None
|
||||
|
||||
tpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(tpm_keys):
|
||||
tpm_dict[tpm_keys[idx].split(":")[0]] = tpm_values[idx]
|
||||
|
||||
rpm_dict = {} # {model_id: 1, ..}
|
||||
for idx, key in enumerate(rpm_keys):
|
||||
rpm_dict[rpm_keys[idx].split(":")[0]] = rpm_values[idx]
|
||||
|
||||
try:
|
||||
input_tokens = token_counter(messages=messages, text=input)
|
||||
except Exception:
|
||||
input_tokens = 0
|
||||
verbose_router_logger.debug(f"input_tokens={input_tokens}")
|
||||
# -----------------------
|
||||
# Find lowest used model
|
||||
# ----------------------
|
||||
|
||||
if tpm_dict is None: # base case - none of the deployments have been used
|
||||
# initialize a tpm dict with {model_id: 0}
|
||||
tpm_dict = {}
|
||||
for deployment in healthy_deployments:
|
||||
tpm_dict[deployment["model_info"]["id"]] = 0
|
||||
else:
|
||||
for d in healthy_deployments:
|
||||
## if healthy deployment not yet used
|
||||
tpm_key = d["model_info"]["id"]
|
||||
if tpm_key not in tpm_dict or tpm_dict[tpm_key] is None:
|
||||
tpm_dict[tpm_key] = 0
|
||||
|
||||
all_deployments = tpm_dict
|
||||
potential_deployments = self._return_potential_deployments(
|
||||
healthy_deployments=healthy_deployments,
|
||||
all_deployments=all_deployments,
|
||||
input_tokens=input_tokens,
|
||||
rpm_dict=rpm_dict,
|
||||
)
|
||||
print_verbose("returning picked lowest tpm/rpm deployment.")
|
||||
|
||||
if len(potential_deployments) > 0:
|
||||
return random.choice(potential_deployments)
|
||||
else:
|
||||
return None
|
||||
|
||||
async def async_get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
):
|
||||
"""
|
||||
Async implementation of get deployments.
|
||||
|
||||
Reduces time to retrieve the tpm/rpm values from cache
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
|
||||
tpm_keys = []
|
||||
rpm_keys = []
|
||||
for m in healthy_deployments:
|
||||
if isinstance(m, dict):
|
||||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
||||
combined_tpm_rpm_keys = tpm_keys + rpm_keys
|
||||
|
||||
combined_tpm_rpm_values = await self.router_cache.async_batch_get_cache(
|
||||
keys=combined_tpm_rpm_keys
|
||||
) # [1, 2, None, ..]
|
||||
|
||||
if combined_tpm_rpm_values is not None:
|
||||
tpm_values = combined_tpm_rpm_values[: len(tpm_keys)]
|
||||
rpm_values = combined_tpm_rpm_values[len(tpm_keys) :]
|
||||
else:
|
||||
tpm_values = None
|
||||
rpm_values = None
|
||||
|
||||
deployment = self._common_checks_available_deployment(
|
||||
model_group=model_group,
|
||||
healthy_deployments=healthy_deployments,
|
||||
tpm_keys=tpm_keys,
|
||||
tpm_values=tpm_values,
|
||||
rpm_keys=rpm_keys,
|
||||
rpm_values=rpm_values,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
||||
try:
|
||||
assert deployment is not None
|
||||
return deployment
|
||||
except Exception:
|
||||
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
|
||||
deployment_dict = {}
|
||||
for index, _deployment in enumerate(healthy_deployments):
|
||||
if isinstance(_deployment, dict):
|
||||
id = _deployment.get("model_info", {}).get("id")
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
### GET CURRENT TPM ###
|
||||
current_tpm = tpm_values[index] if tpm_values else 0
|
||||
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
### GET CURRENT RPM ###
|
||||
current_rpm = rpm_values[index] if rpm_values else 0
|
||||
|
||||
deployment_dict[id] = {
|
||||
"current_tpm": current_tpm,
|
||||
"tpm_limit": _deployment_tpm,
|
||||
"current_rpm": current_rpm,
|
||||
"rpm_limit": _deployment_rpm,
|
||||
}
|
||||
raise litellm.RateLimitError(
|
||||
message=f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}",
|
||||
llm_provider="",
|
||||
model=model_group,
|
||||
response=httpx.Response(
|
||||
status_code=429,
|
||||
content="",
|
||||
headers={"retry-after": str(60)}, # type: ignore
|
||||
request=httpx.Request(method="tpm_rpm_limits", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
|
||||
def get_available_deployments(
|
||||
self,
|
||||
model_group: str,
|
||||
healthy_deployments: list,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
parent_otel_span: Optional[Span] = None,
|
||||
):
|
||||
"""
|
||||
Returns a deployment with the lowest TPM/RPM usage.
|
||||
"""
|
||||
# get list of potential deployments
|
||||
verbose_router_logger.debug(
|
||||
f"get_available_deployments - Usage Based. model_group: {model_group}, healthy_deployments: {healthy_deployments}"
|
||||
)
|
||||
|
||||
dt = get_utc_datetime()
|
||||
current_minute = dt.strftime("%H-%M")
|
||||
tpm_keys = []
|
||||
rpm_keys = []
|
||||
for m in healthy_deployments:
|
||||
if isinstance(m, dict):
|
||||
id = m.get("model_info", {}).get(
|
||||
"id"
|
||||
) # a deployment should always have an 'id'. this is set in router.py
|
||||
deployment_name = m.get("litellm_params", {}).get("model")
|
||||
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||
|
||||
tpm_keys.append(tpm_key)
|
||||
rpm_keys.append(rpm_key)
|
||||
|
||||
tpm_values = self.router_cache.batch_get_cache(
|
||||
keys=tpm_keys, parent_otel_span=parent_otel_span
|
||||
) # [1, 2, None, ..]
|
||||
rpm_values = self.router_cache.batch_get_cache(
|
||||
keys=rpm_keys, parent_otel_span=parent_otel_span
|
||||
) # [1, 2, None, ..]
|
||||
|
||||
deployment = self._common_checks_available_deployment(
|
||||
model_group=model_group,
|
||||
healthy_deployments=healthy_deployments,
|
||||
tpm_keys=tpm_keys,
|
||||
tpm_values=tpm_values,
|
||||
rpm_keys=rpm_keys,
|
||||
rpm_values=rpm_values,
|
||||
messages=messages,
|
||||
input=input,
|
||||
)
|
||||
|
||||
try:
|
||||
assert deployment is not None
|
||||
return deployment
|
||||
except Exception:
|
||||
### GET THE DICT OF TPM / RPM + LIMITS PER DEPLOYMENT ###
|
||||
deployment_dict = {}
|
||||
for index, _deployment in enumerate(healthy_deployments):
|
||||
if isinstance(_deployment, dict):
|
||||
id = _deployment.get("model_info", {}).get("id")
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_tpm = None
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("tpm", None)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("litellm_params", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = _deployment.get("model_info", {}).get(
|
||||
"tpm", None
|
||||
)
|
||||
if _deployment_tpm is None:
|
||||
_deployment_tpm = float("inf")
|
||||
|
||||
### GET CURRENT TPM ###
|
||||
current_tpm = tpm_values[index] if tpm_values else 0
|
||||
|
||||
### GET DEPLOYMENT TPM LIMIT ###
|
||||
_deployment_rpm = None
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("rpm", None)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("litellm_params", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = _deployment.get("model_info", {}).get(
|
||||
"rpm", None
|
||||
)
|
||||
if _deployment_rpm is None:
|
||||
_deployment_rpm = float("inf")
|
||||
|
||||
### GET CURRENT RPM ###
|
||||
current_rpm = rpm_values[index] if rpm_values else 0
|
||||
|
||||
deployment_dict[id] = {
|
||||
"current_tpm": current_tpm,
|
||||
"tpm_limit": _deployment_tpm,
|
||||
"current_rpm": current_rpm,
|
||||
"rpm_limit": _deployment_rpm,
|
||||
}
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_available.value}. Passed model={model_group}. Deployments={deployment_dict}"
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def simple_shuffle(
|
||||
llm_router_instance: LitellmRouter,
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
model: str,
|
||||
) -> Dict:
|
||||
"""
|
||||
Returns a random deployment from the list of healthy deployments.
|
||||
|
||||
If weights are provided, it will return a deployment based on the weights.
|
||||
|
||||
If users pass `rpm` or `tpm`, we do a random weighted pick - based on `rpm`/`tpm`.
|
||||
|
||||
Args:
|
||||
llm_router_instance: LitellmRouter instance
|
||||
healthy_deployments: List of healthy deployments
|
||||
model: Model name
|
||||
|
||||
Returns:
|
||||
Dict: A single healthy deployment
|
||||
"""
|
||||
|
||||
############## Check if 'weight' or 'rpm' or 'tpm' param set for a weighted pick #################
|
||||
for weight_by in ["weight", "rpm", "tpm"]:
|
||||
weight = healthy_deployments[0].get("litellm_params").get(weight_by, None)
|
||||
if weight is not None:
|
||||
weights = [
|
||||
m["litellm_params"].get(weight_by, 0) for m in healthy_deployments
|
||||
]
|
||||
verbose_router_logger.debug(f"\nweight {weights}")
|
||||
total_weight = sum(weights)
|
||||
weights = [weight / total_weight for weight in weights]
|
||||
verbose_router_logger.debug(f"\n weights {weights} by {weight_by}")
|
||||
# Perform weighted random pick
|
||||
selected_index = random.choices(range(len(weights)), weights=weights)[0]
|
||||
verbose_router_logger.debug(f"\n selected index, {selected_index}")
|
||||
deployment = healthy_deployments[selected_index]
|
||||
verbose_router_logger.info(
|
||||
f"get_available_deployment for model: {model}, Selected deployment: {llm_router_instance.print_deployment(deployment) or deployment[0]} for model: {model}"
|
||||
)
|
||||
return deployment or deployment[0]
|
||||
|
||||
############## No RPM/TPM passed, we do a random pick #################
|
||||
item = random.choice(healthy_deployments)
|
||||
return item or item[0]
|
||||
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Use this to route requests between Teams
|
||||
|
||||
- If tags in request is a subset of tags in deployment, return deployment
|
||||
- if deployments are set with default tags, return all default deployment
|
||||
- If no default_deployments are set, return all deployments
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
LitellmRouter = _Router
|
||||
else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def is_valid_deployment_tag(
|
||||
deployment_tags: List[str], request_tags: List[str], match_any: bool = True
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a tag is valid, the matching can be either any or all based on `match_any` flag
|
||||
"""
|
||||
if not request_tags:
|
||||
return False
|
||||
|
||||
dep_set = set(deployment_tags)
|
||||
req_set = set(request_tags)
|
||||
|
||||
if match_any:
|
||||
is_valid_deployment = bool(dep_set & req_set)
|
||||
else:
|
||||
is_valid_deployment = req_set.issubset(dep_set)
|
||||
|
||||
if is_valid_deployment:
|
||||
verbose_logger.debug(
|
||||
"adding deployment with tags: %s, request tags: %s for match_any=%s",
|
||||
deployment_tags,
|
||||
request_tags,
|
||||
match_any,
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
model: str, # used to raise the correct error
|
||||
healthy_deployments: Union[List[Any], Dict[Any, Any]],
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
metadata_variable_name: Literal["metadata", "litellm_metadata"] = "metadata",
|
||||
):
|
||||
"""
|
||||
Returns a list of deployments that match the requested model and tags in the request.
|
||||
|
||||
Executes tag based filtering based on the tags in request metadata and the tags on the deployments
|
||||
"""
|
||||
if llm_router_instance.enable_tag_filtering is not True:
|
||||
return healthy_deployments
|
||||
|
||||
if request_kwargs is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag: request_kwargs is None returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
if healthy_deployments is None:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag: healthy_deployments is None returning healthy_deployments"
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
verbose_logger.debug(
|
||||
"request metadata: %s", request_kwargs.get(metadata_variable_name)
|
||||
)
|
||||
if metadata_variable_name in request_kwargs:
|
||||
metadata = request_kwargs[metadata_variable_name]
|
||||
request_tags = metadata.get("tags")
|
||||
match_any = llm_router_instance.tag_filtering_match_any
|
||||
|
||||
new_healthy_deployments = []
|
||||
default_deployments = []
|
||||
if request_tags:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag routing: router_keys: %s", request_tags
|
||||
)
|
||||
# example this can be router_keys=["free", "custom"]
|
||||
for deployment in healthy_deployments:
|
||||
deployment_litellm_params = deployment.get("litellm_params")
|
||||
deployment_tags = deployment_litellm_params.get("tags")
|
||||
|
||||
verbose_logger.debug(
|
||||
"deployment: %s, deployment_router_keys: %s",
|
||||
deployment,
|
||||
deployment_tags,
|
||||
)
|
||||
|
||||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
if "default" in deployment_tags:
|
||||
default_deployments.append(deployment)
|
||||
|
||||
if len(new_healthy_deployments) == 0 and len(default_deployments) == 0:
|
||||
raise ValueError(
|
||||
f"{RouterErrors.no_deployments_with_tag_routing.value}. Passed model={model} and tags={request_tags}"
|
||||
)
|
||||
|
||||
return (
|
||||
new_healthy_deployments
|
||||
if len(new_healthy_deployments) > 0
|
||||
else default_deployments
|
||||
)
|
||||
|
||||
# for Untagged requests use default deployments if set
|
||||
_default_deployments_with_tags = []
|
||||
for deployment in healthy_deployments:
|
||||
if "default" in deployment.get("litellm_params", {}).get("tags", []):
|
||||
_default_deployments_with_tags.append(deployment)
|
||||
|
||||
if len(_default_deployments_with_tags) > 0:
|
||||
return _default_deployments_with_tags
|
||||
|
||||
# if no default deployment is found, return healthy_deployments
|
||||
verbose_logger.debug(
|
||||
"no tier found in metadata, returning healthy_deployments: %s",
|
||||
healthy_deployments,
|
||||
)
|
||||
return healthy_deployments
|
||||
|
||||
|
||||
def _get_tags_from_request_kwargs(
|
||||
request_kwargs: Optional[Dict[Any, Any]] = None,
|
||||
metadata_variable_name: Literal["metadata", "litellm_metadata"] = "metadata",
|
||||
) -> List[str]:
|
||||
"""
|
||||
Helper to get tags from request kwargs
|
||||
|
||||
Args:
|
||||
request_kwargs: The request kwargs to get tags from
|
||||
|
||||
Returns:
|
||||
List[str]: The tags from the request kwargs
|
||||
"""
|
||||
if request_kwargs is None:
|
||||
return []
|
||||
if metadata_variable_name in request_kwargs:
|
||||
metadata = request_kwargs[metadata_variable_name] or {}
|
||||
tags = metadata.get("tags", [])
|
||||
return tags if tags is not None else []
|
||||
elif "litellm_params" in request_kwargs:
|
||||
litellm_params = request_kwargs["litellm_params"] or {}
|
||||
_metadata = litellm_params.get(metadata_variable_name, {}) or {}
|
||||
tags = _metadata.get("tags", [])
|
||||
return tags if tags is not None else []
|
||||
return []
|
||||
Reference in New Issue
Block a user