chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,397 @@
|
||||
# This file runs a health check for the LLM, used on litellm/proxy
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
import litellm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from litellm.constants import DEFAULT_HEALTH_CHECK_PROMPT, HEALTH_CHECK_TIMEOUT_SECONDS
|
||||
|
||||
ILLEGAL_DISPLAY_PARAMS = [
|
||||
"messages",
|
||||
"api_key",
|
||||
"prompt",
|
||||
"input",
|
||||
"vertex_credentials",
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
]
|
||||
|
||||
MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"]
|
||||
|
||||
|
||||
def _get_process_rss_mb() -> Optional[float]:
|
||||
"""
|
||||
Get process RSS memory in MB.
|
||||
On Linux, ru_maxrss is in KB. On macOS, ru_maxrss is in bytes.
|
||||
"""
|
||||
try:
|
||||
import resource
|
||||
|
||||
ru_maxrss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
if sys.platform == "darwin":
|
||||
return float(ru_maxrss) / (1024 * 1024)
|
||||
return float(ru_maxrss) / 1024
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _rss_mb_for_log() -> str:
|
||||
rss_mb = _get_process_rss_mb()
|
||||
if rss_mb is None:
|
||||
return "unknown"
|
||||
return f"{rss_mb:.2f}"
|
||||
|
||||
|
||||
def _get_random_llm_message():
|
||||
"""
|
||||
Get a random message from the LLM.
|
||||
"""
|
||||
messages = ["Hey how's it going?", "What's 1 + 1?"]
|
||||
|
||||
return [{"role": "user", "content": random.choice(messages)}]
|
||||
|
||||
|
||||
def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True):
|
||||
"""
|
||||
Clean the endpoint data for display to users.
|
||||
"""
|
||||
endpoint_data.pop("litellm_logging_obj", None)
|
||||
return (
|
||||
{k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS}
|
||||
if details is not False
|
||||
else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS}
|
||||
)
|
||||
|
||||
|
||||
def filter_deployments_by_id(
|
||||
model_list: List,
|
||||
) -> List:
|
||||
seen_ids = set()
|
||||
filtered_deployments = []
|
||||
|
||||
for deployment in model_list:
|
||||
_model_info = deployment.get("model_info") or {}
|
||||
_id = _model_info.get("id") or None
|
||||
if _id is None:
|
||||
continue
|
||||
|
||||
if _id not in seen_ids:
|
||||
seen_ids.add(_id)
|
||||
filtered_deployments.append(deployment)
|
||||
|
||||
return filtered_deployments
|
||||
|
||||
|
||||
async def run_with_timeout(task, timeout):
|
||||
try:
|
||||
return await asyncio.wait_for(task, timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# `asyncio.wait_for()` already cancels only the awaited task on timeout.
|
||||
# Do not cancel unrelated sibling health check tasks.
|
||||
return {"error": "Timeout exceeded"}
|
||||
|
||||
|
||||
async def _run_model_health_check(model: dict):
|
||||
litellm_params = model["litellm_params"]
|
||||
model_info = model.get("model_info", {})
|
||||
mode = model_info.get("mode", None)
|
||||
litellm_params = _update_litellm_params_for_health_check(model_info, litellm_params)
|
||||
timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS
|
||||
|
||||
return await run_with_timeout(
|
||||
litellm.ahealth_check(
|
||||
litellm_params,
|
||||
mode=mode,
|
||||
prompt=DEFAULT_HEALTH_CHECK_PROMPT,
|
||||
input=["test from litellm"],
|
||||
),
|
||||
timeout,
|
||||
)
|
||||
|
||||
|
||||
async def _run_health_checks_with_bounded_concurrency(
|
||||
models: list, concurrency_limit: int
|
||||
) -> tuple[list, int]:
|
||||
"""
|
||||
Run health checks with at most `concurrency_limit` active tasks.
|
||||
Preserves result ordering to match `models`.
|
||||
"""
|
||||
results: list = [None] * len(models)
|
||||
tasks_to_index: dict[asyncio.Task, int] = {}
|
||||
model_iter = iter(enumerate(models))
|
||||
peak_in_flight = 0
|
||||
|
||||
def _schedule_next() -> bool:
|
||||
nonlocal peak_in_flight
|
||||
try:
|
||||
idx, next_model = next(model_iter)
|
||||
except StopIteration:
|
||||
return False
|
||||
task = asyncio.create_task(_run_model_health_check(next_model))
|
||||
tasks_to_index[task] = idx
|
||||
peak_in_flight = max(peak_in_flight, len(tasks_to_index))
|
||||
return True
|
||||
|
||||
for _ in range(min(concurrency_limit, len(models))):
|
||||
_schedule_next()
|
||||
|
||||
while tasks_to_index:
|
||||
done, _ = await asyncio.wait(
|
||||
set(tasks_to_index.keys()),
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in done:
|
||||
idx = tasks_to_index.pop(task)
|
||||
try:
|
||||
results[idx] = task.result()
|
||||
except Exception as e:
|
||||
results[idx] = e
|
||||
_schedule_next()
|
||||
|
||||
return results, peak_in_flight
|
||||
|
||||
|
||||
async def _perform_health_check(
|
||||
model_list: list,
|
||||
details: Optional[bool] = True,
|
||||
max_concurrency: Optional[int] = None,
|
||||
instrumentation_context: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Perform a health check for each model in the list.
|
||||
|
||||
max_concurrency: Optional limit on concurrent health check requests.
|
||||
"""
|
||||
|
||||
instrumentation_context = instrumentation_context or {}
|
||||
instrumentation_enabled = bool(instrumentation_context.get("enabled", False))
|
||||
cycle_id = instrumentation_context.get("cycle_id", "unknown")
|
||||
source = instrumentation_context.get("source", "unknown")
|
||||
|
||||
dispatch_mode = "unbounded"
|
||||
peak_in_flight = 0
|
||||
if isinstance(max_concurrency, int) and max_concurrency > 0:
|
||||
dispatch_mode = "bounded"
|
||||
results, peak_in_flight = await _run_health_checks_with_bounded_concurrency(
|
||||
model_list, max_concurrency
|
||||
)
|
||||
else:
|
||||
tasks = [
|
||||
asyncio.create_task(_run_model_health_check(model)) for model in model_list
|
||||
]
|
||||
peak_in_flight = len(tasks)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
if instrumentation_enabled:
|
||||
logger.debug(
|
||||
"health_check_dispatch_summary source=%s cycle_id=%s mode=%s model_count=%d max_concurrency=%s peak_in_flight=%d thread_count=%d rss_mb=%s",
|
||||
source,
|
||||
cycle_id,
|
||||
dispatch_mode,
|
||||
len(model_list),
|
||||
max_concurrency,
|
||||
peak_in_flight,
|
||||
threading.active_count(),
|
||||
_rss_mb_for_log(),
|
||||
)
|
||||
|
||||
healthy_endpoints = []
|
||||
unhealthy_endpoints = []
|
||||
|
||||
for is_healthy, model in zip(results, model_list):
|
||||
litellm_params = model["litellm_params"]
|
||||
|
||||
if isinstance(is_healthy, dict) and "error" not in is_healthy:
|
||||
healthy_endpoints.append(
|
||||
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
|
||||
)
|
||||
elif isinstance(is_healthy, dict):
|
||||
unhealthy_endpoints.append(
|
||||
_clean_endpoint_data({**litellm_params, **is_healthy}, details)
|
||||
)
|
||||
else:
|
||||
unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details))
|
||||
|
||||
return healthy_endpoints, unhealthy_endpoints
|
||||
|
||||
|
||||
def _update_litellm_params_for_health_check(
|
||||
model_info: dict, litellm_params: dict
|
||||
) -> dict:
|
||||
"""
|
||||
Update the litellm params for health check.
|
||||
|
||||
- gets a short `messages` param for health check
|
||||
- updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes
|
||||
- updates the `voice` param with the `health_check_voice` for `audio_speech` mode if it exists Doc: https://docs.litellm.ai/docs/proxy/health#text-to-speech-models
|
||||
- for Bedrock models with region routing (bedrock/region/model), strips the litellm routing prefix but preserves the model ID
|
||||
"""
|
||||
litellm_params["messages"] = _get_random_llm_message()
|
||||
_health_check_max_tokens = model_info.get("health_check_max_tokens", None)
|
||||
if _health_check_max_tokens is not None:
|
||||
litellm_params["max_tokens"] = _health_check_max_tokens
|
||||
elif "*" not in (
|
||||
model_info.get("health_check_model") or litellm_params.get("model") or ""
|
||||
):
|
||||
litellm_params["max_tokens"] = 1
|
||||
|
||||
_health_check_model = model_info.get("health_check_model", None)
|
||||
if _health_check_model is not None:
|
||||
litellm_params["model"] = _health_check_model
|
||||
if model_info.get("mode", None) == "audio_speech":
|
||||
litellm_params["voice"] = model_info.get("health_check_voice", "alloy")
|
||||
|
||||
# Handle Bedrock region routing format: bedrock/region/model
|
||||
# This is needed because health checks bypass get_llm_provider() for the model param
|
||||
# Issue #15807: Without this, health checks send "region/model" as the model ID to AWS
|
||||
# which causes: "bedrock-runtime.../model/us-west-2/mistral.../invoke" (region in model ID)
|
||||
#
|
||||
# However, we must preserve cross-region inference profile prefixes like "us.", "eu.", etc.
|
||||
# Issue: Stripping these breaks AWS requirement for inference profile IDs
|
||||
#
|
||||
# Must also preserve route prefixes (converse/, invoke/) and handlers (llama/, deepseek_r1/, etc.)
|
||||
if litellm_params["model"].startswith("bedrock/"):
|
||||
from litellm.llms.bedrock.common_utils import BedrockModelInfo
|
||||
|
||||
model = litellm_params["model"]
|
||||
# Strip only the bedrock/ prefix (preserve routes like converse/, invoke/)
|
||||
if model.startswith("bedrock/"):
|
||||
model = model[8:] # len("bedrock/") = 8
|
||||
|
||||
# Now check for region routing and strip it if present
|
||||
# Need to handle formats like:
|
||||
# - "us-west-2/model" → "model"
|
||||
# - "converse/us-west-2/model" → "converse/model"
|
||||
# - "llama/arn:..." → "llama/arn:..." (preserve handler)
|
||||
#
|
||||
# Strategy: Check each path segment, remove regions, preserve everything else
|
||||
parts = model.split("/")
|
||||
filtered_parts = []
|
||||
|
||||
for part in parts:
|
||||
# Skip AWS regions, keep everything else
|
||||
if part not in BedrockModelInfo.all_global_regions:
|
||||
filtered_parts.append(part)
|
||||
|
||||
model = "/".join(filtered_parts)
|
||||
litellm_params["model"] = model
|
||||
|
||||
return litellm_params
|
||||
|
||||
|
||||
async def perform_health_check(
|
||||
model_list: list,
|
||||
model: Optional[str] = None,
|
||||
cli_model: Optional[str] = None,
|
||||
details: Optional[bool] = True,
|
||||
model_id: Optional[str] = None,
|
||||
max_concurrency: Optional[int] = None,
|
||||
instrumentation_context: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Perform a health check on the system.
|
||||
|
||||
When model_id is provided, only the deployment with that id is checked
|
||||
(so models that share the same name but have different ids are checked separately).
|
||||
When model (name) is provided, all deployments matching that name are checked.
|
||||
|
||||
Returns:
|
||||
(bool): True if the health check passes, False otherwise.
|
||||
"""
|
||||
instrumentation_context = instrumentation_context or {}
|
||||
instrumentation_enabled = bool(instrumentation_context.get("enabled", False))
|
||||
cycle_id = instrumentation_context.get("cycle_id", "unknown")
|
||||
source = instrumentation_context.get("source", "unknown")
|
||||
|
||||
if not model_list:
|
||||
if cli_model:
|
||||
model_list = [
|
||||
{"model_name": cli_model, "litellm_params": {"model": cli_model}}
|
||||
]
|
||||
else:
|
||||
if instrumentation_enabled:
|
||||
logger.debug(
|
||||
"health_check_cycle_skipped source=%s cycle_id=%s reason=no_models",
|
||||
source,
|
||||
cycle_id,
|
||||
)
|
||||
return [], []
|
||||
|
||||
cycle_start_time = time.monotonic()
|
||||
requested_model_count = len(model_list)
|
||||
|
||||
# Filter by model_id first so a single deployment is checked when id is specified
|
||||
if model_id is not None:
|
||||
_by_id = [
|
||||
x for x in model_list if (x.get("model_info") or {}).get("id") == model_id
|
||||
]
|
||||
if _by_id:
|
||||
model_list = _by_id
|
||||
elif model is not None:
|
||||
_new_model_list = [
|
||||
x for x in model_list if x["litellm_params"]["model"] == model
|
||||
]
|
||||
if _new_model_list == []:
|
||||
_new_model_list = [x for x in model_list if x["model_name"] == model]
|
||||
model_list = _new_model_list
|
||||
|
||||
post_filter_model_count = len(model_list)
|
||||
model_list = filter_deployments_by_id(
|
||||
model_list=model_list
|
||||
) # filter duplicate deployments (e.g. when model alias'es are used)
|
||||
deduped_model_count = len(model_list)
|
||||
|
||||
if instrumentation_enabled:
|
||||
logger.debug(
|
||||
"health_check_cycle_start source=%s cycle_id=%s requested_model_count=%d post_model_filter_count=%d deduped_model_count=%d max_concurrency=%s thread_count=%d rss_mb=%s",
|
||||
source,
|
||||
cycle_id,
|
||||
requested_model_count,
|
||||
post_filter_model_count,
|
||||
deduped_model_count,
|
||||
max_concurrency,
|
||||
threading.active_count(),
|
||||
_rss_mb_for_log(),
|
||||
)
|
||||
|
||||
try:
|
||||
healthy_endpoints, unhealthy_endpoints = await _perform_health_check(
|
||||
model_list,
|
||||
details,
|
||||
max_concurrency=max_concurrency,
|
||||
instrumentation_context=instrumentation_context,
|
||||
)
|
||||
except Exception:
|
||||
if instrumentation_enabled:
|
||||
logger.exception(
|
||||
"health_check_cycle_failed source=%s cycle_id=%s model_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s",
|
||||
source,
|
||||
cycle_id,
|
||||
deduped_model_count,
|
||||
(time.monotonic() - cycle_start_time) * 1000,
|
||||
threading.active_count(),
|
||||
_rss_mb_for_log(),
|
||||
)
|
||||
raise
|
||||
|
||||
if instrumentation_enabled:
|
||||
logger.debug(
|
||||
"health_check_cycle_complete source=%s cycle_id=%s model_count=%d healthy_count=%d unhealthy_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s",
|
||||
source,
|
||||
cycle_id,
|
||||
deduped_model_count,
|
||||
len(healthy_endpoints),
|
||||
len(unhealthy_endpoints),
|
||||
(time.monotonic() - cycle_start_time) * 1000,
|
||||
threading.active_count(),
|
||||
_rss_mb_for_log(),
|
||||
)
|
||||
|
||||
return healthy_endpoints, unhealthy_endpoints
|
||||
Reference in New Issue
Block a user