# This file runs a health check for the LLM, used on litellm/proxy import asyncio import logging import random import sys import threading import time from typing import List, Optional import litellm logger = logging.getLogger(__name__) from litellm.constants import DEFAULT_HEALTH_CHECK_PROMPT, HEALTH_CHECK_TIMEOUT_SECONDS ILLEGAL_DISPLAY_PARAMS = [ "messages", "api_key", "prompt", "input", "vertex_credentials", "aws_access_key_id", "aws_secret_access_key", ] MINIMAL_DISPLAY_PARAMS = ["model", "mode_error"] def _get_process_rss_mb() -> Optional[float]: """ Get process RSS memory in MB. On Linux, ru_maxrss is in KB. On macOS, ru_maxrss is in bytes. """ try: import resource ru_maxrss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss if sys.platform == "darwin": return float(ru_maxrss) / (1024 * 1024) return float(ru_maxrss) / 1024 except Exception: return None def _rss_mb_for_log() -> str: rss_mb = _get_process_rss_mb() if rss_mb is None: return "unknown" return f"{rss_mb:.2f}" def _get_random_llm_message(): """ Get a random message from the LLM. """ messages = ["Hey how's it going?", "What's 1 + 1?"] return [{"role": "user", "content": random.choice(messages)}] def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True): """ Clean the endpoint data for display to users. """ endpoint_data.pop("litellm_logging_obj", None) return ( {k: v for k, v in endpoint_data.items() if k not in ILLEGAL_DISPLAY_PARAMS} if details is not False else {k: v for k, v in endpoint_data.items() if k in MINIMAL_DISPLAY_PARAMS} ) def filter_deployments_by_id( model_list: List, ) -> List: seen_ids = set() filtered_deployments = [] for deployment in model_list: _model_info = deployment.get("model_info") or {} _id = _model_info.get("id") or None if _id is None: continue if _id not in seen_ids: seen_ids.add(_id) filtered_deployments.append(deployment) return filtered_deployments async def run_with_timeout(task, timeout): try: return await asyncio.wait_for(task, timeout) except asyncio.TimeoutError: # `asyncio.wait_for()` already cancels only the awaited task on timeout. # Do not cancel unrelated sibling health check tasks. return {"error": "Timeout exceeded"} async def _run_model_health_check(model: dict): litellm_params = model["litellm_params"] model_info = model.get("model_info", {}) mode = model_info.get("mode", None) litellm_params = _update_litellm_params_for_health_check(model_info, litellm_params) timeout = model_info.get("health_check_timeout") or HEALTH_CHECK_TIMEOUT_SECONDS return await run_with_timeout( litellm.ahealth_check( litellm_params, mode=mode, prompt=DEFAULT_HEALTH_CHECK_PROMPT, input=["test from litellm"], ), timeout, ) async def _run_health_checks_with_bounded_concurrency( models: list, concurrency_limit: int ) -> tuple[list, int]: """ Run health checks with at most `concurrency_limit` active tasks. Preserves result ordering to match `models`. """ results: list = [None] * len(models) tasks_to_index: dict[asyncio.Task, int] = {} model_iter = iter(enumerate(models)) peak_in_flight = 0 def _schedule_next() -> bool: nonlocal peak_in_flight try: idx, next_model = next(model_iter) except StopIteration: return False task = asyncio.create_task(_run_model_health_check(next_model)) tasks_to_index[task] = idx peak_in_flight = max(peak_in_flight, len(tasks_to_index)) return True for _ in range(min(concurrency_limit, len(models))): _schedule_next() while tasks_to_index: done, _ = await asyncio.wait( set(tasks_to_index.keys()), return_when=asyncio.FIRST_COMPLETED, ) for task in done: idx = tasks_to_index.pop(task) try: results[idx] = task.result() except Exception as e: results[idx] = e _schedule_next() return results, peak_in_flight async def _perform_health_check( model_list: list, details: Optional[bool] = True, max_concurrency: Optional[int] = None, instrumentation_context: Optional[dict] = None, ): """ Perform a health check for each model in the list. max_concurrency: Optional limit on concurrent health check requests. """ instrumentation_context = instrumentation_context or {} instrumentation_enabled = bool(instrumentation_context.get("enabled", False)) cycle_id = instrumentation_context.get("cycle_id", "unknown") source = instrumentation_context.get("source", "unknown") dispatch_mode = "unbounded" peak_in_flight = 0 if isinstance(max_concurrency, int) and max_concurrency > 0: dispatch_mode = "bounded" results, peak_in_flight = await _run_health_checks_with_bounded_concurrency( model_list, max_concurrency ) else: tasks = [ asyncio.create_task(_run_model_health_check(model)) for model in model_list ] peak_in_flight = len(tasks) results = await asyncio.gather(*tasks, return_exceptions=True) if instrumentation_enabled: logger.debug( "health_check_dispatch_summary source=%s cycle_id=%s mode=%s model_count=%d max_concurrency=%s peak_in_flight=%d thread_count=%d rss_mb=%s", source, cycle_id, dispatch_mode, len(model_list), max_concurrency, peak_in_flight, threading.active_count(), _rss_mb_for_log(), ) healthy_endpoints = [] unhealthy_endpoints = [] for is_healthy, model in zip(results, model_list): litellm_params = model["litellm_params"] if isinstance(is_healthy, dict) and "error" not in is_healthy: healthy_endpoints.append( _clean_endpoint_data({**litellm_params, **is_healthy}, details) ) elif isinstance(is_healthy, dict): unhealthy_endpoints.append( _clean_endpoint_data({**litellm_params, **is_healthy}, details) ) else: unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) return healthy_endpoints, unhealthy_endpoints def _update_litellm_params_for_health_check( model_info: dict, litellm_params: dict ) -> dict: """ Update the litellm params for health check. - gets a short `messages` param for health check - updates the `model` param with the `health_check_model` if it exists Doc: https://docs.litellm.ai/docs/proxy/health#wildcard-routes - updates the `voice` param with the `health_check_voice` for `audio_speech` mode if it exists Doc: https://docs.litellm.ai/docs/proxy/health#text-to-speech-models - for Bedrock models with region routing (bedrock/region/model), strips the litellm routing prefix but preserves the model ID """ litellm_params["messages"] = _get_random_llm_message() _health_check_max_tokens = model_info.get("health_check_max_tokens", None) if _health_check_max_tokens is not None: litellm_params["max_tokens"] = _health_check_max_tokens elif "*" not in ( model_info.get("health_check_model") or litellm_params.get("model") or "" ): litellm_params["max_tokens"] = 1 _health_check_model = model_info.get("health_check_model", None) if _health_check_model is not None: litellm_params["model"] = _health_check_model if model_info.get("mode", None) == "audio_speech": litellm_params["voice"] = model_info.get("health_check_voice", "alloy") # Handle Bedrock region routing format: bedrock/region/model # This is needed because health checks bypass get_llm_provider() for the model param # Issue #15807: Without this, health checks send "region/model" as the model ID to AWS # which causes: "bedrock-runtime.../model/us-west-2/mistral.../invoke" (region in model ID) # # However, we must preserve cross-region inference profile prefixes like "us.", "eu.", etc. # Issue: Stripping these breaks AWS requirement for inference profile IDs # # Must also preserve route prefixes (converse/, invoke/) and handlers (llama/, deepseek_r1/, etc.) if litellm_params["model"].startswith("bedrock/"): from litellm.llms.bedrock.common_utils import BedrockModelInfo model = litellm_params["model"] # Strip only the bedrock/ prefix (preserve routes like converse/, invoke/) if model.startswith("bedrock/"): model = model[8:] # len("bedrock/") = 8 # Now check for region routing and strip it if present # Need to handle formats like: # - "us-west-2/model" → "model" # - "converse/us-west-2/model" → "converse/model" # - "llama/arn:..." → "llama/arn:..." (preserve handler) # # Strategy: Check each path segment, remove regions, preserve everything else parts = model.split("/") filtered_parts = [] for part in parts: # Skip AWS regions, keep everything else if part not in BedrockModelInfo.all_global_regions: filtered_parts.append(part) model = "/".join(filtered_parts) litellm_params["model"] = model return litellm_params async def perform_health_check( model_list: list, model: Optional[str] = None, cli_model: Optional[str] = None, details: Optional[bool] = True, model_id: Optional[str] = None, max_concurrency: Optional[int] = None, instrumentation_context: Optional[dict] = None, ): """ Perform a health check on the system. When model_id is provided, only the deployment with that id is checked (so models that share the same name but have different ids are checked separately). When model (name) is provided, all deployments matching that name are checked. Returns: (bool): True if the health check passes, False otherwise. """ instrumentation_context = instrumentation_context or {} instrumentation_enabled = bool(instrumentation_context.get("enabled", False)) cycle_id = instrumentation_context.get("cycle_id", "unknown") source = instrumentation_context.get("source", "unknown") if not model_list: if cli_model: model_list = [ {"model_name": cli_model, "litellm_params": {"model": cli_model}} ] else: if instrumentation_enabled: logger.debug( "health_check_cycle_skipped source=%s cycle_id=%s reason=no_models", source, cycle_id, ) return [], [] cycle_start_time = time.monotonic() requested_model_count = len(model_list) # Filter by model_id first so a single deployment is checked when id is specified if model_id is not None: _by_id = [ x for x in model_list if (x.get("model_info") or {}).get("id") == model_id ] if _by_id: model_list = _by_id elif model is not None: _new_model_list = [ x for x in model_list if x["litellm_params"]["model"] == model ] if _new_model_list == []: _new_model_list = [x for x in model_list if x["model_name"] == model] model_list = _new_model_list post_filter_model_count = len(model_list) model_list = filter_deployments_by_id( model_list=model_list ) # filter duplicate deployments (e.g. when model alias'es are used) deduped_model_count = len(model_list) if instrumentation_enabled: logger.debug( "health_check_cycle_start source=%s cycle_id=%s requested_model_count=%d post_model_filter_count=%d deduped_model_count=%d max_concurrency=%s thread_count=%d rss_mb=%s", source, cycle_id, requested_model_count, post_filter_model_count, deduped_model_count, max_concurrency, threading.active_count(), _rss_mb_for_log(), ) try: healthy_endpoints, unhealthy_endpoints = await _perform_health_check( model_list, details, max_concurrency=max_concurrency, instrumentation_context=instrumentation_context, ) except Exception: if instrumentation_enabled: logger.exception( "health_check_cycle_failed source=%s cycle_id=%s model_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s", source, cycle_id, deduped_model_count, (time.monotonic() - cycle_start_time) * 1000, threading.active_count(), _rss_mb_for_log(), ) raise if instrumentation_enabled: logger.debug( "health_check_cycle_complete source=%s cycle_id=%s model_count=%d healthy_count=%d unhealthy_count=%d duration_ms=%.2f thread_count=%d rss_mb=%s", source, cycle_id, deduped_model_count, len(healthy_endpoints), len(unhealthy_endpoints), (time.monotonic() - cycle_start_time) * 1000, threading.active_count(), _rss_mb_for_log(), ) return healthy_endpoints, unhealthy_endpoints