chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,162 @@
|
||||
# Complexity Router
|
||||
|
||||
A rule-based routing strategy that classifies requests by complexity and routes them to appropriate models - with zero API calls and sub-millisecond latency.
|
||||
|
||||
## Overview
|
||||
|
||||
Unlike the semantic `auto_router` which uses embedding-based matching, the `complexity_router` uses weighted rule-based scoring across multiple dimensions to classify request complexity. This approach:
|
||||
|
||||
- **Zero external API calls** - all scoring is local
|
||||
- **Sub-millisecond latency** - typically <1ms per classification
|
||||
- **Predictable behavior** - rule-based scoring is deterministic
|
||||
- **Fully configurable** - weights, thresholds, and keyword lists can be customized
|
||||
|
||||
## How It Works
|
||||
|
||||
The router scores each request across 7 dimensions:
|
||||
|
||||
| Dimension | Description | Weight |
|
||||
|-----------|-------------|--------|
|
||||
| `tokenCount` | Short prompts = simple, long = complex | 0.10 |
|
||||
| `codePresence` | Code keywords (function, class, etc.) | 0.30 |
|
||||
| `reasoningMarkers` | "step by step", "think through", etc. | 0.25 |
|
||||
| `technicalTerms` | Domain complexity indicators | 0.25 |
|
||||
| `simpleIndicators` | "what is", "define" (negative weight) | 0.05 |
|
||||
| `multiStepPatterns` | "first...then", numbered steps | 0.03 |
|
||||
| `questionComplexity` | Multiple question marks | 0.02 |
|
||||
|
||||
The weighted sum is mapped to tiers using configurable boundaries:
|
||||
|
||||
| Tier | Score Range | Typical Use |
|
||||
|------|-------------|-------------|
|
||||
| SIMPLE | < 0.15 | Basic questions, greetings |
|
||||
| MEDIUM | 0.15 - 0.35 | Standard queries |
|
||||
| COMPLEX | 0.35 - 0.60 | Technical, multi-part requests |
|
||||
| REASONING | > 0.60 | Chain-of-thought, analysis |
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: smart-router
|
||||
litellm_params:
|
||||
model: auto_router/complexity_router
|
||||
complexity_router_config:
|
||||
tiers:
|
||||
SIMPLE: gpt-4o-mini
|
||||
MEDIUM: gpt-4o
|
||||
COMPLEX: claude-sonnet-4
|
||||
REASONING: o1-preview
|
||||
```
|
||||
|
||||
### Full Configuration
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: smart-router
|
||||
litellm_params:
|
||||
model: auto_router/complexity_router
|
||||
complexity_router_config:
|
||||
# Tier to model mapping
|
||||
tiers:
|
||||
SIMPLE: gpt-4o-mini
|
||||
MEDIUM: gpt-4o
|
||||
COMPLEX: claude-sonnet-4
|
||||
REASONING: o1-preview
|
||||
|
||||
# Tier boundaries (normalized scores)
|
||||
tier_boundaries:
|
||||
simple_medium: 0.15
|
||||
medium_complex: 0.35
|
||||
complex_reasoning: 0.60
|
||||
|
||||
# Token count thresholds
|
||||
token_thresholds:
|
||||
simple: 15 # Below this = "short" (default: 15)
|
||||
complex: 400 # Above this = "long" (default: 400)
|
||||
|
||||
# Dimension weights (must sum to ~1.0)
|
||||
dimension_weights:
|
||||
tokenCount: 0.10
|
||||
codePresence: 0.30
|
||||
reasoningMarkers: 0.25
|
||||
technicalTerms: 0.25
|
||||
simpleIndicators: 0.05
|
||||
multiStepPatterns: 0.03
|
||||
questionComplexity: 0.02
|
||||
|
||||
# Override default keyword lists
|
||||
code_keywords:
|
||||
- function
|
||||
- class
|
||||
- def
|
||||
- async
|
||||
- database
|
||||
|
||||
reasoning_keywords:
|
||||
- step by step
|
||||
- think through
|
||||
- analyze
|
||||
|
||||
# Fallback model if tier cannot be determined
|
||||
default_model: gpt-4o
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Once configured, use the model name like any other:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.completion(
|
||||
model="smart-router", # Your complexity_router model name
|
||||
messages=[{"role": "user", "content": "What is 2+2?"}]
|
||||
)
|
||||
# Routes to SIMPLE tier (gpt-4o-mini)
|
||||
|
||||
response = litellm.completion(
|
||||
model="smart-router",
|
||||
messages=[{"role": "user", "content": "Think step by step: analyze the performance implications of implementing a distributed consensus algorithm for our microservices architecture."}]
|
||||
)
|
||||
# Routes to REASONING tier (o1-preview)
|
||||
```
|
||||
|
||||
## Special Behaviors
|
||||
|
||||
### Reasoning Override
|
||||
|
||||
If 2+ reasoning markers are detected in the user message, the request is automatically routed to the REASONING tier regardless of the weighted score. This ensures complex reasoning tasks get the appropriate model.
|
||||
|
||||
### System Prompt Handling
|
||||
|
||||
Reasoning markers in the system prompt do **not** trigger the reasoning override. This prevents system prompts like "Think step by step before answering" from forcing all requests to the reasoning tier.
|
||||
|
||||
### Code Detection
|
||||
|
||||
Technical code keywords are detected case-insensitively and include:
|
||||
- Language keywords: `function`, `class`, `def`, `const`, `let`, `var`
|
||||
- Operations: `import`, `export`, `return`, `async`, `await`
|
||||
- Infrastructure: `database`, `api`, `endpoint`, `docker`, `kubernetes`
|
||||
- Actions: `debug`, `implement`, `refactor`, `optimize`
|
||||
|
||||
## Performance
|
||||
|
||||
- **Classification time**: <1ms typical
|
||||
- **Memory usage**: Minimal (compiled regex patterns + keyword sets)
|
||||
- **No external dependencies**: Works offline with no API calls
|
||||
|
||||
## Comparison with auto_router
|
||||
|
||||
| Feature | complexity_router | auto_router |
|
||||
|---------|-------------------|-------------|
|
||||
| Classification | Rule-based scoring | Semantic embedding |
|
||||
| Latency | <1ms | ~100-500ms (embedding API) |
|
||||
| API Calls | None | Requires embedding model |
|
||||
| Training | None | Requires utterance examples |
|
||||
| Customization | Weights, keywords, thresholds | Utterance examples |
|
||||
| Best For | Cost optimization | Intent routing |
|
||||
|
||||
Use `complexity_router` when you want to optimize costs by routing simple queries to cheaper models. Use `auto_router` when you need semantic intent matching (e.g., routing "customer support" queries to a specialized model).
|
||||
@@ -0,0 +1,22 @@
|
||||
"""
|
||||
Complexity-based Auto Router
|
||||
|
||||
A rule-based routing strategy that uses weighted scoring across multiple dimensions
|
||||
to classify requests by complexity and route them to appropriate models.
|
||||
|
||||
No external API calls - all scoring is local and <1ms.
|
||||
"""
|
||||
|
||||
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
|
||||
from litellm.router_strategy.complexity_router.config import (
|
||||
ComplexityTier,
|
||||
DEFAULT_COMPLEXITY_CONFIG,
|
||||
ComplexityRouterConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ComplexityRouter",
|
||||
"ComplexityTier",
|
||||
"DEFAULT_COMPLEXITY_CONFIG",
|
||||
"ComplexityRouterConfig",
|
||||
]
|
||||
@@ -0,0 +1,410 @@
|
||||
"""
|
||||
Complexity-based Auto Router
|
||||
|
||||
A rule-based routing strategy that uses weighted scoring across multiple dimensions
|
||||
to classify requests by complexity and route them to appropriate models.
|
||||
|
||||
No external API calls - all scoring is local and <1ms.
|
||||
|
||||
Inspired by ClawRouter: https://github.com/BlockRunAI/ClawRouter
|
||||
"""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
from .config import (
|
||||
DEFAULT_CODE_KEYWORDS,
|
||||
DEFAULT_REASONING_KEYWORDS,
|
||||
DEFAULT_SIMPLE_KEYWORDS,
|
||||
DEFAULT_TECHNICAL_KEYWORDS,
|
||||
ComplexityRouterConfig,
|
||||
ComplexityTier,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
else:
|
||||
Router = Any
|
||||
PreRoutingHookResponse = Any
|
||||
|
||||
|
||||
class DimensionScore:
|
||||
"""Represents a score for a single dimension with optional signal."""
|
||||
|
||||
__slots__ = ("name", "score", "signal")
|
||||
|
||||
def __init__(self, name: str, score: float, signal: Optional[str] = None):
|
||||
self.name = name
|
||||
self.score = score
|
||||
self.signal = signal
|
||||
|
||||
|
||||
class ComplexityRouter(CustomLogger):
|
||||
"""
|
||||
Rule-based complexity router that classifies requests and routes to appropriate models.
|
||||
|
||||
Handles requests in <1ms with zero external API calls by using weighted scoring
|
||||
across multiple dimensions:
|
||||
- Token count (short=simple, long=complex)
|
||||
- Code presence (code keywords → complex)
|
||||
- Reasoning markers ("step by step", "think through" → reasoning tier)
|
||||
- Technical terms (domain complexity)
|
||||
- Simple indicators ("what is", "define" → simple, negative weight)
|
||||
- Multi-step patterns ("first...then", numbered steps)
|
||||
- Question complexity (multiple questions)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
litellm_router_instance: "Router",
|
||||
complexity_router_config: Optional[Dict[str, Any]] = None,
|
||||
default_model: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Initialize ComplexityRouter.
|
||||
|
||||
Args:
|
||||
model_name: The name of the model/deployment using this router.
|
||||
litellm_router_instance: The LiteLLM Router instance.
|
||||
complexity_router_config: Optional configuration dict from proxy config.
|
||||
default_model: Optional default model to use if tier cannot be determined.
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.litellm_router_instance = litellm_router_instance
|
||||
|
||||
# Parse config - always create a new instance to avoid singleton mutation
|
||||
if complexity_router_config:
|
||||
self.config = ComplexityRouterConfig(**complexity_router_config)
|
||||
else:
|
||||
self.config = ComplexityRouterConfig()
|
||||
|
||||
# Override default_model if provided
|
||||
if default_model:
|
||||
self.config.default_model = default_model
|
||||
|
||||
# Build effective keyword lists (use config overrides or defaults)
|
||||
self.code_keywords = self.config.code_keywords or DEFAULT_CODE_KEYWORDS
|
||||
self.reasoning_keywords = (
|
||||
self.config.reasoning_keywords or DEFAULT_REASONING_KEYWORDS
|
||||
)
|
||||
self.technical_keywords = (
|
||||
self.config.technical_keywords or DEFAULT_TECHNICAL_KEYWORDS
|
||||
)
|
||||
self.simple_keywords = self.config.simple_keywords or DEFAULT_SIMPLE_KEYWORDS
|
||||
|
||||
# Pre-compile regex patterns for efficiency
|
||||
# Use non-greedy .*? to prevent ReDoS on pathological inputs
|
||||
self._multi_step_patterns = [
|
||||
re.compile(r"first.*?then", re.IGNORECASE),
|
||||
re.compile(r"step\s*\d", re.IGNORECASE),
|
||||
re.compile(r"\d+\.\s"),
|
||||
re.compile(r"[a-z]\)\s", re.IGNORECASE),
|
||||
]
|
||||
|
||||
verbose_router_logger.debug(
|
||||
f"ComplexityRouter initialized for {model_name} with tiers: {self.config.tiers}"
|
||||
)
|
||||
|
||||
def _estimate_tokens(self, text: str) -> int:
|
||||
"""
|
||||
Estimate token count from text.
|
||||
Uses a simple heuristic: ~4 characters per token on average.
|
||||
"""
|
||||
return len(text) // 4
|
||||
|
||||
def _score_token_count(self, estimated_tokens: int) -> DimensionScore:
|
||||
"""Score based on token count."""
|
||||
thresholds = self.config.token_thresholds
|
||||
simple_threshold = thresholds.get("simple", 15)
|
||||
complex_threshold = thresholds.get("complex", 400)
|
||||
|
||||
if estimated_tokens < simple_threshold:
|
||||
return DimensionScore(
|
||||
"tokenCount", -1.0, f"short ({estimated_tokens} tokens)"
|
||||
)
|
||||
if estimated_tokens > complex_threshold:
|
||||
return DimensionScore(
|
||||
"tokenCount", 1.0, f"long ({estimated_tokens} tokens)"
|
||||
)
|
||||
return DimensionScore("tokenCount", 0, None)
|
||||
|
||||
def _keyword_matches(self, text: str, keyword: str) -> bool:
|
||||
"""
|
||||
Check if a keyword matches in text using word boundary matching.
|
||||
|
||||
For single-word keywords, uses regex word boundaries to avoid
|
||||
false positives (e.g., "error" matching "terrorism", "class" matching "classical").
|
||||
For multi-word phrases, uses substring matching.
|
||||
"""
|
||||
kw_lower = keyword.lower()
|
||||
|
||||
# For single-word keywords, use word boundary matching to avoid false positives
|
||||
# e.g., "api" should not match "capital", "error" should not match "terrorism"
|
||||
if " " not in kw_lower:
|
||||
pattern = r"\b" + re.escape(kw_lower) + r"\b"
|
||||
return bool(re.search(pattern, text))
|
||||
|
||||
# For multi-word phrases, substring matching is fine
|
||||
return kw_lower in text
|
||||
|
||||
def _score_keyword_match(
|
||||
self,
|
||||
text: str,
|
||||
keywords: List[str],
|
||||
name: str,
|
||||
signal_label: str,
|
||||
thresholds: Tuple[int, int], # (low, high)
|
||||
scores: Tuple[float, float, float], # (none, low, high)
|
||||
) -> Tuple[DimensionScore, int]:
|
||||
"""Score based on keyword matches using word boundary matching.
|
||||
|
||||
Returns:
|
||||
Tuple of (DimensionScore, match_count) so callers can reuse the count.
|
||||
"""
|
||||
low_threshold, high_threshold = thresholds
|
||||
score_none, score_low, score_high = scores
|
||||
|
||||
matches = [kw for kw in keywords if self._keyword_matches(text, kw)]
|
||||
match_count = len(matches)
|
||||
|
||||
if match_count >= high_threshold:
|
||||
return (
|
||||
DimensionScore(
|
||||
name, score_high, f"{signal_label} ({', '.join(matches[:3])})"
|
||||
),
|
||||
match_count,
|
||||
)
|
||||
if match_count >= low_threshold:
|
||||
return (
|
||||
DimensionScore(
|
||||
name, score_low, f"{signal_label} ({', '.join(matches[:3])})"
|
||||
),
|
||||
match_count,
|
||||
)
|
||||
return DimensionScore(name, score_none, None), match_count
|
||||
|
||||
def _score_multi_step(self, text: str) -> DimensionScore:
|
||||
"""Score based on multi-step patterns."""
|
||||
hits = sum(1 for p in self._multi_step_patterns if p.search(text))
|
||||
if hits > 0:
|
||||
return DimensionScore("multiStepPatterns", 0.5, "multi-step")
|
||||
return DimensionScore("multiStepPatterns", 0, None)
|
||||
|
||||
def _score_question_complexity(self, text: str) -> DimensionScore:
|
||||
"""Score based on number of question marks."""
|
||||
count = text.count("?")
|
||||
if count > 3:
|
||||
return DimensionScore("questionComplexity", 0.5, f"{count} questions")
|
||||
return DimensionScore("questionComplexity", 0, None)
|
||||
|
||||
def classify(
|
||||
self, prompt: str, system_prompt: Optional[str] = None
|
||||
) -> Tuple[ComplexityTier, float, List[str]]:
|
||||
"""
|
||||
Classify a prompt by complexity.
|
||||
|
||||
Args:
|
||||
prompt: The user's prompt/message.
|
||||
system_prompt: Optional system prompt for context.
|
||||
|
||||
Returns:
|
||||
Tuple of (tier, score, signals) where:
|
||||
- tier: The ComplexityTier (SIMPLE, MEDIUM, COMPLEX, REASONING)
|
||||
- score: The raw weighted score
|
||||
- signals: List of triggered signals for debugging
|
||||
"""
|
||||
# Combine text for analysis.
|
||||
# System prompt is intentionally included in code/technical/simple scoring
|
||||
# because it provides deployment-level context (e.g., "You are a Python assistant"
|
||||
# signals that code-capable models are appropriate). Reasoning markers use
|
||||
# user_text only to prevent system prompts from forcing REASONING tier.
|
||||
full_text = f"{system_prompt or ''} {prompt}".lower()
|
||||
user_text = prompt.lower()
|
||||
|
||||
# Estimate tokens
|
||||
estimated_tokens = self._estimate_tokens(prompt)
|
||||
|
||||
# Score all dimensions, capturing match counts where needed
|
||||
code_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.code_keywords,
|
||||
"codePresence",
|
||||
"code",
|
||||
(1, 2),
|
||||
(0, 0.5, 1.0),
|
||||
)
|
||||
reasoning_score, reasoning_match_count = self._score_keyword_match(
|
||||
user_text,
|
||||
self.reasoning_keywords,
|
||||
"reasoningMarkers",
|
||||
"reasoning",
|
||||
(1, 2),
|
||||
(0, 0.7, 1.0),
|
||||
)
|
||||
technical_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.technical_keywords,
|
||||
"technicalTerms",
|
||||
"technical",
|
||||
(2, 4),
|
||||
(0, 0.5, 1.0),
|
||||
)
|
||||
simple_score, _ = self._score_keyword_match(
|
||||
full_text,
|
||||
self.simple_keywords,
|
||||
"simpleIndicators",
|
||||
"simple",
|
||||
(1, 2),
|
||||
(0, -1.0, -1.0),
|
||||
)
|
||||
|
||||
dimensions: List[DimensionScore] = [
|
||||
self._score_token_count(estimated_tokens),
|
||||
code_score,
|
||||
reasoning_score,
|
||||
technical_score,
|
||||
simple_score,
|
||||
self._score_multi_step(full_text),
|
||||
self._score_question_complexity(prompt),
|
||||
]
|
||||
|
||||
# Collect signals
|
||||
signals = [d.signal for d in dimensions if d.signal is not None]
|
||||
|
||||
# Compute weighted score
|
||||
weights = self.config.dimension_weights
|
||||
weighted_score = sum(d.score * weights.get(d.name, 0) for d in dimensions)
|
||||
|
||||
# Check for reasoning override (2+ reasoning markers)
|
||||
# Reuse match count from _score_keyword_match to avoid scanning twice
|
||||
if reasoning_match_count >= 2:
|
||||
return ComplexityTier.REASONING, weighted_score, signals
|
||||
|
||||
# Map score to tier
|
||||
boundaries = self.config.tier_boundaries
|
||||
simple_medium = boundaries.get("simple_medium", 0.15)
|
||||
medium_complex = boundaries.get("medium_complex", 0.35)
|
||||
complex_reasoning = boundaries.get("complex_reasoning", 0.60)
|
||||
|
||||
if weighted_score < simple_medium:
|
||||
tier = ComplexityTier.SIMPLE
|
||||
elif weighted_score < medium_complex:
|
||||
tier = ComplexityTier.MEDIUM
|
||||
elif weighted_score < complex_reasoning:
|
||||
tier = ComplexityTier.COMPLEX
|
||||
else:
|
||||
tier = ComplexityTier.REASONING
|
||||
|
||||
return tier, weighted_score, signals
|
||||
|
||||
def get_model_for_tier(self, tier: ComplexityTier) -> str:
|
||||
"""
|
||||
Get the model name for a given complexity tier.
|
||||
|
||||
Args:
|
||||
tier: The complexity tier.
|
||||
|
||||
Returns:
|
||||
The model name configured for that tier.
|
||||
"""
|
||||
tier_key = tier.value if isinstance(tier, ComplexityTier) else tier
|
||||
|
||||
# Check config tiers mapping
|
||||
model = self.config.tiers.get(tier_key)
|
||||
if model:
|
||||
return model
|
||||
|
||||
# Fallback to default model if configured
|
||||
if self.config.default_model:
|
||||
return self.config.default_model
|
||||
|
||||
# Last resort: return MEDIUM tier model or error
|
||||
medium_model = self.config.tiers.get(ComplexityTier.MEDIUM.value)
|
||||
if medium_model:
|
||||
return medium_model
|
||||
|
||||
raise ValueError(
|
||||
f"No model configured for tier {tier_key} and no default_model set"
|
||||
)
|
||||
|
||||
async def async_pre_routing_hook(
|
||||
self,
|
||||
model: str,
|
||||
request_kwargs: Dict,
|
||||
messages: Optional[List[Dict[str, Any]]] = None,
|
||||
input: Optional[Union[str, List]] = None,
|
||||
specific_deployment: Optional[bool] = False,
|
||||
) -> Optional["PreRoutingHookResponse"]:
|
||||
"""
|
||||
Pre-routing hook called before the routing decision.
|
||||
|
||||
Classifies the request by complexity and returns the appropriate model.
|
||||
|
||||
Args:
|
||||
model: The original model name requested.
|
||||
request_kwargs: The request kwargs.
|
||||
messages: The messages in the request.
|
||||
input: Optional input for embeddings.
|
||||
specific_deployment: Whether a specific deployment was requested.
|
||||
|
||||
Returns:
|
||||
PreRoutingHookResponse with the routed model, or None if no routing needed.
|
||||
"""
|
||||
from litellm.types.router import PreRoutingHookResponse
|
||||
|
||||
if messages is None or len(messages) == 0:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No messages provided, skipping routing"
|
||||
)
|
||||
return None
|
||||
|
||||
# Extract the last user message and the last system prompt
|
||||
user_message: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content") or ""
|
||||
# content may be a list of content parts (e.g. [{"type": "text", "text": "..."}])
|
||||
if isinstance(content, list):
|
||||
text_parts = [
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
]
|
||||
content = " ".join(text_parts).strip()
|
||||
if isinstance(content, str) and content:
|
||||
if role == "user" and user_message is None:
|
||||
user_message = content
|
||||
elif role == "system" and system_prompt is None:
|
||||
system_prompt = content
|
||||
|
||||
if user_message is None:
|
||||
verbose_router_logger.debug(
|
||||
"ComplexityRouter: No user message found, routing to default model"
|
||||
)
|
||||
return PreRoutingHookResponse(
|
||||
model=self.config.default_model
|
||||
or self.get_model_for_tier(ComplexityTier.MEDIUM),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
# Classify the request
|
||||
tier, score, signals = self.classify(user_message, system_prompt)
|
||||
|
||||
# Get the model for this tier
|
||||
routed_model = self.get_model_for_tier(tier)
|
||||
|
||||
verbose_router_logger.info(
|
||||
f"ComplexityRouter: tier={tier.value}, score={score:.3f}, "
|
||||
f"signals={signals}, routed_model={routed_model}"
|
||||
)
|
||||
|
||||
return PreRoutingHookResponse(
|
||||
model=routed_model,
|
||||
messages=messages,
|
||||
)
|
||||
@@ -0,0 +1,255 @@
|
||||
"""
|
||||
Configuration for the Complexity Router.
|
||||
|
||||
Contains default keyword lists, weights, tier boundaries, and configuration classes.
|
||||
All values are configurable via proxy config.yaml.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ComplexityTier(str, Enum):
|
||||
"""Complexity tiers for routing decisions."""
|
||||
|
||||
SIMPLE = "SIMPLE"
|
||||
MEDIUM = "MEDIUM"
|
||||
COMPLEX = "COMPLEX"
|
||||
REASONING = "REASONING"
|
||||
|
||||
|
||||
# ─── Default Keyword Lists ───
|
||||
# Note: Keywords should be full words/phrases to avoid substring false positives.
|
||||
# The matching logic uses word boundary detection for single-word keywords.
|
||||
|
||||
DEFAULT_CODE_KEYWORDS: List[str] = [
|
||||
"function",
|
||||
"class",
|
||||
"def",
|
||||
"const",
|
||||
"let",
|
||||
"var",
|
||||
"import",
|
||||
"export",
|
||||
"return",
|
||||
"async",
|
||||
"await",
|
||||
"try",
|
||||
"catch",
|
||||
"exception",
|
||||
"error",
|
||||
"debug",
|
||||
"api",
|
||||
"endpoint",
|
||||
"request",
|
||||
"response",
|
||||
"database",
|
||||
"sql",
|
||||
"query",
|
||||
"schema",
|
||||
"algorithm",
|
||||
"implement",
|
||||
"refactor",
|
||||
"optimize",
|
||||
"python",
|
||||
"javascript",
|
||||
"typescript",
|
||||
"java",
|
||||
"rust",
|
||||
"golang",
|
||||
"react",
|
||||
"vue",
|
||||
"angular",
|
||||
"node",
|
||||
"docker",
|
||||
"kubernetes",
|
||||
"git",
|
||||
"commit",
|
||||
"merge",
|
||||
"branch",
|
||||
"pull request",
|
||||
]
|
||||
|
||||
DEFAULT_REASONING_KEYWORDS: List[str] = [
|
||||
"step by step",
|
||||
"think through",
|
||||
"let's think",
|
||||
"reason through",
|
||||
"analyze this",
|
||||
"break down",
|
||||
"explain your reasoning",
|
||||
"show your work",
|
||||
"chain of thought",
|
||||
"think carefully",
|
||||
"consider all",
|
||||
"evaluate",
|
||||
"pros and cons",
|
||||
"compare and contrast",
|
||||
"weigh the options",
|
||||
"logical",
|
||||
"deduce",
|
||||
"infer",
|
||||
"conclude",
|
||||
]
|
||||
|
||||
DEFAULT_TECHNICAL_KEYWORDS: List[str] = [
|
||||
"architecture",
|
||||
"distributed",
|
||||
"scalable",
|
||||
"microservice",
|
||||
"machine learning",
|
||||
"neural network",
|
||||
"deep learning",
|
||||
"encryption",
|
||||
"authentication",
|
||||
"authorization",
|
||||
"performance",
|
||||
"latency",
|
||||
"throughput",
|
||||
"benchmark",
|
||||
"concurrency",
|
||||
"parallel",
|
||||
"threading",
|
||||
"memory",
|
||||
"cpu",
|
||||
"gpu",
|
||||
"optimization",
|
||||
"protocol",
|
||||
"tcp",
|
||||
"http",
|
||||
"grpc",
|
||||
"websocket",
|
||||
"container",
|
||||
"orchestration",
|
||||
# Note: "async", "kubernetes", "docker" are in DEFAULT_CODE_KEYWORDS
|
||||
]
|
||||
|
||||
DEFAULT_SIMPLE_KEYWORDS: List[str] = [
|
||||
"what is",
|
||||
"what's",
|
||||
"define",
|
||||
"definition of",
|
||||
"who is",
|
||||
"who was",
|
||||
"when did",
|
||||
"when was",
|
||||
"where is",
|
||||
"where was",
|
||||
"how many",
|
||||
"how much",
|
||||
"yes or no",
|
||||
"true or false",
|
||||
"simple",
|
||||
"brief",
|
||||
"short",
|
||||
"quick",
|
||||
"hello",
|
||||
"hi",
|
||||
"hey",
|
||||
"thanks",
|
||||
"thank you",
|
||||
"goodbye",
|
||||
"bye",
|
||||
"okay",
|
||||
# Note: "ok" removed due to false positives (matches "token", "book", etc.)
|
||||
]
|
||||
|
||||
|
||||
# ─── Default Dimension Weights ───
|
||||
|
||||
DEFAULT_DIMENSION_WEIGHTS: Dict[str, float] = {
|
||||
"tokenCount": 0.10, # Reduced - length is less important than content
|
||||
"codePresence": 0.30, # High - code requests need capable models
|
||||
"reasoningMarkers": 0.25, # High - explicit reasoning requests
|
||||
"technicalTerms": 0.25, # High - technical content matters
|
||||
"simpleIndicators": 0.05, # Low - don't over-penalize simple patterns
|
||||
"multiStepPatterns": 0.03,
|
||||
"questionComplexity": 0.02,
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Tier Boundaries ───
|
||||
|
||||
DEFAULT_TIER_BOUNDARIES: Dict[str, float] = {
|
||||
"simple_medium": 0.15, # Lower threshold to catch more MEDIUM cases
|
||||
"medium_complex": 0.35, # Lower threshold to catch technical COMPLEX cases
|
||||
"complex_reasoning": 0.60, # Reasoning tier reserved for explicit reasoning markers
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Token Thresholds ───
|
||||
|
||||
DEFAULT_TOKEN_THRESHOLDS: Dict[str, int] = {
|
||||
"simple": 15, # Only very short prompts (<15 tokens) are penalized
|
||||
"complex": 400, # Long prompts (>400 tokens) get complexity boost
|
||||
}
|
||||
|
||||
|
||||
# ─── Default Tier to Model Mapping ───
|
||||
|
||||
DEFAULT_TIER_MODELS: Dict[str, str] = {
|
||||
"SIMPLE": "gpt-4o-mini",
|
||||
"MEDIUM": "gpt-4o",
|
||||
"COMPLEX": "claude-sonnet-4-20250514",
|
||||
"REASONING": "claude-sonnet-4-20250514",
|
||||
}
|
||||
|
||||
|
||||
class ComplexityRouterConfig(BaseModel):
|
||||
"""Configuration for the ComplexityRouter."""
|
||||
|
||||
# Tier to model mapping
|
||||
tiers: Dict[str, str] = Field(
|
||||
default_factory=lambda: DEFAULT_TIER_MODELS.copy(),
|
||||
description="Mapping of complexity tiers to model names",
|
||||
)
|
||||
|
||||
# Tier boundaries (normalized scores)
|
||||
tier_boundaries: Dict[str, float] = Field(
|
||||
default_factory=lambda: DEFAULT_TIER_BOUNDARIES.copy(),
|
||||
description="Score boundaries between tiers",
|
||||
)
|
||||
|
||||
# Token count thresholds
|
||||
token_thresholds: Dict[str, int] = Field(
|
||||
default_factory=lambda: DEFAULT_TOKEN_THRESHOLDS.copy(),
|
||||
description="Token count thresholds for simple/complex classification",
|
||||
)
|
||||
|
||||
# Dimension weights
|
||||
dimension_weights: Dict[str, float] = Field(
|
||||
default_factory=lambda: DEFAULT_DIMENSION_WEIGHTS.copy(),
|
||||
description="Weights for each scoring dimension",
|
||||
)
|
||||
|
||||
# Keyword lists (overridable)
|
||||
code_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating code-related content",
|
||||
)
|
||||
reasoning_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating reasoning-required content",
|
||||
)
|
||||
technical_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating technical content",
|
||||
)
|
||||
simple_keywords: Optional[List[str]] = Field(
|
||||
default=None,
|
||||
description="Keywords indicating simple/basic queries",
|
||||
)
|
||||
|
||||
# Default model if scoring fails
|
||||
default_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Default model to use if tier cannot be determined",
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow") # Allow additional fields
|
||||
|
||||
|
||||
# Combined default config
|
||||
DEFAULT_COMPLEXITY_CONFIG = ComplexityRouterConfig()
|
||||
@@ -0,0 +1 @@
|
||||
# Evaluation suite for ComplexityRouter
|
||||
@@ -0,0 +1,343 @@
|
||||
"""
|
||||
Evaluation suite for the ComplexityRouter.
|
||||
|
||||
Tests the router's ability to correctly classify prompts into complexity tiers.
|
||||
Run with: python -m litellm.router_strategy.complexity_router.evals.eval_complexity_router
|
||||
"""
|
||||
import os
|
||||
|
||||
# Add parent to path for imports
|
||||
import sys
|
||||
|
||||
# ruff: noqa: T201
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
|
||||
)
|
||||
|
||||
from litellm.router_strategy.complexity_router.complexity_router import ComplexityRouter
|
||||
from litellm.router_strategy.complexity_router.config import ComplexityTier
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalCase:
|
||||
"""A single evaluation case."""
|
||||
|
||||
prompt: str
|
||||
expected_tier: ComplexityTier
|
||||
description: str
|
||||
system_prompt: Optional[str] = None
|
||||
# Allow some flexibility - if actual tier is in acceptable_tiers, still passes
|
||||
acceptable_tiers: Optional[List[ComplexityTier]] = None
|
||||
|
||||
|
||||
# ─── Evaluation Dataset ───
|
||||
|
||||
EVAL_CASES: List[EvalCase] = [
|
||||
# === SIMPLE tier cases ===
|
||||
EvalCase(
|
||||
prompt="Hello!",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Basic greeting",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What is Python?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple definition question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Who is Elon Musk?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple factual question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What's the capital of France?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple geography question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Thanks for your help!",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple thank you",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Define machine learning",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Definition request",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="When was the iPhone released?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple date question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="How many planets are in our solar system?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple count question",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Yes",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Single word response",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What time is it in Tokyo?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Simple time zone question",
|
||||
),
|
||||
# === MEDIUM tier cases ===
|
||||
EvalCase(
|
||||
prompt="Explain how REST APIs work and when to use them",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical explanation",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Write a short poem about the ocean",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Creative writing - short",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Summarize the main differences between SQL and NoSQL databases",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical comparison",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="What are the benefits of using TypeScript over JavaScript?",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Technical comparison question",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Help me debug this error: TypeError: Cannot read property 'map' of undefined",
|
||||
expected_tier=ComplexityTier.MEDIUM,
|
||||
description="Debugging help",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
# === COMPLEX tier cases ===
|
||||
EvalCase(
|
||||
prompt="Design a distributed microservice architecture for a high-throughput "
|
||||
"real-time data processing pipeline with Kubernetes orchestration, "
|
||||
"implementing proper authentication and encryption protocols",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex architecture design",
|
||||
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Write a Python function that implements a binary search tree with "
|
||||
"insert, delete, and search operations. Include proper error handling "
|
||||
"and optimize for memory efficiency.",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex coding task",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Explain the differences between TCP and UDP protocols, including "
|
||||
"use cases for each, performance implications, and how they handle "
|
||||
"packet loss in distributed systems",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Deep technical explanation",
|
||||
acceptable_tiers=[ComplexityTier.MEDIUM, ComplexityTier.COMPLEX],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Create a comprehensive database schema for an e-commerce platform "
|
||||
"that handles users, products, orders, payments, shipping, reviews, "
|
||||
"and inventory management with proper indexing strategies",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex database design",
|
||||
acceptable_tiers=[
|
||||
ComplexityTier.MEDIUM,
|
||||
ComplexityTier.COMPLEX,
|
||||
ComplexityTier.REASONING,
|
||||
],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Implement a rate limiter using the token bucket algorithm in Python "
|
||||
"that supports multiple rate limit tiers and can be used across "
|
||||
"distributed systems with Redis as the backend",
|
||||
expected_tier=ComplexityTier.COMPLEX,
|
||||
description="Complex distributed systems coding",
|
||||
acceptable_tiers=[
|
||||
ComplexityTier.MEDIUM,
|
||||
ComplexityTier.COMPLEX,
|
||||
ComplexityTier.REASONING,
|
||||
],
|
||||
),
|
||||
# === REASONING tier cases ===
|
||||
EvalCase(
|
||||
prompt="Think step by step about how to solve this: A farmer has 17 sheep. "
|
||||
"All but 9 die. How many are left? Explain your reasoning.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Explicit reasoning request",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Let's think through this carefully. Analyze the pros and cons of "
|
||||
"microservices vs monolithic architecture for a startup with 5 engineers. "
|
||||
"Consider scalability, development speed, and operational complexity.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Multiple reasoning markers + analysis",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Reason through this problem: If I have a function that's O(n^2) and "
|
||||
"I need to process 1 million items, what are my options to optimize it? "
|
||||
"Walk me through each approach step by step.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Algorithm reasoning",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="I need you to think carefully and analyze this code for potential "
|
||||
"security vulnerabilities. Consider injection attacks, authentication "
|
||||
"bypasses, and data exposure risks. Show your reasoning process.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Security analysis with reasoning",
|
||||
acceptable_tiers=[ComplexityTier.COMPLEX, ComplexityTier.REASONING],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Step by step, explain your reasoning as you evaluate whether we should "
|
||||
"use PostgreSQL or MongoDB for our new project. Consider our requirements: "
|
||||
"complex queries, high write volume, and eventual consistency is acceptable.",
|
||||
expected_tier=ComplexityTier.REASONING,
|
||||
description="Database decision with explicit reasoning",
|
||||
),
|
||||
# === Edge cases / regression tests ===
|
||||
EvalCase(
|
||||
prompt="What is the capital of France?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'capital' should not trigger 'api' keyword",
|
||||
),
|
||||
EvalCase(
|
||||
prompt="I tried to book a flight but the entry form wasn't working",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'tried' and 'entry' should not trigger code keywords",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="The poetry of digital art is fascinating",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'poetry' should not trigger 'try' keyword",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
EvalCase(
|
||||
prompt="Can you recommend a good book about country music history?",
|
||||
expected_tier=ComplexityTier.SIMPLE,
|
||||
description="Regression: 'country' should not trigger 'try' keyword",
|
||||
acceptable_tiers=[ComplexityTier.SIMPLE, ComplexityTier.MEDIUM],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def run_eval() -> Tuple[int, int, List[dict]]:
|
||||
"""
|
||||
Run the evaluation suite.
|
||||
|
||||
Returns:
|
||||
Tuple of (passed, total, failures)
|
||||
"""
|
||||
# Create router with default config
|
||||
mock_router = MagicMock()
|
||||
router = ComplexityRouter(
|
||||
model_name="eval-router",
|
||||
litellm_router_instance=mock_router,
|
||||
)
|
||||
|
||||
passed = 0
|
||||
total = len(EVAL_CASES)
|
||||
failures = []
|
||||
|
||||
print("=" * 70) # noqa: T201
|
||||
print("COMPLEXITY ROUTER EVALUATION") # noqa: T201
|
||||
print("=" * 70) # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
for i, case in enumerate(EVAL_CASES, 1):
|
||||
tier, score, signals = router.classify(case.prompt, case.system_prompt)
|
||||
|
||||
# Check if pass
|
||||
is_exact_match = tier == case.expected_tier
|
||||
is_acceptable = (
|
||||
case.acceptable_tiers is not None and tier in case.acceptable_tiers
|
||||
)
|
||||
is_pass = is_exact_match or is_acceptable
|
||||
|
||||
if is_pass:
|
||||
passed += 1
|
||||
status = "✓ PASS"
|
||||
else:
|
||||
status = "✗ FAIL"
|
||||
failures.append(
|
||||
{
|
||||
"case": i,
|
||||
"description": case.description,
|
||||
"prompt": case.prompt[:80] + "..."
|
||||
if len(case.prompt) > 80
|
||||
else case.prompt,
|
||||
"expected": case.expected_tier.value,
|
||||
"actual": tier.value,
|
||||
"score": round(score, 3),
|
||||
"signals": signals,
|
||||
"acceptable": [t.value for t in case.acceptable_tiers]
|
||||
if case.acceptable_tiers
|
||||
else None,
|
||||
}
|
||||
)
|
||||
|
||||
# Print result
|
||||
print(f"[{i:2d}] {status} | {case.description}") # noqa: T201
|
||||
print(
|
||||
f" Expected: {case.expected_tier.value:10s} | Got: {tier.value:10s} | Score: {score:+.3f}"
|
||||
) # noqa: T201
|
||||
if signals:
|
||||
print(f" Signals: {', '.join(signals)}") # noqa: T201
|
||||
if not is_pass:
|
||||
print(f" Prompt: {case.prompt[:60]}...") # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
# Summary
|
||||
print("=" * 70) # noqa: T201
|
||||
print(f"RESULTS: {passed}/{total} passed ({100*passed/total:.1f}%)") # noqa: T201
|
||||
print("=" * 70) # noqa: T201
|
||||
|
||||
if failures:
|
||||
print("\nFAILURES:") # noqa: T201
|
||||
print("-" * 70) # noqa: T201
|
||||
for f in failures:
|
||||
print(f"Case {f['case']}: {f['description']}") # noqa: T201
|
||||
print(
|
||||
f" Expected: {f['expected']}, Got: {f['actual']} (score: {f['score']})"
|
||||
) # noqa: T201
|
||||
print(f" Signals: {f['signals']}") # noqa: T201
|
||||
if f["acceptable"]:
|
||||
print(f" Acceptable: {f['acceptable']}") # noqa: T201
|
||||
print() # noqa: T201
|
||||
|
||||
return passed, total, failures
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
passed, total, failures = run_eval()
|
||||
|
||||
# Exit with error code if too many failures
|
||||
pass_rate = passed / total
|
||||
if pass_rate < 0.80:
|
||||
print(
|
||||
f"\n❌ EVAL FAILED: Pass rate {pass_rate:.1%} is below 80% threshold"
|
||||
) # noqa: T201
|
||||
sys.exit(1)
|
||||
elif pass_rate < 0.90:
|
||||
print(
|
||||
f"\n⚠️ EVAL WARNING: Pass rate {pass_rate:.1%} is below 90%"
|
||||
) # noqa: T201
|
||||
sys.exit(0)
|
||||
else:
|
||||
print(f"\n✅ EVAL PASSED: Pass rate {pass_rate:.1%}") # noqa: T201
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user