chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Auto-Routing Strategy that works with a Semantic Router Config
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
else:
|
||||
Router = Any
|
||||
PreRoutingHookResponse = Any
|
||||
Route = Any
|
||||
|
||||
|
||||
class AutoRouter(CustomLogger):
|
||||
DEFAULT_AUTO_SYNC_VALUE = "local"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
default_model: str,
|
||||
embedding_model: str,
|
||||
litellm_router_instance: "Router",
|
||||
auto_router_config_path: Optional[str] = None,
|
||||
auto_router_config: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Auto-Router class that uses a semantic router to route requests to the appropriate model.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model to use for the auto-router. eg. if model = "auto-router1" then us this router.
|
||||
auto_router_config_path: The path to the router config file.
|
||||
auto_router_config: The config to use for the auto-router. You can either use this or auto_router_config_path, not both.
|
||||
default_model: The default model to use if no route is found.
|
||||
embedding_model: The embedding model to use for the auto-router.
|
||||
litellm_router_instance: The instance of the LiteLLM Router.
|
||||
"""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
self.auto_router_config_path: Optional[str] = auto_router_config_path
|
||||
self.auto_router_config: Optional[str] = auto_router_config
|
||||
self.auto_sync_value = self.DEFAULT_AUTO_SYNC_VALUE
|
||||
self.loaded_routes: List[Route] = self._load_semantic_routing_routes()
|
||||
self.routelayer: Optional[SemanticRouter] = None
|
||||
self.default_model = default_model
|
||||
self.embedding_model: str = embedding_model
|
||||
self.litellm_router_instance: "Router" = litellm_router_instance
|
||||
|
||||
def _load_semantic_routing_routes(self) -> List[Route]:
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
if self.auto_router_config_path:
|
||||
return SemanticRouter.from_json(self.auto_router_config_path).routes
|
||||
elif self.auto_router_config:
|
||||
return self._load_auto_router_routes_from_config_json()
|
||||
else:
|
||||
raise ValueError("No router config provided")
|
||||
|
||||
def _load_auto_router_routes_from_config_json(self) -> List[Route]:
|
||||
import json
|
||||
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
if self.auto_router_config is None:
|
||||
raise ValueError("No auto router config provided")
|
||||
auto_router_routes: List[Route] = []
|
||||
loaded_config = json.loads(self.auto_router_config)
|
||||
for route in loaded_config.get("routes", []):
|
||||
auto_router_routes.append(
|
||||
Route(
|
||||
name=route.get("name"),
|
||||
description=route.get("description"),
|
||||
utterances=route.get("utterances", []),
|
||||
score_threshold=route.get("score_threshold"),
|
||||
)
|
||||
)
|
||||
return auto_router_routes
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, str]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""
|
||||
This hook is called before the routing decision is made.
|
||||
|
||||
Used for the litellm auto-router to modify the request before the routing decision is made.
|
||||
"""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
from semantic_router.schema import RouteChoice
|
||||
|
||||
from litellm.router_strategy.auto_router.litellm_encoder import (
|
||||
LiteLLMRouterEncoder,
|
||||
)
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
if messages is None:
|
||||
# do nothing, return same inputs
|
||||
return None
|
||||
|
||||
if self.routelayer is None:
|
||||
#######################
|
||||
# Create the route layer
|
||||
#######################
|
||||
self.routelayer = SemanticRouter(
|
||||
routes=self.loaded_routes,
|
||||
encoder=LiteLLMRouterEncoder(
|
||||
litellm_router_instance=self.litellm_router_instance,
|
||||
model_name=self.embedding_model,
|
||||
),
|
||||
auto_sync=self.auto_sync_value,
|
||||
)
|
||||
|
||||
user_message: Dict[str, str] = messages[-1]
|
||||
message_content: str = user_message.get("content", "")
|
||||
route_choice: Optional[Union[RouteChoice, List[RouteChoice]]] = self.routelayer(
|
||||
text=message_content
|
||||
)
|
||||
verbose_router_logger.debug(f"route_choice: {route_choice}")
|
||||
if isinstance(route_choice, RouteChoice):
|
||||
model = route_choice.name or self.default_model
|
||||
elif isinstance(route_choice, list):
|
||||
model = route_choice[0].name or self.default_model
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=model,
|
||||
messages=messages,
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from semantic_router.encoders import DenseEncoder
|
||||
from semantic_router.encoders.base import AsymmetricDenseMixin
|
||||
|
||||
import litellm
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router
|
||||
else:
|
||||
Router = Any
|
||||
|
||||
|
||||
def litellm_to_list(embeds: litellm.EmbeddingResponse) -> list[list[float]]:
|
||||
"""Convert a LiteLLM embedding response to a list of embeddings.
|
||||
|
||||
:param embeds: The LiteLLM embedding response.
|
||||
:return: A list of embeddings.
|
||||
"""
|
||||
if (
|
||||
not embeds
|
||||
or not isinstance(embeds, litellm.EmbeddingResponse)
|
||||
or not embeds.data
|
||||
):
|
||||
raise ValueError("No embeddings found in LiteLLM embedding response.")
|
||||
return [x["embedding"] for x in embeds.data]
|
||||
|
||||
|
||||
class CustomDenseEncoder(DenseEncoder):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
def __init__(self, litellm_router_instance: Optional["Router"] = None, **kwargs):
|
||||
# Extract litellm_router_instance from kwargs if passed there
|
||||
if "litellm_router_instance" in kwargs:
|
||||
litellm_router_instance = kwargs.pop("litellm_router_instance")
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
|
||||
class LiteLLMRouterEncoder(CustomDenseEncoder, AsymmetricDenseMixin):
|
||||
"""LiteLLM encoder class for generating embeddings using LiteLLM.
|
||||
|
||||
The LiteLLMRouterEncoder class is a subclass of DenseEncoder and utilizes the LiteLLM Router SDK
|
||||
to generate embeddings for given documents. It supports all encoders supported by LiteLLM
|
||||
and supports customization of the score threshold for filtering or processing the embeddings.
|
||||
"""
|
||||
|
||||
type: str = "internal_litellm_router"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
litellm_router_instance: "Router",
|
||||
model_name: str,
|
||||
score_threshold: Union[float, None] = None,
|
||||
):
|
||||
"""Initialize the LiteLLMEncoder.
|
||||
|
||||
:param litellm_router_instance: The instance of the LiteLLM Router.
|
||||
:type litellm_router_instance: Router
|
||||
:param model_name: The name of the embedding model to use. Must use LiteLLM naming
|
||||
convention (e.g. "openai/text-embedding-3-small" or "mistral/mistral-embed").
|
||||
:type model_name: str
|
||||
:param score_threshold: The score threshold for the embeddings.
|
||||
:type score_threshold: float
|
||||
"""
|
||||
super().__init__(
|
||||
name=model_name,
|
||||
score_threshold=score_threshold if score_threshold is not None else 0.3,
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
def __call__(self, docs: list[Any], **kwargs) -> list[list[float]]:
|
||||
"""Encode a list of text documents into embeddings using LiteLLM.
|
||||
|
||||
:param docs: List of text documents to encode.
|
||||
:return: List of embeddings for each document."""
|
||||
return self.encode_queries(docs, **kwargs)
|
||||
|
||||
async def acall(self, docs: list[Any], **kwargs) -> list[list[float]]:
|
||||
"""Encode a list of documents into embeddings using LiteLLM asynchronously.
|
||||
|
||||
:param docs: List of documents to encode.
|
||||
:return: List of embeddings for each document."""
|
||||
return await self.aencode_queries(docs, **kwargs)
|
||||
|
||||
def encode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = self.litellm_router_instance.embedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
def encode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = self.litellm_router_instance.embedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
async def aencode_queries(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = await self.litellm_router_instance.aembedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
|
||||
async def aencode_documents(self, docs: list[str], **kwargs) -> list[list[float]]:
|
||||
if self.litellm_router_instance is None:
|
||||
raise ValueError("litellm_router_instance is not set")
|
||||
try:
|
||||
embeds = await self.litellm_router_instance.aembedding(
|
||||
input=docs, model=self.model_name, **kwargs
|
||||
)
|
||||
return litellm_to_list(embeds)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{self.type.capitalize()} API call failed. Error: {e}"
|
||||
) from e
|
||||
Reference in New Issue
Block a user