chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

View File

@@ -0,0 +1,39 @@
from typing import Dict, List, Optional
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
from litellm.proxy._types import UserAPIKeyAuth
class MCPAuthenticatedUser(AuthenticatedUser):
"""
Wrapper class to make LiteLLM's authentication and configuration compatible with MCP's AuthenticatedUser.
This class handles:
1. User API key authentication information
2. MCP authentication header (deprecated)
3. MCP server configuration (can include access groups)
4. Server-specific authentication headers
5. OAuth2 headers
6. Raw headers - allows forwarding specific headers to the MCP server, specified by the admin.
"""
def __init__(
self,
user_api_key_auth: UserAPIKeyAuth,
mcp_auth_header: Optional[str] = None,
mcp_servers: Optional[List[str]] = None,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
oauth2_headers: Optional[Dict[str, str]] = None,
mcp_protocol_version: Optional[str] = None,
raw_headers: Optional[Dict[str, str]] = None,
client_ip: Optional[str] = None,
):
self.user_api_key_auth = user_api_key_auth
self.mcp_auth_header = mcp_auth_header
self.mcp_servers = mcp_servers
self.mcp_server_auth_headers = mcp_server_auth_headers or {}
self.mcp_protocol_version = mcp_protocol_version
self.oauth2_headers = oauth2_headers
self.raw_headers = raw_headers
self.client_ip = client_ip

View File

@@ -0,0 +1,789 @@
"""
BYOK (Bring Your Own Key) OAuth 2.1 Authorization Server endpoints for MCP servers.
When an MCP client connects to a BYOK-enabled server and no stored credential exists,
LiteLLM runs a minimal OAuth 2.1 authorization code flow. The "authorization page" is
just a form that asks the user for their API key — not a full identity-provider OAuth.
Endpoints implemented here:
GET /.well-known/oauth-authorization-server — OAuth authorization server metadata
GET /.well-known/oauth-protected-resource — OAuth protected resource metadata
GET /v1/mcp/oauth/authorize — Shows HTML form to collect the API key
POST /v1/mcp/oauth/authorize — Stores temp auth code and redirects
POST /v1/mcp/oauth/token — Exchanges code for a bearer JWT token
"""
import base64
import hashlib
import html as _html_module
import time
import uuid
from typing import Dict, Optional, cast
from urllib.parse import urlencode, urlparse
import jwt
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from litellm._logging import verbose_proxy_logger
from litellm.proxy._experimental.mcp_server.db import store_user_credential
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
get_request_base_url,
)
# ---------------------------------------------------------------------------
# In-memory store for pending authorization codes.
# Each entry: {code: {api_key, server_id, code_challenge, redirect_uri, user_id, expires_at}}
# ---------------------------------------------------------------------------
_byok_auth_codes: Dict[str, dict] = {}
# Authorization codes expire after 5 minutes.
_AUTH_CODE_TTL_SECONDS = 300
# Hard cap to prevent memory exhaustion from incomplete OAuth flows.
_AUTH_CODES_MAX_SIZE = 1000
router = APIRouter(tags=["mcp"])
# ---------------------------------------------------------------------------
# PKCE helper
# ---------------------------------------------------------------------------
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
"""Return True iff SHA-256(code_verifier) == code_challenge (base64url, no padding)."""
digest = hashlib.sha256(code_verifier.encode()).digest()
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return computed == code_challenge
# ---------------------------------------------------------------------------
# Cleanup of expired auth codes (called lazily on each request)
# ---------------------------------------------------------------------------
def _purge_expired_codes() -> None:
now = time.time()
expired = [k for k, v in _byok_auth_codes.items() if v["expires_at"] < now]
for k in expired:
del _byok_auth_codes[k]
def _build_authorize_html(
server_name: str,
server_initial: str,
client_id: str,
redirect_uri: str,
code_challenge: str,
code_challenge_method: str,
state: str,
server_id: str,
access_items: list,
help_url: str,
) -> str:
"""Build the 2-step BYOK OAuth authorization page HTML."""
# Escape all user-supplied / externally-derived values before interpolation
e = _html_module.escape
server_name = e(server_name)
server_initial = e(server_initial)
client_id = e(client_id)
redirect_uri = e(redirect_uri)
code_challenge = e(code_challenge)
code_challenge_method = e(code_challenge_method)
state = e(state)
server_id = e(server_id)
# Build access checklist rows
access_rows = "".join(
f'<div class="access-item"><span class="check">&#10003;</span>{e(item)}</div>'
for item in access_items
)
access_section = ""
if access_rows:
access_section = f"""
<div class="access-box">
<div class="access-header">
<span class="shield">&#9646;</span>
<span>Requested Access</span>
</div>
{access_rows}
</div>"""
# Help link for step 2
help_link_html = ""
if help_url:
help_link_html = f'<a class="help-link" href="{e(help_url)}" target="_blank">Where do I find my API key? &#8599;</a>'
return f"""<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Connect {server_name} &mdash; LiteLLM</title>
<style>
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
body {{
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
background: #0f172a;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
padding: 24px;
}}
.modal {{
background: #ffffff;
border-radius: 20px;
padding: 36px 32px 32px;
width: 440px;
max-width: 100%;
position: relative;
box-shadow: 0 25px 60px rgba(0,0,0,0.35);
}}
/* Progress dots */
.dots {{
display: flex;
justify-content: center;
gap: 7px;
margin-bottom: 28px;
}}
.dot {{
width: 8px; height: 8px;
border-radius: 50%;
background: #e2e8f0;
}}
.dot.active {{ background: #38bdf8; }}
/* Close button */
.close-btn {{
position: absolute;
top: 16px; right: 16px;
background: none; border: none;
font-size: 16px; color: #94a3b8;
cursor: pointer; line-height: 1;
width: 28px; height: 28px;
border-radius: 6px;
display: flex; align-items: center; justify-content: center;
}}
.close-btn:hover {{ background: #f1f5f9; color: #475569; }}
/* Logo pair */
.logos {{
display: flex; align-items: center; justify-content: center;
gap: 12px; margin-bottom: 20px;
}}
.logo {{
width: 52px; height: 52px;
border-radius: 14px;
display: flex; align-items: center; justify-content: center;
font-size: 22px; font-weight: 800; color: white;
}}
.logo-img {{
width: 52px; height: 52px;
border-radius: 14px;
object-fit: cover;
border: 1.5px solid #e2e8f0;
}}
.logo-s {{ background: linear-gradient(135deg, #818cf8 0%, #4f46e5 100%); }}
.logo-arrow {{ color: #cbd5e1; font-size: 20px; font-weight: 300; }}
/* Headings */
.step-title {{
text-align: center;
font-size: 21px; font-weight: 700;
color: #0f172a; margin-bottom: 8px;
}}
.step-subtitle {{
text-align: center;
font-size: 14px; color: #64748b;
line-height: 1.55; margin-bottom: 22px;
}}
/* Info box */
.info-box {{
background: #f8fafc;
border-radius: 12px;
padding: 14px 16px;
display: flex; gap: 12px;
margin-bottom: 14px;
}}
.info-icon {{ font-size: 17px; flex-shrink: 0; margin-top: 1px; color: #38bdf8; }}
.info-box h4 {{ font-size: 13px; font-weight: 600; color: #1e293b; margin-bottom: 4px; }}
.info-box p {{ font-size: 13px; color: #64748b; line-height: 1.5; }}
/* Access checklist */
.access-box {{
background: #f8fafc;
border-radius: 12px;
padding: 14px 16px;
margin-bottom: 22px;
}}
.access-header {{
display: flex; align-items: center; gap: 8px;
margin-bottom: 10px;
}}
.shield {{ color: #22c55e; font-size: 15px; }}
.access-header > span:last-child {{
font-size: 11px; font-weight: 700;
letter-spacing: 0.07em;
text-transform: uppercase;
color: #475569;
}}
.access-item {{
display: flex; align-items: center; gap: 9px;
font-size: 13.5px; color: #374151;
padding: 3px 0;
}}
.check {{ color: #22c55e; font-weight: 700; font-size: 13px; }}
/* Primary CTA */
.btn-primary {{
width: 100%; padding: 15px;
background: #0f172a; color: white;
border: none; border-radius: 12px;
font-size: 15px; font-weight: 600;
cursor: pointer; margin-bottom: 10px;
}}
.btn-primary:hover {{ background: #1e293b; }}
.btn-cancel {{
width: 100%; padding: 8px;
background: none; border: none;
font-size: 13.5px; color: #94a3b8;
cursor: pointer;
}}
.btn-cancel:hover {{ color: #64748b; }}
/* Step 2 nav */
.step2-nav {{
display: flex; align-items: center;
justify-content: space-between;
margin-bottom: 24px;
}}
.back-btn {{
background: none; border: none;
font-size: 13.5px; color: #64748b;
cursor: pointer; display: flex; align-items: center; gap: 4px;
}}
.back-btn:hover {{ color: #374151; }}
/* Key icon */
.key-icon-wrap {{
width: 46px; height: 46px;
background: #e0f2fe;
border-radius: 12px;
display: flex; align-items: center; justify-content: center;
margin-bottom: 14px;
}}
.key-icon-wrap svg {{ width: 22px; height: 22px; color: #0284c7; }}
/* Form elements */
.field-label {{
font-size: 13.5px; font-weight: 600;
color: #1e293b; display: block;
margin-bottom: 7px;
}}
.key-input {{
width: 100%; padding: 11px 13px;
border: 1.5px solid #e2e8f0;
border-radius: 10px;
font-size: 14px; color: #0f172a;
outline: none; transition: border-color 0.15s, box-shadow 0.15s;
}}
.key-input:focus {{
border-color: #38bdf8;
box-shadow: 0 0 0 3px rgba(56,189,248,0.12);
}}
.help-link {{
display: inline-flex; align-items: center; gap: 4px;
color: #0ea5e9; font-size: 13px;
text-decoration: none; margin: 8px 0 16px;
}}
.help-link:hover {{ text-decoration: underline; }}
/* Save toggle card */
.save-card {{
border: 1.5px solid #e2e8f0;
border-radius: 12px;
padding: 13px 15px;
margin-bottom: 6px;
}}
.save-row {{
display: flex; align-items: center; gap: 10px;
}}
.save-icon {{ font-size: 16px; }}
.save-label {{
flex: 1;
font-size: 14px; font-weight: 500; color: #1e293b;
}}
/* Toggle switch */
.toggle {{ position: relative; width: 44px; height: 24px; flex-shrink: 0; }}
.toggle input {{ opacity: 0; width: 0; height: 0; }}
.slider {{
position: absolute; inset: 0;
background: #e2e8f0;
border-radius: 24px; cursor: pointer;
transition: background 0.18s;
}}
.slider::before {{
content: '';
position: absolute;
width: 18px; height: 18px;
left: 3px; bottom: 3px;
background: white;
border-radius: 50%;
transition: transform 0.18s;
box-shadow: 0 1px 3px rgba(0,0,0,0.18);
}}
input:checked + .slider {{ background: #38bdf8; }}
input:checked + .slider::before {{ transform: translateX(20px); }}
/* Duration pills */
.duration-section {{ margin-top: 14px; }}
.duration-label {{
font-size: 12px; font-weight: 600;
color: #64748b; margin-bottom: 8px;
text-transform: uppercase; letter-spacing: 0.05em;
}}
.pills {{ display: flex; flex-wrap: wrap; gap: 7px; }}
.pill {{
padding: 6px 13px;
border: 1.5px solid #e2e8f0;
border-radius: 20px;
font-size: 13px; color: #475569;
cursor: pointer; background: white;
transition: all 0.13s;
user-select: none;
}}
.pill:hover {{ border-color: #94a3b8; }}
.pill.sel {{
border-color: #38bdf8;
color: #0284c7;
background: #e0f2fe;
}}
/* Security note */
.sec-note {{
background: #f8fafc;
border-radius: 10px;
padding: 11px 14px;
display: flex; gap: 9px; align-items: flex-start;
margin: 16px 0;
}}
.sec-icon {{ font-size: 13px; color: #94a3b8; margin-top: 1px; flex-shrink: 0; }}
.sec-note p {{ font-size: 12.5px; color: #64748b; line-height: 1.5; }}
/* Connect button */
.btn-connect {{
width: 100%; padding: 15px;
border: none; border-radius: 12px;
font-size: 15px; font-weight: 600;
cursor: pointer;
background: #bae6fd; color: #0369a1;
transition: background 0.15s, color 0.15s;
}}
.btn-connect.ready {{
background: #0ea5e9; color: white;
}}
.btn-connect.ready:hover {{ background: #0284c7; }}
/* Step visibility */
.step {{ display: none; }}
.step.show {{ display: block; }}
</style>
</head>
<body>
<div class="modal">
<!-- ── STEP 1: Connect ─────────────────────────────────────── -->
<div id="s1" class="step show">
<div class="dots">
<div class="dot active"></div>
<div class="dot"></div>
</div>
<button class="close-btn" type="button" onclick="doCancel()" title="Close">&times;</button>
<div class="logos">
<img src="/ui/assets/logos/litellm_logo.jpg" class="logo-img" alt="LiteLLM">
<span class="logo-arrow">&#8594;</span>
<div class="logo logo-s">{server_initial}</div>
</div>
<h2 class="step-title">Connect {server_name} MCP</h2>
<p class="step-subtitle">LiteLLM needs access to {server_name} to complete your request.</p>
<div class="info-box">
<span class="info-icon">
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
</span>
<div>
<h4>How it works</h4>
<p>LiteLLM acts as a secure bridge. Your requests are routed through our MCP client directly to {server_name}&rsquo;s API.</p>
</div>
</div>
{access_section}
<button class="btn-primary" type="button" onclick="goStep2()">
Continue to Authentication &rarr;
</button>
<button class="btn-cancel" type="button" onclick="doCancel()">Cancel</button>
</div>
<!-- ── STEP 2: Provide API Key ──────────────────────────────── -->
<div id="s2" class="step">
<div class="step2-nav">
<button class="back-btn" type="button" onclick="goStep1()">&#8592; Back</button>
<div class="dots">
<div class="dot active"></div>
<div class="dot active"></div>
</div>
<button class="close-btn" style="position:static;" type="button" onclick="doCancel()" title="Close">&times;</button>
</div>
<div class="key-icon-wrap">
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#0284c7" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21 2l-2 2m-7.61 7.61a5.5 5.5 0 1 1-7.778 7.778 5.5 5.5 0 0 1 7.777-7.777zm0 0L15.5 7.5m0 0l3 3L22 7l-3-3m-3.5 3.5L19 4"/></svg>
</div>
<h2 class="step-title" style="text-align:left;">Provide API Key</h2>
<p class="step-subtitle" style="text-align:left;">Enter your {server_name} API key to authorize this connection.</p>
<form method="POST" id="authForm" onsubmit="prepareSubmit()">
<input type="hidden" name="client_id" value="{client_id}">
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
<input type="hidden" name="code_challenge" value="{code_challenge}">
<input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
<input type="hidden" name="state" value="{state}">
<input type="hidden" name="server_id" value="{server_id}">
<input type="hidden" name="duration" id="durInput" value="until_revoked">
<label class="field-label">{server_name} API Key</label>
<input
type="password"
name="api_key"
id="apiKey"
class="key-input"
placeholder="Enter your API key"
required
autofocus
oninput="syncBtn()"
>
{help_link_html}
<div class="save-card">
<div class="save-row">
<span class="save-label">Save key for future use</span>
<label class="toggle">
<input type="checkbox" id="saveToggle" onchange="toggleDur()">
<span class="slider"></span>
</label>
</div>
<div id="durSection" class="duration-section" style="display:none;">
<div class="duration-label">Duration</div>
<div class="pills">
<div class="pill" onclick="selDur('1h',this)">1 hour</div>
<div class="pill sel" onclick="selDur('24h',this)">24 hours</div>
<div class="pill" onclick="selDur('7d',this)">7 days</div>
<div class="pill" onclick="selDur('30d',this)">30 days</div>
<div class="pill" onclick="selDur('until_revoked',this)">Until I revoke</div>
</div>
</div>
</div>
<div class="sec-note">
<span class="sec-icon">
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="11" width="18" height="11" rx="2" ry="2"/><path d="M7 11V7a5 5 0 0 1 10 0v4"/></svg>
</span>
<p>Your key is stored securely and transmitted over HTTPS. It is never shared with third parties.</p>
</div>
<button type="submit" class="btn-connect" id="connectBtn">
Connect &amp; Authorize
</button>
</form>
</div>
</div>
<script>
function goStep2() {{
document.getElementById('s1').classList.remove('show');
document.getElementById('s2').classList.add('show');
}}
function goStep1() {{
document.getElementById('s2').classList.remove('show');
document.getElementById('s1').classList.add('show');
}}
function doCancel() {{
if (window.opener) window.close();
else window.history.back();
}}
function toggleDur() {{
const on = document.getElementById('saveToggle').checked;
document.getElementById('durSection').style.display = on ? 'block' : 'none';
}}
function selDur(val, el) {{
document.querySelectorAll('.pill').forEach(p => p.classList.remove('sel'));
el.classList.add('sel');
document.getElementById('durInput').value = val;
}}
function syncBtn() {{
const btn = document.getElementById('connectBtn');
if (document.getElementById('apiKey').value.length > 0) {{
btn.classList.add('ready');
}} else {{
btn.classList.remove('ready');
}}
}}
function prepareSubmit() {{
// nothing extra needed — duration is already in the hidden input
}}
</script>
</body>
</html>"""
# ---------------------------------------------------------------------------
# OAuth metadata discovery endpoints
# ---------------------------------------------------------------------------
@router.get("/.well-known/oauth-authorization-server", include_in_schema=False)
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
"""RFC 8414 Authorization Server Metadata for the BYOK OAuth flow."""
base_url = get_request_base_url(request)
return JSONResponse(
{
"issuer": base_url,
"authorization_endpoint": f"{base_url}/v1/mcp/oauth/authorize",
"token_endpoint": f"{base_url}/v1/mcp/oauth/token",
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code"],
"code_challenge_methods_supported": ["S256"],
}
)
@router.get("/.well-known/oauth-protected-resource", include_in_schema=False)
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
"""RFC 9728 Protected Resource Metadata pointing back at this server."""
base_url = get_request_base_url(request)
return JSONResponse(
{
"resource": base_url,
"authorization_servers": [base_url],
}
)
# ---------------------------------------------------------------------------
# Authorization endpoint — GET (show form) and POST (process form)
# ---------------------------------------------------------------------------
@router.get("/v1/mcp/oauth/authorize", include_in_schema=False)
async def byok_authorize_get(
request: Request,
client_id: Optional[str] = None,
redirect_uri: Optional[str] = None,
response_type: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
state: Optional[str] = None,
server_id: Optional[str] = None,
) -> HTMLResponse:
"""
Show the BYOK API-key entry form.
The MCP client navigates the user here; the user types their API key and
clicks "Connect & Authorize", which POSTs back to this same path.
"""
if response_type != "code":
raise HTTPException(status_code=400, detail="response_type must be 'code'")
if not redirect_uri:
raise HTTPException(status_code=400, detail="redirect_uri is required")
if not code_challenge:
raise HTTPException(status_code=400, detail="code_challenge is required")
# Resolve server metadata (name, description items, help URL).
server_name = "MCP Server"
access_items: list = []
help_url = ""
if server_id:
try:
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
registry = global_mcp_server_manager.get_registry()
if server_id in registry:
srv = registry[server_id]
server_name = srv.server_name or srv.name
access_items = list(srv.byok_description or [])
help_url = srv.byok_api_key_help_url or ""
except Exception:
pass
server_initial = (server_name[0].upper()) if server_name else "S"
html = _build_authorize_html(
server_name=server_name,
server_initial=server_initial,
client_id=client_id or "",
redirect_uri=redirect_uri,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method or "S256",
state=state or "",
server_id=server_id or "",
access_items=access_items,
help_url=help_url,
)
return HTMLResponse(content=html)
@router.post("/v1/mcp/oauth/authorize", include_in_schema=False)
async def byok_authorize_post(
request: Request,
client_id: str = Form(default=""),
redirect_uri: str = Form(...),
code_challenge: str = Form(...),
code_challenge_method: str = Form(default="S256"),
state: str = Form(default=""),
server_id: str = Form(default=""),
api_key: str = Form(...),
) -> RedirectResponse:
"""
Process the BYOK API-key form submission.
Stores a short-lived authorization code and redirects the client back to
redirect_uri with ?code=...&state=... query parameters.
"""
_purge_expired_codes()
# Validate redirect_uri scheme to prevent open redirect
parsed_uri = urlparse(redirect_uri)
if parsed_uri.scheme not in ("http", "https"):
raise HTTPException(status_code=400, detail="Invalid redirect_uri scheme")
# Reject new codes if the store is at capacity (prevents memory exhaustion
# from a burst of abandoned OAuth flows).
if len(_byok_auth_codes) >= _AUTH_CODES_MAX_SIZE:
raise HTTPException(
status_code=503, detail="Too many pending authorization flows"
)
if code_challenge_method != "S256":
raise HTTPException(
status_code=400, detail="Only S256 code_challenge_method is supported"
)
auth_code = str(uuid.uuid4())
_byok_auth_codes[auth_code] = {
"api_key": api_key,
"server_id": server_id,
"code_challenge": code_challenge,
"redirect_uri": redirect_uri,
"user_id": client_id, # external client passes LiteLLM user-id as client_id
"expires_at": time.time() + _AUTH_CODE_TTL_SECONDS,
}
params = urlencode({"code": auth_code, "state": state})
separator = "&" if "?" in redirect_uri else "?"
location = f"{redirect_uri}{separator}{params}"
return RedirectResponse(url=location, status_code=302)
# ---------------------------------------------------------------------------
# Token endpoint
# ---------------------------------------------------------------------------
@router.post("/v1/mcp/oauth/token", include_in_schema=False)
async def byok_token(
request: Request,
grant_type: str = Form(...),
code: str = Form(...),
redirect_uri: str = Form(default=""),
code_verifier: str = Form(...),
client_id: str = Form(default=""),
) -> JSONResponse:
"""
Exchange an authorization code for a short-lived BYOK session JWT.
1. Validates the authorization code and PKCE challenge.
2. Stores the API key via store_user_credential().
3. Issues a signed JWT with type="byok_session".
"""
from litellm.proxy.proxy_server import master_key, prisma_client
_purge_expired_codes()
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="unsupported_grant_type")
record = _byok_auth_codes.get(code)
if record is None:
raise HTTPException(status_code=400, detail="invalid_grant")
if time.time() > record["expires_at"]:
del _byok_auth_codes[code]
raise HTTPException(status_code=400, detail="invalid_grant")
# PKCE verification
if not _verify_pkce(code_verifier, record["code_challenge"]):
raise HTTPException(status_code=400, detail="invalid_grant")
# Consume the code (one-time use)
del _byok_auth_codes[code]
server_id: str = record["server_id"]
api_key_value: str = record["api_key"]
# Prefer the user_id that was stored when the code was issued; fall back to
# whatever client_id the token request supplies (they should match).
user_id: str = record.get("user_id") or client_id
if not user_id:
raise HTTPException(
status_code=400,
detail="Cannot determine user_id; pass LiteLLM user id as client_id",
)
# Persist the BYOK credential
if prisma_client is not None:
try:
await store_user_credential(
prisma_client=prisma_client,
user_id=user_id,
server_id=server_id,
credential=api_key_value,
)
# Invalidate any cached negative result so the user isn't blocked
# for up to the TTL period after completing the OAuth flow.
from litellm.proxy._experimental.mcp_server.server import (
_invalidate_byok_cred_cache,
)
_invalidate_byok_cred_cache(user_id, server_id)
except Exception as exc:
verbose_proxy_logger.error(
"byok_token: failed to store user credential for user=%s server=%s: %s",
user_id,
server_id,
exc,
)
raise HTTPException(status_code=500, detail="Failed to store credential")
else:
verbose_proxy_logger.warning(
"byok_token: prisma_client is None — credential not persisted"
)
if master_key is None:
raise HTTPException(
status_code=500, detail="Master key not configured; cannot issue token"
)
now = int(time.time())
payload = {
"user_id": user_id,
"server_id": server_id,
# "type" distinguishes this from regular proxy auth tokens.
# The proxy's SSO JWT path uses asymmetric keys (RS256/ES256), so an
# HS256 token signed with master_key cannot be accepted there.
"type": "byok_session",
"iat": now,
"exp": now + 3600,
}
access_token = jwt.encode(payload, cast(str, master_key), algorithm="HS256")
return JSONResponse(
{
"access_token": access_token,
"token_type": "bearer",
"expires_in": 3600,
}
)

View File

@@ -0,0 +1,77 @@
"""
Cost calculator for MCP tools.
"""
from typing import TYPE_CHECKING, Any, Optional, cast
from litellm.types.mcp import MCPServerCostInfo
from litellm.types.utils import StandardLoggingMCPToolCall
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LitellmLoggingObject,
)
else:
LitellmLoggingObject = Any
class MCPCostCalculator:
@staticmethod
def calculate_mcp_tool_call_cost(
litellm_logging_obj: Optional[LitellmLoggingObject],
) -> float:
"""
Calculate the cost of an MCP tool call.
Default is 0.0, unless user specifies a custom cost per request for MCP tools.
"""
if litellm_logging_obj is None:
return 0.0
#########################################################
# Get the response cost from logging object model_call_details
# This is set when a user modifies the response in a post_mcp_tool_call_hook
#########################################################
response_cost = litellm_logging_obj.model_call_details.get(
"response_cost", None
)
if response_cost is not None:
return response_cost
#########################################################
# Unpack the mcp_tool_call_metadata
#########################################################
mcp_tool_call_metadata: StandardLoggingMCPToolCall = (
cast(
StandardLoggingMCPToolCall,
litellm_logging_obj.model_call_details.get(
"mcp_tool_call_metadata", {}
),
)
or {}
)
mcp_server_cost_info: MCPServerCostInfo = (
mcp_tool_call_metadata.get("mcp_server_cost_info") or MCPServerCostInfo()
)
#########################################################
# User defined cost per query
#########################################################
default_cost_per_query = mcp_server_cost_info.get(
"default_cost_per_query", None
)
tool_name_to_cost_per_query: dict = (
mcp_server_cost_info.get("tool_name_to_cost_per_query", {}) or {}
)
tool_name = mcp_tool_call_metadata.get("name", "")
#########################################################
# 1. If tool_name is in tool_name_to_cost_per_query, use the cost per query
# 2. If tool_name is not in tool_name_to_cost_per_query, use the default cost per query
# 3. Default to 0.0 if no cost per query is found
#########################################################
cost_per_query: float = 0.0
if tool_name in tool_name_to_cost_per_query:
cost_per_query = tool_name_to_cost_per_query[tool_name]
elif default_cost_per_query is not None:
cost_per_query = default_cost_per_query
return cost_per_query

View File

@@ -0,0 +1,767 @@
import base64
import json
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LiteLLM_ObjectPermissionTable,
LiteLLM_TeamTable,
MCPApprovalStatus,
MCPSubmissionsSummary,
NewMCPServerRequest,
SpecialMCPServerName,
UpdateMCPServerRequest,
UserAPIKeyAuth,
)
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
_get_salt_key,
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.utils import PrismaClient
from litellm.types.mcp import MCPCredentials
def _prepare_mcp_server_data(
data: Union[NewMCPServerRequest, UpdateMCPServerRequest],
) -> Dict[str, Any]:
"""
Helper function to prepare MCP server data for database operations.
Handles JSON field serialization for mcp_info and env fields.
Args:
data: NewMCPServerRequest or UpdateMCPServerRequest object
Returns:
Dict with properly serialized JSON fields
"""
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Convert model to dict
data_dict = data.model_dump(exclude_none=True)
# Ensure alias is always present in the dict (even if None)
if "alias" not in data_dict:
data_dict["alias"] = getattr(data, "alias", None)
# Handle credentials serialization
credentials = data_dict.get("credentials")
if credentials is not None:
data_dict["credentials"] = encrypt_credentials(
credentials=credentials, encryption_key=_get_salt_key()
)
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
# Handle static_headers serialization
if data.static_headers is not None:
data_dict["static_headers"] = safe_dumps(data.static_headers)
# Handle mcp_info serialization
if data.mcp_info is not None:
data_dict["mcp_info"] = safe_dumps(data.mcp_info)
# Handle env serialization
if data.env is not None:
data_dict["env"] = safe_dumps(data.env)
# Handle tool name override serialization
if data.tool_name_to_display_name is not None:
data_dict["tool_name_to_display_name"] = safe_dumps(
data.tool_name_to_display_name
)
if data.tool_name_to_description is not None:
data_dict["tool_name_to_description"] = safe_dumps(
data.tool_name_to_description
)
# mcp_access_groups is already List[str], no serialization needed
# Force include is_byok even when False (exclude_none=True would not drop it,
# but be explicit to ensure a False value is always written to the DB).
data_dict["is_byok"] = getattr(data, "is_byok", False)
return data_dict
def encrypt_credentials(
credentials: MCPCredentials, encryption_key: Optional[str]
) -> MCPCredentials:
auth_value = credentials.get("auth_value")
if auth_value is not None:
credentials["auth_value"] = encrypt_value_helper(
value=auth_value,
new_encryption_key=encryption_key,
)
client_id = credentials.get("client_id")
if client_id is not None:
credentials["client_id"] = encrypt_value_helper(
value=client_id,
new_encryption_key=encryption_key,
)
client_secret = credentials.get("client_secret")
if client_secret is not None:
credentials["client_secret"] = encrypt_value_helper(
value=client_secret,
new_encryption_key=encryption_key,
)
# AWS SigV4 credential fields
aws_access_key_id = credentials.get("aws_access_key_id")
if aws_access_key_id is not None:
credentials["aws_access_key_id"] = encrypt_value_helper(
value=aws_access_key_id,
new_encryption_key=encryption_key,
)
aws_secret_access_key = credentials.get("aws_secret_access_key")
if aws_secret_access_key is not None:
credentials["aws_secret_access_key"] = encrypt_value_helper(
value=aws_secret_access_key,
new_encryption_key=encryption_key,
)
aws_session_token = credentials.get("aws_session_token")
if aws_session_token is not None:
credentials["aws_session_token"] = encrypt_value_helper(
value=aws_session_token,
new_encryption_key=encryption_key,
)
# aws_region_name and aws_service_name are NOT secrets — stored as-is
return credentials
def decrypt_credentials(
credentials: MCPCredentials,
) -> MCPCredentials:
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
secret_fields = [
"auth_value",
"client_id",
"client_secret",
"aws_access_key_id",
"aws_secret_access_key",
"aws_session_token",
]
for field in secret_fields:
value = credentials.get(field) # type: ignore[literal-required]
if value is not None and isinstance(value, str):
credentials[field] = decrypt_value_helper( # type: ignore[literal-required]
value=value,
key=field,
exception_type="debug",
return_original_value=True,
)
return credentials
async def get_all_mcp_servers(
prisma_client: PrismaClient,
approval_status: Optional[str] = None,
) -> List[LiteLLM_MCPServerTable]:
"""
Returns mcp servers from the db, optionally filtered by approval_status.
Pass approval_status=None to return all servers regardless of approval state.
"""
try:
where: Dict[str, Any] = {}
if approval_status is not None:
where["approval_status"] = approval_status
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
where=where if where else {}
)
return [
LiteLLM_MCPServerTable(**mcp_server.model_dump())
for mcp_server in mcp_servers
]
except Exception as e:
verbose_proxy_logger.debug(
"litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format(
str(e)
)
)
return []
async def get_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> Optional[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp server from the db iff exists
"""
mcp_server: Optional[
LiteLLM_MCPServerTable
] = await prisma_client.db.litellm_mcpservertable.find_unique(
where={
"server_id": server_id,
}
)
return mcp_server
async def get_mcp_servers(
prisma_client: PrismaClient, server_ids: Iterable[str]
) -> List[LiteLLM_MCPServerTable]:
"""
Returns the matching mcp servers from the db with the server_ids
"""
_mcp_servers: List[
LiteLLM_MCPServerTable
] = await prisma_client.db.litellm_mcpservertable.find_many(
where={
"server_id": {"in": server_ids},
}
)
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
for _mcp_server in _mcp_servers:
final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump()))
return final_mcp_servers
async def get_mcp_servers_by_verificationtoken(
prisma_client: PrismaClient, token: str
) -> List[str]:
"""
Returns the mcp servers from the db for the verification token
"""
verification_token_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_verificationtoken.find_unique(
where={
"token": token,
},
include={
"object_permission": True,
},
)
)
mcp_servers: Optional[List[str]] = []
if (
verification_token_record is not None
and verification_token_record.object_permission is not None
):
mcp_servers = verification_token_record.object_permission.mcp_servers
return mcp_servers or []
async def get_mcp_servers_by_team(
prisma_client: PrismaClient, team_id: str
) -> List[str]:
"""
Returns the mcp servers from the db for the team id
"""
team_record: LiteLLM_TeamTable = (
await prisma_client.db.litellm_teamtable.find_unique(
where={
"team_id": team_id,
},
include={
"object_permission": True,
},
)
)
mcp_servers: Optional[List[str]] = []
if team_record is not None and team_record.object_permission is not None:
mcp_servers = team_record.object_permission.mcp_servers
return mcp_servers or []
async def get_all_mcp_servers_for_user(
prisma_client: PrismaClient,
user: UserAPIKeyAuth,
) -> List[LiteLLM_MCPServerTable]:
"""
Get all the mcp servers filtered by the given user has access to.
Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to.
"""
mcp_server_ids: Set[str] = set()
mcp_servers = []
# Get the mcp servers for the key
if user.api_key:
token_mcp_servers = await get_mcp_servers_by_verificationtoken(
prisma_client, user.api_key
)
mcp_server_ids.update(token_mcp_servers)
# check for special team membership
if (
SpecialMCPServerName.all_team_servers in mcp_server_ids
and user.team_id is not None
):
team_mcp_servers = await get_mcp_servers_by_team(
prisma_client, user.team_id
)
mcp_server_ids.update(team_mcp_servers)
if len(mcp_server_ids) > 0:
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids)
return mcp_servers
async def get_objectpermissions_for_mcp_server(
prisma_client: PrismaClient, mcp_server_id: str
) -> List[LiteLLM_ObjectPermissionTable]:
"""
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
"""
object_permission_records = (
await prisma_client.db.litellm_objectpermissiontable.find_many(
where={
"mcp_servers": {"has": mcp_server_id},
},
include={
"teams": True,
"verification_tokens": True,
},
)
)
return object_permission_records
async def get_virtualkeys_for_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> List:
"""
Get all the virtual keys that have access to the mcp server
"""
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
where={
"mcp_servers": {"has": server_id},
},
)
if virtual_keys is None:
return []
return virtual_keys
async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str):
"""
Remove the mcp server from the team
"""
pass
async def delete_mcp_server_from_virtualkey():
"""
Remove the mcp server from the virtual key
"""
pass
async def delete_mcp_server(
prisma_client: PrismaClient, server_id: str
) -> Optional[LiteLLM_MCPServerTable]:
"""
Delete the mcp server from the db by server_id
Returns the deleted mcp server record if it exists, otherwise None
"""
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
where={
"server_id": server_id,
},
)
return deleted_server
async def create_mcp_server(
prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str
) -> LiteLLM_MCPServerTable:
"""
Create a new mcp server record in the db
"""
if data.server_id is None:
data.server_id = str(uuid.uuid4())
# Use helper to prepare data with proper JSON serialization
data_dict = _prepare_mcp_server_data(data)
# Add audit fields
data_dict["created_by"] = touched_by
data_dict["updated_by"] = touched_by
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
data=data_dict # type: ignore
)
return new_mcp_server
async def update_mcp_server(
prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str
) -> LiteLLM_MCPServerTable:
"""
Update a new mcp server record in the db
"""
import json
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
# Use helper to prepare data with proper JSON serialization
data_dict = _prepare_mcp_server_data(data)
# Pre-fetch existing record once if we need it for auth_type or credential logic
existing = None
has_credentials = (
"credentials" in data_dict and data_dict["credentials"] is not None
)
if data.auth_type or has_credentials:
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
where={"server_id": data.server_id}
)
# Clear stale credentials when auth_type changes but no new credentials provided
if (
data.auth_type
and "credentials" not in data_dict
and existing
and existing.auth_type is not None
and existing.auth_type != data.auth_type
):
data_dict["credentials"] = None
# Merge credentials: preserve existing fields not present in the update.
# Without this, a partial credential update (e.g. changing only region)
# would wipe encrypted secrets that the UI cannot display back.
if "credentials" in data_dict and data_dict["credentials"] is not None:
if existing and existing.credentials:
# Only merge when auth_type is unchanged. Switching auth types
# (e.g. oauth2 → api_key) should replace credentials entirely
# to avoid stale secrets from the previous auth type lingering.
auth_type_unchanged = (
data.auth_type is None or data.auth_type == existing.auth_type
)
if auth_type_unchanged:
existing_creds = (
json.loads(existing.credentials)
if isinstance(existing.credentials, str)
else dict(existing.credentials)
)
new_creds = (
json.loads(data_dict["credentials"])
if isinstance(data_dict["credentials"], str)
else dict(data_dict["credentials"])
)
# New values override existing; existing keys not in update are preserved
merged = {**existing_creds, **new_creds}
data_dict["credentials"] = safe_dumps(merged)
# Add audit fields
data_dict["updated_by"] = touched_by
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": data.server_id}, data=data_dict # type: ignore
)
return updated_mcp_server
async def rotate_mcp_server_credentials_master_key(
prisma_client: PrismaClient, touched_by: str, new_master_key: str
):
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
for mcp_server in mcp_servers:
credentials = mcp_server.credentials
if not credentials:
continue
credentials_copy = dict(credentials)
# Decrypt with current key first, then re-encrypt with new key
decrypted_credentials = decrypt_credentials(
credentials=cast(MCPCredentials, credentials_copy),
)
encrypted_credentials = encrypt_credentials(
credentials=decrypted_credentials,
encryption_key=new_master_key,
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
serialized_credentials = safe_dumps(encrypted_credentials)
await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": mcp_server.server_id},
data={
"credentials": serialized_credentials,
"updated_by": touched_by,
},
)
async def store_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
credential: str,
) -> None:
"""Store a user credential for a BYOK MCP server."""
encoded = base64.urlsafe_b64encode(credential.encode()).decode()
await prisma_client.db.litellm_mcpusercredentials.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
"user_id": user_id,
"server_id": server_id,
"credential_b64": encoded,
},
"update": {"credential_b64": encoded},
},
)
async def get_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> Optional[str]:
"""Return credential for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
return None
try:
return base64.urlsafe_b64decode(row.credential_b64).decode()
except Exception:
# Fall back to nacl decryption for credentials stored by older code
return decrypt_value_helper(
value=row.credential_b64,
key="byok_credential",
exception_type="debug",
return_original_value=False,
)
async def has_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> bool:
"""Return True if the user has a stored credential for this server."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
return row is not None
async def delete_user_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> None:
"""Delete the user's stored credential for a BYOK MCP server."""
await prisma_client.db.litellm_mcpusercredentials.delete(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
# ── OAuth2 user-credential helpers ────────────────────────────────────────────
async def store_user_oauth_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
access_token: str,
refresh_token: Optional[str] = None,
expires_in: Optional[int] = None,
scopes: Optional[List[str]] = None,
) -> None:
"""Persist an OAuth2 access token for a user+server pair.
The payload is JSON-serialised and stored base64-encoded in the same
``credential_b64`` column used by BYOK. A ``"type": "oauth2"`` key
differentiates it from plain BYOK API keys.
"""
expires_at: Optional[str] = None
if expires_in is not None:
expires_at = (
datetime.now(timezone.utc) + timedelta(seconds=expires_in)
).isoformat()
payload: Dict[str, Any] = {
"type": "oauth2",
"access_token": access_token,
"connected_at": datetime.now(timezone.utc).isoformat(),
}
if refresh_token:
payload["refresh_token"] = refresh_token
if expires_at:
payload["expires_at"] = expires_at
if scopes:
payload["scopes"] = scopes
# Guard against silently overwriting a BYOK credential with an OAuth token.
# BYOK credentials lack a "type" field (or use a non-"oauth2" type).
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if existing is not None:
_byok_error = ValueError(
f"A non-OAuth2 credential already exists for user {user_id} "
f"and server {server_id}. Refusing to overwrite."
)
try:
raw = json.loads(base64.urlsafe_b64decode(existing.credential_b64).decode())
except Exception:
# Credential is not base64+JSON — it's a plain-text BYOK key.
raise _byok_error
if raw.get("type") != "oauth2":
raise _byok_error
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
await prisma_client.db.litellm_mcpusercredentials.upsert(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
data={
"create": {
"user_id": user_id,
"server_id": server_id,
"credential_b64": encoded,
},
"update": {"credential_b64": encoded},
},
)
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
"""Return True if the OAuth2 credential's access_token has expired.
Checks the ``expires_at`` ISO-format string stored in the credential payload.
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
"""
expires_at = cred.get("expires_at")
if not expires_at:
return False
try:
exp_dt = datetime.fromisoformat(expires_at)
if exp_dt.tzinfo is None:
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > exp_dt
except (ValueError, TypeError):
return False
async def get_user_oauth_credential(
prisma_client: PrismaClient,
user_id: str,
server_id: str,
) -> Optional[Dict[str, Any]]:
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
)
if row is None:
return None
try:
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
parsed = json.loads(decoded)
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
return parsed
# Row exists but is a BYOK (plain string), not an OAuth token
return None
except Exception:
return None
async def list_user_oauth_credentials(
prisma_client: PrismaClient,
user_id: str,
) -> List[Dict[str, Any]]:
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
where={"user_id": user_id}
)
results: List[Dict[str, Any]] = []
for row in rows:
try:
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
parsed = json.loads(decoded)
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
parsed["server_id"] = row.server_id
results.append(parsed)
except Exception:
pass # Skip non-OAuth rows (BYOK plain strings)
return results
async def approve_mcp_server(
prisma_client: PrismaClient,
server_id: str,
touched_by: str,
) -> LiteLLM_MCPServerTable:
"""Set approval_status=active and record reviewed_at."""
now = datetime.now(timezone.utc)
updated = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": server_id},
data={
"approval_status": MCPApprovalStatus.active,
"reviewed_at": now,
"updated_by": touched_by,
},
)
return LiteLLM_MCPServerTable(**updated.model_dump())
async def reject_mcp_server(
prisma_client: PrismaClient,
server_id: str,
touched_by: str,
review_notes: Optional[str] = None,
) -> LiteLLM_MCPServerTable:
"""Set approval_status=rejected, record reviewed_at and review_notes."""
now = datetime.now(timezone.utc)
data: Dict[str, Any] = {
"approval_status": MCPApprovalStatus.rejected,
"reviewed_at": now,
"updated_by": touched_by,
}
if review_notes is not None:
data["review_notes"] = review_notes
updated = await prisma_client.db.litellm_mcpservertable.update(
where={"server_id": server_id},
data=data,
)
return LiteLLM_MCPServerTable(**updated.model_dump())
async def get_mcp_submissions(
prisma_client: PrismaClient,
) -> MCPSubmissionsSummary:
"""
Returns all MCP servers that were submitted by non-admin users (submitted_at IS NOT NULL),
along with a summary count breakdown by approval_status.
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
"""
rows = await prisma_client.db.litellm_mcpservertable.find_many(
where={"submitted_at": {"not": None}},
order={"submitted_at": "desc"},
take=500, # safety cap; paginate if needed in a future iteration
)
items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows]
pending = sum(
1 for i in items if i.approval_status == MCPApprovalStatus.pending_review
)
active = sum(1 for i in items if i.approval_status == MCPApprovalStatus.active)
rejected = sum(1 for i in items if i.approval_status == MCPApprovalStatus.rejected)
return MCPSubmissionsSummary(
total=len(items),
pending_review=pending,
active=active,
rejected=rejected,
items=items,
)

View File

@@ -0,0 +1,741 @@
import json
from typing import Optional
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
decrypt_value_helper,
encrypt_value_helper,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import get_server_root_path
from litellm.types.mcp import MCPAuth
from litellm.types.mcp_server.mcp_server_manager import MCPServer
router = APIRouter(
tags=["mcp"],
)
def get_request_base_url(request: Request) -> str:
"""
Get the base URL for the request, considering X-Forwarded-* headers.
When behind a proxy (like nginx), the proxy may set:
- X-Forwarded-Proto: The original protocol (http/https)
- X-Forwarded-Host: The original host (may include port)
- X-Forwarded-Port: The original port (if not in Host header)
Args:
request: FastAPI Request object
Returns:
The reconstructed base URL (e.g., "https://proxy.example.com")
"""
base_url = str(request.base_url).rstrip("/")
parsed = urlparse(base_url)
# Get forwarded headers
x_forwarded_proto = request.headers.get("X-Forwarded-Proto")
x_forwarded_host = request.headers.get("X-Forwarded-Host")
x_forwarded_port = request.headers.get("X-Forwarded-Port")
# Start with the original scheme
scheme = x_forwarded_proto if x_forwarded_proto else parsed.scheme
# Handle host and port
if x_forwarded_host:
# X-Forwarded-Host may already include port (e.g., "example.com:8080")
if ":" in x_forwarded_host and not x_forwarded_host.startswith("["):
# Host includes port
netloc = x_forwarded_host
elif x_forwarded_port:
# Port is separate
netloc = f"{x_forwarded_host}:{x_forwarded_port}"
else:
# Just host, no explicit port
netloc = x_forwarded_host
else:
# No X-Forwarded-Host, use original netloc
netloc = parsed.netloc
if x_forwarded_port and ":" not in netloc:
# Add forwarded port if not already in netloc
netloc = f"{netloc}:{x_forwarded_port}"
# Reconstruct the URL
return urlunparse((scheme, netloc, parsed.path, "", "", ""))
def encode_state_with_base_url(
base_url: str,
original_state: str,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
client_redirect_uri: Optional[str] = None,
) -> str:
"""
Encode the base_url, original state, and PKCE parameters using encryption.
Args:
base_url: The base URL to encode
original_state: The original state parameter
code_challenge: PKCE code challenge from client
code_challenge_method: PKCE code challenge method from client
client_redirect_uri: Original redirect_uri from client
Returns:
An encrypted string that encodes all values
"""
state_data = {
"base_url": base_url,
"original_state": original_state,
"code_challenge": code_challenge,
"code_challenge_method": code_challenge_method,
"client_redirect_uri": client_redirect_uri,
}
state_json = json.dumps(state_data, sort_keys=True)
encrypted_state = encrypt_value_helper(state_json)
return encrypted_state
def decode_state_hash(encrypted_state: str) -> dict:
"""
Decode an encrypted state to retrieve all OAuth session data.
Args:
encrypted_state: The encrypted string to decode
Returns:
A dict containing base_url, original_state, and optional PKCE parameters
Raises:
Exception: If decryption fails or data is malformed
"""
decrypted_json = decrypt_value_helper(encrypted_state, "oauth_state")
if decrypted_json is None:
raise ValueError("Failed to decrypt state parameter")
state_data = json.loads(decrypted_json)
return state_data
def _resolve_oauth2_server_for_root_endpoints(
client_ip: Optional[str] = None,
) -> Optional[MCPServer]:
"""
Resolve the MCP server for root-level OAuth endpoints (no server name in path).
When the MCP SDK hits root-level endpoints like /register, /authorize, /token
without a server name prefix, we try to find the right server automatically.
Returns the server if exactly one OAuth2 server is configured, else None.
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip)
oauth2_servers = [s for s in registry.values() if s.auth_type == MCPAuth.oauth2]
if len(oauth2_servers) == 1:
return oauth2_servers[0]
return None
async def authorize_with_server(
request: Request,
mcp_server: MCPServer,
client_id: str,
redirect_uri: str,
state: str = "",
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
if mcp_server.auth_type != "oauth2":
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
parsed = urlparse(redirect_uri)
base_url = urlunparse(parsed._replace(query=""))
request_base_url = get_request_base_url(request)
encoded_state = encode_state_with_base_url(
base_url=base_url,
original_state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
client_redirect_uri=redirect_uri,
)
params = {
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"redirect_uri": f"{request_base_url}/callback",
"state": encoded_state,
"response_type": response_type or "code",
}
if scope:
params["scope"] = scope
elif mcp_server.scopes:
params["scope"] = " ".join(mcp_server.scopes)
if code_challenge:
params["code_challenge"] = code_challenge
if code_challenge_method:
params["code_challenge_method"] = code_challenge_method
parsed_auth_url = urlparse(mcp_server.authorization_url)
existing_params = dict(parse_qsl(parsed_auth_url.query))
existing_params.update(params)
final_url = urlunparse(parsed_auth_url._replace(query=urlencode(existing_params)))
return RedirectResponse(final_url)
async def exchange_token_with_server(
request: Request,
mcp_server: MCPServer,
grant_type: str,
code: Optional[str],
redirect_uri: Optional[str],
client_id: str,
client_secret: Optional[str],
code_verifier: Optional[str],
):
if grant_type != "authorization_code":
raise HTTPException(status_code=400, detail="Unsupported grant_type")
if mcp_server.token_url is None:
raise HTTPException(status_code=400, detail="MCP server token url is not set")
proxy_base_url = get_request_base_url(request)
token_data = {
"grant_type": "authorization_code",
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
"client_secret": mcp_server.client_secret
if mcp_server.client_secret
else client_secret,
"code": code,
"redirect_uri": f"{proxy_base_url}/callback",
}
if code_verifier:
token_data["code_verifier"] = code_verifier
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
response = await async_client.post(
mcp_server.token_url,
headers={"Accept": "application/json"},
data=token_data,
)
response.raise_for_status()
token_response = response.json()
access_token = token_response["access_token"]
result = {
"access_token": access_token,
"token_type": token_response.get("token_type", "Bearer"),
"expires_in": token_response.get("expires_in", 3600),
}
if "refresh_token" in token_response and token_response["refresh_token"]:
result["refresh_token"] = token_response["refresh_token"]
if "scope" in token_response and token_response["scope"]:
result["scope"] = token_response["scope"]
return JSONResponse(result)
async def register_client_with_server(
request: Request,
mcp_server: MCPServer,
client_name: str,
grant_types: Optional[list],
response_types: Optional[list],
token_endpoint_auth_method: Optional[str],
fallback_client_id: Optional[str] = None,
):
request_base_url = get_request_base_url(request)
dummy_return = {
"client_id": fallback_client_id or mcp_server.server_name,
"client_secret": "dummy",
"redirect_uris": [f"{request_base_url}/callback"],
}
if mcp_server.client_id and mcp_server.client_secret:
return dummy_return
if mcp_server.authorization_url is None:
raise HTTPException(
status_code=400, detail="MCP server authorization url is not set"
)
if mcp_server.registration_url is None:
return dummy_return
register_data = {
"client_name": client_name,
"redirect_uris": [f"{request_base_url}/callback"],
"grant_types": grant_types or [],
"response_types": response_types or [],
"token_endpoint_auth_method": token_endpoint_auth_method or "",
}
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
async_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.Oauth2Register
)
response = await async_client.post(
mcp_server.registration_url,
headers=headers,
json=register_data,
)
response.raise_for_status()
token_response = response.json()
return JSONResponse(token_response)
@router.get("/{mcp_server_name}/authorize")
@router.get("/authorize")
async def authorize(
request: Request,
redirect_uri: str,
client_id: Optional[str] = None,
state: str = "",
mcp_server_name: Optional[str] = None,
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
response_type: Optional[str] = None,
scope: Optional[str] = None,
):
# Redirect to real OAuth provider with PKCE support
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
lookup_name: Optional[str] = mcp_server_name or client_id
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = (
global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if lookup_name
else None
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
# Use server's stored client_id when caller doesn't supply one.
# Raise a clear error instead of passing an empty string — an empty
# client_id would silently produce a broken authorization URL.
resolved_client_id: str = mcp_server.client_id or client_id or ""
if not resolved_client_id:
raise HTTPException(
status_code=400,
detail={
"error": "client_id is required but was not supplied and is not "
"stored on the MCP server record. Provide client_id as a query "
"parameter or configure it on the server."
},
)
return await authorize_with_server(
request=request,
mcp_server=mcp_server,
client_id=resolved_client_id,
redirect_uri=redirect_uri,
state=state,
code_challenge=code_challenge,
code_challenge_method=code_challenge_method,
response_type=response_type,
scope=scope,
)
@router.post("/{mcp_server_name}/token")
@router.post("/token")
async def token_endpoint(
request: Request,
grant_type: str = Form(...),
code: str = Form(None),
redirect_uri: str = Form(None),
client_id: str = Form(...),
client_secret: Optional[str] = Form(None),
code_verifier: str = Form(None),
mcp_server_name: Optional[str] = None,
):
"""
Accept the authorization code from client and exchange it for OAuth token.
Supports PKCE flow by forwarding code_verifier to upstream provider.
1. Call the token endpoint with PKCE parameters
2. Store the user's token in the db - and generate a LiteLLM virtual key
3. Return the token
4. Return a virtual key in this response
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
lookup_name = mcp_server_name or client_id
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
lookup_name, client_ip=client_ip
)
if mcp_server is None and mcp_server_name is None:
mcp_server = _resolve_oauth2_server_for_root_endpoints()
if mcp_server is None:
raise HTTPException(status_code=404, detail="MCP server not found")
return await exchange_token_with_server(
request=request,
mcp_server=mcp_server,
grant_type=grant_type,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)
@router.get("/callback")
async def callback(code: str, state: str):
try:
# Decode the state hash to get base_url, original state, and PKCE params
state_data = decode_state_hash(state)
base_url = state_data["base_url"]
original_state = state_data["original_state"]
# Forward code and original state back to client
params = {"code": code, "state": original_state}
# Forward to client's callback endpoint
complete_returned_url = f"{base_url}?{urlencode(params)}"
return RedirectResponse(url=complete_returned_url, status_code=302)
except Exception:
# fallback if state hash not found
return HTMLResponse(
"<html><body>Authentication incomplete. You can close this window.</body></html>"
)
# ------------------------------
# Optional .well-known endpoints for MCP + OAuth discovery
# ------------------------------
"""
Per SEP-985, the client MUST:
1. Try resource_metadata from WWW-Authenticate header (if present)
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
(
If the resource identifier value contains a path or query component, any terminating slash (/)
following the host component MUST be removed before inserting /.well-known/ and the well-known
URI path suffix between the host component and the path(include root path) and/or query components.
https://datatracker.ietf.org/doc/html/rfc9728#section-3.1)
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
Dual Pattern Support:
- Standard MCP pattern: /mcp/{server_name} (recommended, used by mcp-inspector, VSCode Copilot)
- LiteLLM legacy pattern: /{server_name}/mcp (backward compatibility)
The resource URL returned matches the pattern used in the discovery request.
"""
def _build_oauth_protected_resource_response(
request: Request,
mcp_server_name: Optional[str],
use_standard_pattern: bool,
) -> dict:
"""
Build OAuth protected resource response with the appropriate URL pattern.
Args:
request: FastAPI Request object
mcp_server_name: Name of the MCP server
use_standard_pattern: If True, use /mcp/{server_name} pattern;
if False, use /{server_name}/mcp pattern
Returns:
OAuth protected resource metadata dict
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
request_base_url = get_request_base_url(request)
# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name
mcp_server: Optional[MCPServer] = None
if mcp_server_name:
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
# Build resource URL based on the pattern
if mcp_server_name:
if use_standard_pattern:
# Standard MCP pattern: /mcp/{server_name}
resource_url = f"{request_base_url}/mcp/{mcp_server_name}"
else:
# LiteLLM legacy pattern: /{server_name}/mcp
resource_url = f"{request_base_url}/{mcp_server_name}/mcp"
else:
resource_url = f"{request_base_url}/mcp"
return {
"authorization_servers": [
(
f"{request_base_url}/{mcp_server_name}"
if mcp_server_name
else f"{request_base_url}"
)
],
"resource": resource_url,
"scopes_supported": mcp_server.scopes
if mcp_server and mcp_server.scopes
else [],
}
# Standard MCP pattern: /.well-known/oauth-protected-resource/mcp/{server_name}
# This is the pattern expected by standard MCP clients (mcp-inspector, VSCode Copilot)
@router.get(
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
)
async def oauth_protected_resource_mcp_standard(request: Request, mcp_server_name: str):
"""
OAuth protected resource discovery endpoint using standard MCP URL pattern.
Standard pattern: /mcp/{server_name}
Discovery path: /.well-known/oauth-protected-resource/mcp/{server_name}
This endpoint is compliant with MCP specification and works with standard
MCP clients like mcp-inspector and VSCode Copilot.
"""
return _build_oauth_protected_resource_response(
request=request,
mcp_server_name=mcp_server_name,
use_standard_pattern=True,
)
# LiteLLM legacy pattern: /.well-known/oauth-protected-resource/{server_name}/mcp
# Kept for backward compatibility with existing deployments
@router.get(
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}/mcp"
)
@router.get("/.well-known/oauth-protected-resource")
async def oauth_protected_resource_mcp(
request: Request, mcp_server_name: Optional[str] = None
):
"""
OAuth protected resource discovery endpoint using LiteLLM legacy URL pattern.
Legacy pattern: /{server_name}/mcp
Discovery path: /.well-known/oauth-protected-resource/{server_name}/mcp
This endpoint is kept for backward compatibility. New integrations should
use the standard MCP pattern (/mcp/{server_name}) instead.
"""
return _build_oauth_protected_resource_response(
request=request,
mcp_server_name=mcp_server_name,
use_standard_pattern=False,
)
"""
https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
RFC 8414: Path-aware OAuth discovery
If the issuer identifier value contains a path component, any
terminating "/" MUST be removed before inserting "/.well-known/" and
the well-known URI suffix between the host component and the path(include root path)
component.
"""
def _build_oauth_authorization_server_response(
request: Request,
mcp_server_name: Optional[str],
) -> dict:
"""
Build OAuth authorization server metadata response.
Args:
request: FastAPI Request object
mcp_server_name: Name of the MCP server
Returns:
OAuth authorization server metadata dict
"""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
request_base_url = get_request_base_url(request)
# When no server name provided, try to resolve the single OAuth2 server
if mcp_server_name is None:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
mcp_server_name = resolved.server_name or resolved.name
authorization_endpoint = (
f"{request_base_url}/{mcp_server_name}/authorize"
if mcp_server_name
else f"{request_base_url}/authorize"
)
token_endpoint = (
f"{request_base_url}/{mcp_server_name}/token"
if mcp_server_name
else f"{request_base_url}/token"
)
mcp_server: Optional[MCPServer] = None
if mcp_server_name:
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
return {
"issuer": request_base_url, # point to your proxy
"authorization_endpoint": authorization_endpoint,
"token_endpoint": token_endpoint,
"response_types_supported": ["code"],
"scopes_supported": mcp_server.scopes
if mcp_server and mcp_server.scopes
else [],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"token_endpoint_auth_methods_supported": ["client_secret_post"],
# Claude expects a registration endpoint, even if we just fake it
"registration_endpoint": f"{request_base_url}/{mcp_server_name}/register"
if mcp_server_name
else f"{request_base_url}/register",
}
# Standard MCP pattern: /.well-known/oauth-authorization-server/mcp/{server_name}
@router.get(
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
)
async def oauth_authorization_server_mcp_standard(
request: Request, mcp_server_name: str
):
"""
OAuth authorization server discovery endpoint using standard MCP URL pattern.
Standard pattern: /mcp/{server_name}
Discovery path: /.well-known/oauth-authorization-server/mcp/{server_name}
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
# LiteLLM legacy pattern and root endpoint
@router.get(
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}"
)
@router.get("/.well-known/oauth-authorization-server")
async def oauth_authorization_server_mcp(
request: Request, mcp_server_name: Optional[str] = None
):
"""
OAuth authorization server discovery endpoint.
Supports both legacy pattern (/{server_name}) and root endpoint.
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
# Alias for standard OpenID discovery
@router.get("/.well-known/openid-configuration")
async def openid_configuration(request: Request):
return await oauth_authorization_server_mcp(request)
# Additional legacy pattern support
@router.get("/.well-known/oauth-authorization-server/{mcp_server_name}/mcp")
async def oauth_authorization_server_legacy(request: Request, mcp_server_name: str):
"""
OAuth authorization server discovery for legacy /{server_name}/mcp pattern.
"""
return _build_oauth_authorization_server_response(
request=request,
mcp_server_name=mcp_server_name,
)
@router.post("/{mcp_server_name}/register")
@router.post("/register")
async def register_client(request: Request, mcp_server_name: Optional[str] = None):
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
# Get the correct base URL considering X-Forwarded-* headers
request_base_url = get_request_base_url(request)
request_data = await _read_request_body(request=request)
data: dict = {**request_data}
dummy_return = {
"client_id": mcp_server_name or "dummy_client",
"client_secret": "dummy",
"redirect_uris": [f"{request_base_url}/callback"],
}
if not mcp_server_name:
resolved = _resolve_oauth2_server_for_root_endpoints()
if resolved:
return await register_client_with_server(
request=request,
mcp_server=resolved,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=resolved.server_name or resolved.name,
)
return dummy_return
client_ip = IPAddressUtils.get_mcp_client_ip(request)
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
mcp_server_name, client_ip=client_ip
)
if mcp_server is None:
return dummy_return
return await register_client_with_server(
request=request,
mcp_server=mcp_server,
client_name=data.get("client_name", ""),
grant_types=data.get("grant_types", []),
response_types=data.get("response_types", []),
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
fallback_client_id=mcp_server_name,
)

View File

@@ -0,0 +1,16 @@
"""Guardrail translation mapping for MCP tool calls."""
from litellm.proxy._experimental.mcp_server.guardrail_translation.handler import (
MCPGuardrailTranslationHandler,
)
from litellm.types.utils import CallTypes
# This mapping lives alongside the MCP server implementation because MCP
# integrations are managed by the proxy subsystem, not litellm.llms providers.
# Unified guardrails import this module explicitly to register the handler.
guardrail_translation_mappings = {
CallTypes.call_mcp_tool: MCPGuardrailTranslationHandler,
}
__all__ = ["guardrail_translation_mappings", "MCPGuardrailTranslationHandler"]

View File

@@ -0,0 +1,99 @@
"""
MCP Guardrail Handler for Unified Guardrails.
Converts an MCP call_tool (name + arguments) into a single OpenAI-compatible
tool_call and passes it to apply_guardrail. Works with the synthetic payload
from ProxyLogging._convert_mcp_to_llm_format.
Note: For MCP tool definitions (schema) -> OpenAI tools=[], see
litellm.experimental_mcp_client.tools.transform_mcp_tool_to_openai_tool
when you have a full MCP Tool from list_tools. Here we only have the call
payload (name + arguments) so we just build the tool_call.
"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from mcp.types import Tool as MCPTool
from litellm._logging import verbose_proxy_logger
from litellm.experimental_mcp_client.tools import transform_mcp_tool_to_openai_tool
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from mcp.types import CallToolResult
from litellm.integrations.custom_guardrail import CustomGuardrail
class MCPGuardrailTranslationHandler(BaseTranslation):
"""Guardrail translation handler for MCP tool calls (passes a single tool_call to guardrail)."""
async def process_input_messages(
self,
data: Dict[str, Any],
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
) -> Dict[str, Any]:
mcp_tool_name = data.get("mcp_tool_name") or data.get("name")
mcp_arguments = data.get("mcp_arguments") or data.get("arguments")
mcp_tool_description = data.get("mcp_tool_description") or data.get(
"description"
)
if mcp_arguments is None or not isinstance(mcp_arguments, dict):
mcp_arguments = {}
if not mcp_tool_name:
verbose_proxy_logger.debug("MCP Guardrail: mcp_tool_name missing")
return data
# Convert MCP input via transform_mcp_tool_to_openai_tool, then map to litellm
# ChatCompletionToolParam (openai SDK type has incompatible strict/cache_control).
mcp_tool = MCPTool(
name=mcp_tool_name,
description=mcp_tool_description or "",
inputSchema={}, # Call payload has no schema; guardrail gets args from request_data
)
openai_tool = transform_mcp_tool_to_openai_tool(mcp_tool)
fn = openai_tool["function"]
tool_def: ChatCompletionToolParam = {
"type": "function",
"function": ChatCompletionToolParamFunctionChunk(
name=fn["name"],
description=fn.get("description") or "",
parameters=fn.get("parameters")
or {
"type": "object",
"properties": {},
"additionalProperties": False,
},
strict=fn.get("strict", False) or False, # Default to False if None
),
}
inputs: GenericGuardrailAPIInputs = GenericGuardrailAPIInputs(
tools=[tool_def],
)
await guardrail_to_apply.apply_guardrail(
inputs=inputs,
request_data=data,
input_type="request",
logging_obj=litellm_logging_obj,
)
return data
async def process_output_response(
self,
response: "CallToolResult",
guardrail_to_apply: "CustomGuardrail",
litellm_logging_obj: Optional[Any] = None,
user_api_key_dict: Optional[Any] = None,
) -> Any:
verbose_proxy_logger.debug(
"MCP Guardrail: Output processing not implemented for MCP tools",
)
return response

View File

@@ -0,0 +1,325 @@
"""
MCP OAuth2 Debug Headers
========================
Client-side debugging for MCP authentication flows.
When a client sends the ``x-litellm-mcp-debug: true`` header, LiteLLM
returns masked diagnostic headers in the response so operators can
troubleshoot OAuth2 issues without SSH access to the gateway.
Response headers returned (all values are masked for safety):
x-mcp-debug-inbound-auth
Which inbound auth headers were present and how they were classified.
Example: ``x-litellm-api-key=Bearer sk-12****1234``
x-mcp-debug-oauth2-token
The OAuth2 token extracted from the Authorization header (masked).
Shows ``(none)`` if absent, or flags ``SAME_AS_LITELLM_KEY`` when
the LiteLLM API key is accidentally leaking to the MCP server.
x-mcp-debug-auth-resolution
Which auth priority was used for the outbound MCP call:
``per-request-header``, ``m2m-client-credentials``, ``static-token``,
``oauth2-passthrough``, or ``no-auth``.
x-mcp-debug-outbound-url
The upstream MCP server URL that will receive the request.
x-mcp-debug-server-auth-type
The ``auth_type`` configured on the MCP server (e.g. ``oauth2``,
``bearer_token``, ``none``).
Debugging Guide
---------------
**Common issue: LiteLLM API key leaking to the MCP server**
Symptom: ``x-mcp-debug-oauth2-token`` shows ``SAME_AS_LITELLM_KEY``.
This means the ``Authorization`` header carries the LiteLLM API key and
it's being forwarded to the upstream MCP server instead of an OAuth2 token.
Fix: Move the LiteLLM key to ``x-litellm-api-key`` so the ``Authorization``
header is free for OAuth2 discovery::
# WRONG — blocks OAuth2 discovery
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "Authorization: Bearer sk-..."
# CORRECT — LiteLLM key in dedicated header, Authorization free for OAuth2
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "x-litellm-api-key: Bearer sk-..." \\
--header "x-litellm-mcp-debug: true"
**Common issue: No OAuth2 token present**
Symptom: ``x-mcp-debug-oauth2-token`` shows ``(none)`` and
``x-mcp-debug-auth-resolution`` shows ``no-auth``.
This means the client didn't go through the OAuth2 flow. Check that:
1. The ``Authorization`` header is NOT set as a static header in the client config.
2. The ``.well-known/oauth-protected-resource`` endpoint returns valid metadata.
3. The MCP server in LiteLLM config has ``auth_type: oauth2``.
**Common issue: M2M token used instead of user token**
Symptom: ``x-mcp-debug-auth-resolution`` shows ``m2m-client-credentials``.
This means the server has ``client_id``/``client_secret``/``token_url``
configured and LiteLLM is fetching a machine-to-machine token instead of
using the per-user OAuth2 token. If you want per-user tokens, remove the
client credentials from the server config.
Usage from Claude Code::
claude mcp add --transport http my_server http://proxy/mcp/server \\
--header "x-litellm-api-key: Bearer sk-..." \\
--header "x-litellm-mcp-debug: true"
Usage with curl::
curl -H "x-litellm-mcp-debug: true" \\
-H "x-litellm-api-key: Bearer sk-..." \\
http://localhost:4000/mcp/atlassian_mcp
"""
from typing import TYPE_CHECKING, Dict, List, Optional
from starlette.types import Message, Send
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
if TYPE_CHECKING:
from litellm.types.mcp_server.mcp_server_manager import MCPServer
# Header the client sends to opt into debug mode
MCP_DEBUG_REQUEST_HEADER = "x-litellm-mcp-debug"
# Prefix for all debug response headers
_RESPONSE_HEADER_PREFIX = "x-mcp-debug"
class MCPDebug:
"""
Static helper class for MCP OAuth2 debug headers.
Provides opt-in client-side diagnostics by injecting masked
authentication info into HTTP response headers.
"""
# Masker: show first 6 and last 4 chars so you can distinguish token types
# e.g. "Bearer****ef01" vs "sk-123****cdef"
_masker = SensitiveDataMasker(
sensitive_patterns={
"authorization",
"token",
"key",
"secret",
"auth",
"bearer",
},
visible_prefix=6,
visible_suffix=4,
)
@staticmethod
def _mask(value: Optional[str]) -> str:
"""Mask a single value for safe display in headers."""
if not value:
return "(none)"
return MCPDebug._masker._mask_value(value)
@staticmethod
def is_debug_enabled(headers: Dict[str, str]) -> bool:
"""
Check if the client opted into MCP debug mode.
Looks for ``x-litellm-mcp-debug: true`` (case-insensitive) in the
request headers.
"""
for key, val in headers.items():
if key.lower() == MCP_DEBUG_REQUEST_HEADER:
return val.strip().lower() in ("true", "1", "yes")
return False
@staticmethod
def resolve_auth_resolution(
server: "MCPServer",
mcp_auth_header: Optional[str],
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
oauth2_headers: Optional[Dict[str, str]],
) -> str:
"""
Determine which auth priority will be used for the outbound MCP call.
Returns one of: ``per-request-header``, ``m2m-client-credentials``,
``static-token``, ``oauth2-passthrough``, or ``no-auth``.
"""
from litellm.types.mcp import MCPAuth
has_server_specific = bool(
mcp_server_auth_headers
and (
mcp_server_auth_headers.get(server.alias or "")
or mcp_server_auth_headers.get(server.server_name or "")
)
)
if has_server_specific or mcp_auth_header:
return "per-request-header"
if server.has_client_credentials:
return "m2m-client-credentials"
if server.authentication_token:
return "static-token"
if oauth2_headers and server.auth_type == MCPAuth.oauth2:
return "oauth2-passthrough"
return "no-auth"
@staticmethod
def build_debug_headers(
*,
inbound_headers: Dict[str, str],
oauth2_headers: Optional[Dict[str, str]],
litellm_api_key: Optional[str],
auth_resolution: str,
server_url: Optional[str],
server_auth_type: Optional[str],
) -> Dict[str, str]:
"""
Build masked debug response headers.
Parameters
----------
inbound_headers : dict
Raw headers received from the MCP client.
oauth2_headers : dict or None
Extracted OAuth2 headers (``{"Authorization": "Bearer ..."}``).
litellm_api_key : str or None
The LiteLLM API key extracted from ``x-litellm-api-key`` or
``Authorization`` header.
auth_resolution : str
Which auth priority was selected for the outbound call.
server_url : str or None
Upstream MCP server URL.
server_auth_type : str or None
The ``auth_type`` configured on the server (e.g. ``oauth2``).
Returns
-------
dict
Headers to include in the response (all values masked).
"""
debug: Dict[str, str] = {}
# --- Inbound auth summary ---
inbound_parts = []
for hdr_name in ("x-litellm-api-key", "authorization", "x-mcp-auth"):
for k, v in inbound_headers.items():
if k.lower() == hdr_name:
inbound_parts.append(f"{hdr_name}={MCPDebug._mask(v)}")
break
debug[f"{_RESPONSE_HEADER_PREFIX}-inbound-auth"] = (
"; ".join(inbound_parts) if inbound_parts else "(none)"
)
# --- OAuth2 token ---
oauth2_token = (oauth2_headers or {}).get("Authorization")
if oauth2_token and litellm_api_key:
oauth2_raw = oauth2_token.removeprefix("Bearer ").strip()
litellm_raw = litellm_api_key.removeprefix("Bearer ").strip()
if oauth2_raw == litellm_raw:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = (
f"{MCPDebug._mask(oauth2_token)} "
f"(SAME_AS_LITELLM_KEY - likely misconfigured)"
)
else:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
oauth2_token
)
else:
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
oauth2_token
)
# --- Auth resolution ---
debug[f"{_RESPONSE_HEADER_PREFIX}-auth-resolution"] = auth_resolution
# --- Server info ---
debug[f"{_RESPONSE_HEADER_PREFIX}-outbound-url"] = server_url or "(unknown)"
debug[f"{_RESPONSE_HEADER_PREFIX}-server-auth-type"] = (
server_auth_type or "(none)"
)
return debug
@staticmethod
def wrap_send_with_debug_headers(send: Send, debug_headers: Dict[str, str]) -> Send:
"""
Return a new ASGI ``send`` callable that injects *debug_headers*
into the ``http.response.start`` message.
"""
async def _send_with_debug(message: Message) -> None:
if message["type"] == "http.response.start":
headers = list(message.get("headers", []))
for k, v in debug_headers.items():
headers.append((k.encode(), v.encode()))
message = {**message, "headers": headers}
await send(message)
return _send_with_debug
@staticmethod
def maybe_build_debug_headers(
*,
raw_headers: Optional[Dict[str, str]],
scope: Dict,
mcp_servers: Optional[List[str]],
mcp_auth_header: Optional[str],
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
oauth2_headers: Optional[Dict[str, str]],
client_ip: Optional[str],
) -> Dict[str, str]:
"""
Build debug headers if debug mode is enabled, otherwise return empty dict.
This is the single entry point called from the MCP request handler.
"""
if not raw_headers or not MCPDebug.is_debug_enabled(raw_headers):
return {}
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
MCPRequestHandler,
)
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
server_url: Optional[str] = None
server_auth_type: Optional[str] = None
auth_resolution = "no-auth"
for server_name in mcp_servers or []:
server = global_mcp_server_manager.get_mcp_server_by_name(
server_name, client_ip=client_ip
)
if server:
server_url = server.url
server_auth_type = server.auth_type
auth_resolution = MCPDebug.resolve_auth_resolution(
server, mcp_auth_header, mcp_server_auth_headers, oauth2_headers
)
break
scope_headers = MCPRequestHandler._safe_get_headers_from_scope(scope)
litellm_key = MCPRequestHandler.get_litellm_api_key_from_headers(scope_headers)
return MCPDebug.build_debug_headers(
inbound_headers=raw_headers,
oauth2_headers=oauth2_headers,
litellm_api_key=litellm_key,
auth_resolution=auth_resolution,
server_url=server_url,
server_auth_type=server_auth_type,
)

View File

@@ -0,0 +1,170 @@
"""
OAuth2 client_credentials token cache for MCP servers.
Automatically fetches and refreshes access tokens for MCP servers configured
with ``client_id``, ``client_secret``, and ``token_url``.
"""
import asyncio
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.constants import (
MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.custom_http import httpxSpecialProvider
if TYPE_CHECKING:
from litellm.types.mcp_server.mcp_server_manager import MCPServer
class MCPOAuth2TokenCache(InMemoryCache):
"""
In-memory cache for OAuth2 client_credentials tokens, keyed by server_id.
Inherits from ``InMemoryCache`` for TTL-based storage and eviction.
Adds per-server ``asyncio.Lock`` to prevent duplicate concurrent fetches.
"""
def __init__(self) -> None:
super().__init__(
max_size_in_memory=MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
default_ttl=MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
)
self._locks: Dict[str, asyncio.Lock] = {}
def _get_lock(self, server_id: str) -> asyncio.Lock:
return self._locks.setdefault(server_id, asyncio.Lock())
async def async_get_token(self, server: "MCPServer") -> Optional[str]:
"""Return a valid access token, fetching or refreshing as needed.
Returns ``None`` when the server lacks client credentials config.
"""
if not server.has_client_credentials:
return None
server_id = server.server_id
# Fast path — cached token is still valid
cached = self.get_cache(server_id)
if cached is not None:
return cached
# Slow path — acquire per-server lock then double-check
async with self._get_lock(server_id):
cached = self.get_cache(server_id)
if cached is not None:
return cached
token, ttl = await self._fetch_token(server)
self.set_cache(server_id, token, ttl=ttl)
return token
async def _fetch_token(self, server: "MCPServer") -> Tuple[str, int]:
"""POST to ``token_url`` with ``grant_type=client_credentials``.
Returns ``(access_token, ttl_seconds)`` where ttl accounts for the
expiry buffer so the cache entry expires before the real token does.
"""
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
if not server.client_id or not server.client_secret or not server.token_url:
raise ValueError(
f"MCP server '{server.server_id}' missing required OAuth2 fields: "
f"client_id={bool(server.client_id)}, "
f"client_secret={bool(server.client_secret)}, "
f"token_url={bool(server.token_url)}"
)
data: Dict[str, str] = {
"grant_type": "client_credentials",
"client_id": server.client_id,
"client_secret": server.client_secret,
}
if server.scopes:
data["scope"] = " ".join(server.scopes)
verbose_logger.debug(
"Fetching OAuth2 client_credentials token for MCP server %s",
server.server_id,
)
try:
response = await client.post(server.token_url, data=data)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
raise ValueError(
f"OAuth2 token request for MCP server '{server.server_id}' "
f"failed with status {exc.response.status_code}"
) from exc
body = response.json()
if not isinstance(body, dict):
raise ValueError(
f"OAuth2 token response for MCP server '{server.server_id}' "
f"returned non-object JSON (got {type(body).__name__})"
)
access_token = body.get("access_token")
if not access_token:
raise ValueError(
f"OAuth2 token response for MCP server '{server.server_id}' "
f"missing 'access_token'"
)
# Safely parse expires_in — providers may return null or non-numeric values
raw_expires_in = body.get("expires_in")
try:
expires_in = (
int(raw_expires_in)
if raw_expires_in is not None
else MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
)
except (TypeError, ValueError):
expires_in = MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
ttl = max(
expires_in - MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
)
verbose_logger.info(
"Fetched OAuth2 token for MCP server %s (expires in %ds)",
server.server_id,
expires_in,
)
return access_token, ttl
def invalidate(self, server_id: str) -> None:
"""Remove a cached token (e.g. after a 401)."""
self.delete_cache(server_id)
mcp_oauth2_token_cache = MCPOAuth2TokenCache()
async def resolve_mcp_auth(
server: "MCPServer",
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
) -> Optional[Union[str, Dict[str, str]]]:
"""Resolve the auth value for an MCP server.
Priority:
1. ``mcp_auth_header`` — per-request/per-user override
2. OAuth2 client_credentials token — auto-fetched and cached
3. ``server.authentication_token`` — static token from config/DB
"""
if mcp_auth_header:
return mcp_auth_header
if server.has_client_credentials:
return await mcp_oauth2_token_cache.async_get_token(server)
return server.authentication_token

View File

@@ -0,0 +1,435 @@
"""
This module is used to generate MCP tools from OpenAPI specs.
"""
import asyncio
import contextvars
import json
import os
from pathlib import PurePosixPath
from typing import Any, Dict, List, Optional
from urllib.parse import quote
from litellm._logging import verbose_logger
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.proxy._experimental.mcp_server.tool_registry import (
global_mcp_tool_registry,
)
# Store the base URL and headers globally
BASE_URL = ""
HEADERS: Dict[str, str] = {}
# Per-request auth header override for BYOK servers.
# Set this ContextVar before calling a local tool handler to inject the user's
# stored credential into the HTTP request made by the tool function closure.
_request_auth_header: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
"_request_auth_header", default=None
)
def _sanitize_path_parameter_value(param_value: Any, param_name: str) -> str:
"""Ensure path params cannot introduce directory traversal."""
if param_value is None:
return ""
value_str = str(param_value)
if value_str == "":
return ""
normalized_value = value_str.replace("\\", "/")
if "/" in normalized_value:
raise ValueError(
f"Path parameter '{param_name}' must not contain path separators"
)
if any(part in {".", ".."} for part in PurePosixPath(normalized_value).parts):
raise ValueError(
f"Path parameter '{param_name}' cannot include '.' or '..' segments"
)
return quote(value_str, safe="")
def load_openapi_spec(filepath: str) -> Dict[str, Any]:
"""
Sync wrapper. For URL specs, use the shared/custom MCP httpx client.
"""
try:
# If we're already inside an event loop, prefer the async function.
asyncio.get_running_loop()
raise RuntimeError(
"load_openapi_spec() was called from within a running event loop. "
"Use 'await load_openapi_spec_async(...)' instead."
)
except RuntimeError as e:
# "no running event loop" is fine; other RuntimeErrors we re-raise
if "no running event loop" not in str(e).lower():
raise
return asyncio.run(load_openapi_spec_async(filepath))
async def load_openapi_spec_async(filepath: str) -> Dict[str, Any]:
if filepath.startswith("http://") or filepath.startswith("https://"):
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
# NOTE: do not close shared client if get_async_httpx_client returns a shared singleton.
# If it returns a new client each time, consider wrapping it in an async context manager.
r = await client.get(filepath)
r.raise_for_status()
return r.json()
# fallback: local file
# Local filesystem path
if not os.path.exists(filepath):
raise FileNotFoundError(f"OpenAPI spec not found at {filepath}")
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
"""Extract base URL from OpenAPI spec."""
# OpenAPI 3.x
if "servers" in spec and spec["servers"]:
server_url = spec["servers"][0]["url"]
# If the server URL is relative (starts with /), derive base from spec_path
if server_url.startswith("/") and spec_path:
if spec_path.startswith("http://") or spec_path.startswith("https://"):
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
# Combine domain with the relative server URL
from urllib.parse import urlparse
parsed = urlparse(spec_path)
base_domain = f"{parsed.scheme}://{parsed.netloc}"
full_base_url = base_domain + server_url
verbose_logger.info(
f"OpenAPI spec has relative server URL '{server_url}'. "
f"Deriving base from spec_path: {full_base_url}"
)
return full_base_url
return server_url
# OpenAPI 2.x (Swagger)
elif "host" in spec:
scheme = spec.get("schemes", ["https"])[0]
base_path = spec.get("basePath", "")
return f"{scheme}://{spec['host']}{base_path}"
# Fallback: derive base URL from spec_path if it's a URL
if spec_path and (
spec_path.startswith("http://") or spec_path.startswith("https://")
):
for suffix in [
"/openapi.json",
"/openapi.yaml",
"/swagger.json",
"/swagger.yaml",
]:
if spec_path.endswith(suffix):
base_url = spec_path[: -len(suffix)]
verbose_logger.info(
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
)
return base_url
if spec_path.split("/")[-1].endswith((".json", ".yaml", ".yml")):
base_url = "/".join(spec_path.split("/")[:-1])
verbose_logger.info(
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
)
return base_url
return ""
def _resolve_ref(
param: Dict[str, Any], component_params: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""Resolve a single parameter, following a $ref if present.
Returns the resolved param dict, or None if the $ref target is absent from
components (so callers can skip/filter it rather than propagating a stub
with name=None that would corrupt deduplication).
"""
ref = param.get("$ref", "")
if not ref.startswith("#/components/parameters/"):
return param
return component_params.get(ref.split("/")[-1])
def _resolve_param_list(
raw: List[Dict[str, Any]], component_params: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""Resolve $refs in a parameter list, dropping any unresolvable entries."""
result = []
for p in raw:
resolved = _resolve_ref(p, component_params)
if resolved is not None and resolved.get("name"):
result.append(resolved)
return result
def resolve_operation_params(
operation: Dict[str, Any],
path_item: Dict[str, Any],
components: Dict[str, Any],
) -> Dict[str, Any]:
"""Return a copy of *operation* with fully-resolved, merged parameters.
Handles two common patterns in real-world OpenAPI specs:
1. **$ref parameters** — ``{"$ref": "#/components/parameters/per-page"}``
instead of inline objects. Each ref is resolved against
``components["parameters"]``; unresolvable refs are silently dropped so
they cannot corrupt the deduplication set with ``(None, None)`` keys.
2. **Path-level parameters** — params defined on the path item that apply
to every HTTP method on that path (e.g. ``owner``, ``repo``). They are
merged with the operation-level params; operation-level wins when the
same ``name`` + ``in`` combination appears in both.
"""
component_params = components.get("parameters", {})
path_level = _resolve_param_list(path_item.get("parameters", []), component_params)
op_level = _resolve_param_list(operation.get("parameters", []), component_params)
op_keys = {(p["name"], p.get("in")) for p in op_level}
merged = [
p for p in path_level if (p["name"], p.get("in")) not in op_keys
] + op_level
result = dict(operation)
result["parameters"] = merged
return result
def extract_parameters(operation: Dict[str, Any]) -> tuple:
"""Extract parameter names from OpenAPI operation."""
path_params = []
query_params = []
body_params = []
# OpenAPI 3.x and 2.x parameters
if "parameters" in operation:
for param in operation["parameters"]:
if "name" not in param:
continue
param_name = param["name"]
if param.get("in") == "path":
path_params.append(param_name)
elif param.get("in") == "query":
query_params.append(param_name)
elif param.get("in") == "body":
body_params.append(param_name)
# OpenAPI 3.x requestBody
if "requestBody" in operation:
body_params.append("body")
return path_params, query_params, body_params
def build_input_schema(operation: Dict[str, Any]) -> Dict[str, Any]:
"""Build MCP input schema from OpenAPI operation."""
properties = {}
required = []
# Process parameters
if "parameters" in operation:
for param in operation["parameters"]:
if "name" not in param:
continue
param_name = param["name"]
param_schema = param.get("schema", {})
param_type = param_schema.get("type", "string")
properties[param_name] = {
"type": param_type,
"description": param.get("description", ""),
}
if param.get("required", False):
required.append(param_name)
# Process requestBody (OpenAPI 3.x)
if "requestBody" in operation:
request_body = operation["requestBody"]
content = request_body.get("content", {})
# Try to get JSON schema
if "application/json" in content:
schema = content["application/json"].get("schema", {})
properties["body"] = {
"type": "object",
"description": request_body.get("description", "Request body"),
"properties": schema.get("properties", {}),
}
if request_body.get("required", False):
required.append("body")
return {
"type": "object",
"properties": properties,
"required": required if required else [],
}
def create_tool_function(
path: str,
method: str,
operation: Dict[str, Any],
base_url: str,
headers: Optional[Dict[str, str]] = None,
):
"""Create a tool function for an OpenAPI operation.
This function creates an async tool function that can be called with
keyword arguments. Parameter names from the OpenAPI spec are accessed
directly via **kwargs, avoiding syntax errors from invalid Python identifiers.
Args:
path: API endpoint path
method: HTTP method (get, post, put, delete, patch)
operation: OpenAPI operation object
base_url: Base URL for the API
headers: Optional headers to include in requests (e.g., authentication)
Returns:
An async function that accepts **kwargs and makes the HTTP request
"""
if headers is None:
headers = {}
path_params, query_params, body_params = extract_parameters(operation)
original_method = method.lower()
async def tool_function(**kwargs: Any) -> str:
"""
Dynamically generated tool function.
Accepts keyword arguments where keys are the original OpenAPI parameter names.
The function safely handles parameter names that aren't valid Python identifiers
by using **kwargs instead of named parameters.
"""
# Allow per-request auth override (e.g. BYOK credential set via ContextVar).
# The ContextVar holds the full Authorization header value, including the
# correct prefix (Bearer / ApiKey / Basic) formatted by the caller in
# server.py based on the server's configured auth_type.
effective_headers = dict(headers)
override_auth = _request_auth_header.get()
if override_auth:
effective_headers["Authorization"] = override_auth
# Build URL from base_url and path
url = base_url + path
# Replace path parameters using original names from OpenAPI spec
# Apply path traversal validation and URL encoding
for param_name in path_params:
param_value = kwargs.get(param_name, "")
if param_value:
try:
# Sanitize and encode path parameter to prevent traversal attacks
safe_value = _sanitize_path_parameter_value(param_value, param_name)
except ValueError as exc:
return "Invalid path parameter: " + str(exc)
# Replace {param_name} or {{param_name}} in URL
url = url.replace("{" + param_name + "}", safe_value)
url = url.replace("{{" + param_name + "}}", safe_value)
# Build query params using original parameter names
params: Dict[str, Any] = {}
for param_name in query_params:
param_value = kwargs.get(param_name, "")
if param_value:
# Use original parameter name in query string (as expected by API)
params[param_name] = param_value
# Build request body
json_body: Optional[Dict[str, Any]] = None
if body_params:
# Try "body" first (most common), then check all body param names
body_value = kwargs.get("body", {})
if not body_value:
for param_name in body_params:
body_value = kwargs.get(param_name, {})
if body_value:
break
if isinstance(body_value, dict):
json_body = body_value
elif body_value:
# If it's a string, try to parse as JSON
try:
json_body = (
json.loads(body_value)
if isinstance(body_value, str)
else {"data": body_value}
)
except (json.JSONDecodeError, TypeError):
json_body = {"data": body_value}
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
if original_method == "get":
response = await client.get(url, params=params, headers=effective_headers)
elif original_method == "post":
response = await client.post(
url, params=params, json=json_body, headers=effective_headers
)
elif original_method == "put":
response = await client.put(
url, params=params, json=json_body, headers=effective_headers
)
elif original_method == "delete":
response = await client.delete(
url, params=params, headers=effective_headers
)
elif original_method == "patch":
response = await client.patch(
url, params=params, json=json_body, headers=effective_headers
)
else:
return f"Unsupported HTTP method: {original_method}"
return response.text
return tool_function
def register_tools_from_openapi(spec: Dict[str, Any], base_url: str):
"""Register MCP tools from OpenAPI specification."""
paths = spec.get("paths", {})
for path, path_item in paths.items():
for method in ["get", "post", "put", "delete", "patch"]:
if method in path_item:
operation = path_item[method]
# Generate tool name
operation_id = operation.get(
"operationId", f"{method}_{path.replace('/', '_')}"
)
tool_name = operation_id.replace(" ", "_").lower()
# Get description
description = operation.get(
"summary", operation.get("description", f"{method.upper()} {path}")
)
# Build input schema
input_schema = build_input_schema(operation)
# Create tool function
tool_func = create_tool_function(path, method, operation, base_url)
tool_func.__name__ = tool_name
tool_func.__doc__ = description
# Register tool with local registry
global_mcp_tool_registry.register_tool(
name=tool_name,
description=description,
input_schema=input_schema,
handler=tool_func,
)
verbose_logger.debug(f"Registered tool: {tool_name}")

View File

@@ -0,0 +1,256 @@
"""
Semantic MCP Tool Filtering using semantic-router
Filters MCP tools semantically for /chat/completions and /responses endpoints.
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_logger
if TYPE_CHECKING:
from semantic_router.routers import SemanticRouter
from litellm.router import Router
class SemanticMCPToolFilter:
"""Filters MCP tools using semantic similarity to reduce context window size."""
def __init__(
self,
embedding_model: str,
litellm_router_instance: "Router",
top_k: int = 10,
similarity_threshold: float = 0.3,
enabled: bool = True,
):
"""
Initialize the semantic tool filter.
Args:
embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small")
litellm_router_instance: Router instance for embedding generation
top_k: Maximum number of tools to return
similarity_threshold: Minimum similarity score for filtering
enabled: Whether filtering is enabled
"""
self.enabled = enabled
self.top_k = top_k
self.similarity_threshold = similarity_threshold
self.embedding_model = embedding_model
self.router_instance = litellm_router_instance
self.tool_router: Optional["SemanticRouter"] = None
self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts
async def build_router_from_mcp_registry(self) -> None:
"""Build semantic router from all MCP tools in the registry (no auth checks)."""
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
try:
# Get all servers from registry without auth checks
registry = global_mcp_server_manager.get_registry()
if not registry:
verbose_logger.warning("MCP registry is empty")
self.tool_router = None
return
# Fetch tools from all servers in parallel
all_tools = []
for server_id, server in registry.items():
try:
tools = await global_mcp_server_manager.get_tools_for_server(
server_id
)
all_tools.extend(tools)
except Exception as e:
verbose_logger.warning(
f"Failed to fetch tools from server {server_id}: {e}"
)
continue
if not all_tools:
verbose_logger.warning("No MCP tools found in registry")
self.tool_router = None
return
verbose_logger.info(
f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers"
)
self._build_router(all_tools)
except Exception as e:
verbose_logger.error(f"Failed to build router from MCP registry: {e}")
self.tool_router = None
raise
def _extract_tool_info(self, tool) -> tuple[str, str]:
"""Extract name and description from MCP tool or OpenAI function dict."""
name: str
description: str
if isinstance(tool, dict):
# OpenAI function format
name = tool.get("name", "")
description = tool.get("description", name)
else:
# MCPTool object
name = str(tool.name)
description = str(tool.description) if tool.description else str(tool.name)
return name, description
def _build_router(self, tools: List) -> None:
"""Build semantic router with tools (MCPTool objects or OpenAI function dicts)."""
from semantic_router.routers import SemanticRouter
from semantic_router.routers.base import Route
from litellm.router_strategy.auto_router.litellm_encoder import (
LiteLLMRouterEncoder,
)
if not tools:
self.tool_router = None
return
try:
# Convert tools to routes
routes = []
self._tool_map = {}
for tool in tools:
name, description = self._extract_tool_info(tool)
self._tool_map[name] = tool
routes.append(
Route(
name=name,
description=description,
utterances=[description],
score_threshold=self.similarity_threshold,
)
)
self.tool_router = SemanticRouter(
routes=routes,
encoder=LiteLLMRouterEncoder(
litellm_router_instance=self.router_instance,
model_name=self.embedding_model,
score_threshold=self.similarity_threshold,
),
auto_sync="local",
)
verbose_logger.info(f"Built semantic router with {len(routes)} tools")
except Exception as e:
verbose_logger.error(f"Failed to build semantic router: {e}")
self.tool_router = None
raise
async def filter_tools(
self,
query: str,
available_tools: List[Any],
top_k: Optional[int] = None,
) -> List[Any]:
"""
Filter tools semantically based on query.
Args:
query: User query to match against tools
available_tools: Full list of available MCP tools
top_k: Override default top_k (optional)
Returns:
Filtered and ordered list of tools (up to top_k)
"""
# Early returns for cases where we can't/shouldn't filter
if not self.enabled:
return available_tools
if not available_tools:
return available_tools
if not query or not query.strip():
return available_tools
# Router should be built on startup - if not, something went wrong
if self.tool_router is None:
verbose_logger.warning(
"Router not initialized - was build_router_from_mcp_registry() called on startup?"
)
return available_tools
# Run semantic filtering
try:
limit = top_k or self.top_k
matches = self.tool_router(text=query, limit=limit)
matched_tool_names = self._extract_tool_names_from_matches(matches)
if not matched_tool_names:
return available_tools
return self._get_tools_by_names(matched_tool_names, available_tools)
except Exception as e:
verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True)
return available_tools
def _extract_tool_names_from_matches(self, matches) -> List[str]:
"""Extract tool names from semantic router match results."""
if not matches:
return []
# Handle single match
if hasattr(matches, "name") and matches.name:
return [matches.name]
# Handle list of matches
if isinstance(matches, list):
return [m.name for m in matches if hasattr(m, "name") and m.name]
return []
def _get_tools_by_names(
self, tool_names: List[str], available_tools: List[Any]
) -> List[Any]:
"""Get tools from available_tools by their names, preserving order."""
# Match tools from available_tools (preserves format - dict or MCPTool)
matched_tools = []
for tool in available_tools:
tool_name, _ = self._extract_tool_info(tool)
if tool_name in tool_names:
matched_tools.append(tool)
# Reorder to match semantic router's ordering
tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools}
return [tool_map[name] for name in tool_names if name in tool_map]
def extract_user_query(self, messages: List[Dict[str, Any]]) -> str:
"""
Extract user query from messages for /chat/completions or /responses.
Args:
messages: List of message dictionaries (from 'messages' or 'input' field)
Returns:
Extracted query string
"""
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
texts = [
block.get("text", "") if isinstance(block, dict) else str(block)
for block in content
if isinstance(block, (dict, str))
]
return " ".join(texts)
return ""

View File

@@ -0,0 +1,150 @@
"""
This is a modification of code from: https://github.com/SecretiveShell/MCP-Bridge/blob/master/mcp_bridge/mcp_server/sse_transport.py
Credit to the maintainers of SecretiveShell for their SSE Transport implementation
"""
from contextlib import asynccontextmanager
from typing import Any
from urllib.parse import quote
from uuid import UUID, uuid4
import anyio
import mcp.types as types
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
from fastapi.requests import Request
from fastapi.responses import Response
from pydantic import ValidationError
from sse_starlette import EventSourceResponse
from starlette.types import Receive, Scope, Send
from litellm._logging import verbose_logger
class SseServerTransport:
"""
SSE server transport for MCP. This class provides _two_ ASGI applications,
suitable to be used with a framework like Starlette and a server like Hypercorn:
1. connect_sse() is an ASGI application which receives incoming GET requests,
and sets up a new SSE stream to send server messages to the client.
2. handle_post_message() is an ASGI application which receives incoming POST
requests, which should contain client messages that link to a
previously-established SSE session.
"""
_endpoint: str
_read_stream_writers: dict[
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
]
def __init__(self, endpoint: str) -> None:
"""
Creates a new SSE server transport, which will direct the client to POST
messages to the relative or absolute URL given.
"""
super().__init__()
self._endpoint = endpoint
self._read_stream_writers = {}
verbose_logger.debug(
f"SseServerTransport initialized with endpoint: {endpoint}"
)
@asynccontextmanager
async def connect_sse(self, request: Request):
if request.scope["type"] != "http":
verbose_logger.error("connect_sse received non-HTTP request")
raise ValueError("connect_sse can only handle HTTP requests")
verbose_logger.debug("Setting up SSE connection")
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
verbose_logger.debug(f"Created new session with ID: {session_id}")
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]]
sse_stream_reader: MemoryObjectReceiveStream[dict[str, Any]]
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
0, dict[str, Any]
)
async def sse_writer():
verbose_logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
verbose_logger.debug(f"Sent endpoint event: {session_uri}")
async for message in write_stream_reader:
verbose_logger.debug(f"Sending message via SSE: {message}")
await sse_stream_writer.send(
{
"event": "message",
"data": message.model_dump_json(
by_alias=True, exclude_none=True
),
}
)
async with anyio.create_task_group() as tg:
response = EventSourceResponse(
content=sse_stream_reader, data_sender_callable=sse_writer
)
verbose_logger.debug("Starting SSE response task")
tg.start_soon(response, request.scope, request.receive, request._send)
verbose_logger.debug("Yielding read and write streams")
yield (read_stream, write_stream)
async def handle_post_message(
self, scope: Scope, receive: Receive, send: Send
) -> Response:
verbose_logger.debug("Handling POST message")
request = Request(scope, receive)
session_id_param = request.query_params.get("session_id")
if session_id_param is None:
verbose_logger.warning("Received request without session_id")
response = Response("session_id is required", status_code=400)
return response
try:
session_id = UUID(hex=session_id_param)
verbose_logger.debug(f"Parsed session ID: {session_id}")
except ValueError:
verbose_logger.warning(f"Received invalid session ID: {session_id_param}")
response = Response("Invalid session ID", status_code=400)
return response
writer = self._read_stream_writers.get(session_id)
if not writer:
verbose_logger.warning(f"Could not find session for ID: {session_id}")
response = Response("Could not find session", status_code=404)
return response
json = await request.json()
verbose_logger.debug(f"Received JSON: {json}")
try:
message = types.JSONRPCMessage.model_validate(json)
verbose_logger.debug(f"Validated client message: {message}")
except ValidationError as err:
verbose_logger.error(f"Failed to parse message: {err}")
response = Response("Could not parse message", status_code=400)
await writer.send(err)
return response
verbose_logger.debug(f"Sending message to writer: {message}")
response = Response("Accepted", status_code=202)
await writer.send(message)
return response

View File

@@ -0,0 +1,133 @@
import json
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from litellm._logging import verbose_logger
from litellm.proxy.types_utils.utils import get_instance_fn
from litellm.types.mcp_server.tool_registry import MCPTool
if TYPE_CHECKING:
from mcp.types import Tool as MCPToolSDKTool
else:
try:
from mcp.types import Tool as MCPToolSDKTool
except ImportError:
MCPToolSDKTool = None # type: ignore
class MCPToolRegistry:
"""
A registry for managing MCP tools
"""
def __init__(self):
# Registry to store all registered tools
self.tools: Dict[str, MCPTool] = {}
def register_tool(
self,
name: str,
description: str,
input_schema: Dict[str, Any],
handler: Callable,
) -> None:
"""
Register a new tool in the registry
"""
self.tools[name] = MCPTool(
name=name,
description=description,
input_schema=input_schema,
handler=handler,
)
verbose_logger.debug(f"Registered tool: {name}")
def get_tool(self, name: str) -> Optional[MCPTool]:
"""
Get a tool from the registry by name
"""
return self.tools.get(name)
def list_tools(self, tool_prefix: Optional[str] = None) -> List[MCPTool]:
"""
List all registered tools
"""
if tool_prefix:
return [
tool
for tool in self.tools.values()
if tool.name.startswith(tool_prefix)
]
return list(self.tools.values())
def convert_tools_to_mcp_sdk_tool_type(
self, tools: List[MCPTool]
) -> List["MCPToolSDKTool"]:
if MCPToolSDKTool is None:
raise ImportError(
"MCP SDK is not installed. Please install it with: pip install 'litellm[proxy]'"
)
return [
MCPToolSDKTool(
name=tool.name,
description=tool.description,
inputSchema=tool.input_schema,
)
for tool in tools
]
def load_tools_from_config(
self, mcp_tools_config: Optional[Dict[str, Any]] = None
) -> None:
"""
Load and register tools from the proxy config
Args:
mcp_tools_config: The mcp_tools config from the proxy config
"""
if mcp_tools_config is None:
raise ValueError(
"mcp_tools_config is required, please set `mcp_tools` in your proxy config"
)
for tool_config in mcp_tools_config:
if not isinstance(tool_config, dict):
raise ValueError("mcp_tools_config must be a list of dictionaries")
name = tool_config.get("name")
description = tool_config.get("description")
input_schema = tool_config.get("input_schema", {})
handler_name = tool_config.get("handler")
if not all([name, description, handler_name]):
continue
# Try to resolve the handler
# First check if it's a module path (e.g., "module.submodule.function")
if handler_name is None:
raise ValueError(f"handler is required for tool {name}")
handler = get_instance_fn(handler_name)
if handler is None:
verbose_logger.warning(
f"Warning: Could not find handler {handler_name} for tool {name}"
)
continue
# Register the tool
if name is None:
raise ValueError(f"name is required for tool {name}")
if description is None:
raise ValueError(f"description is required for tool {name}")
self.register_tool(
name=name,
description=description,
input_schema=input_schema,
handler=handler,
)
verbose_logger.debug(
"all registered tools: %s", json.dumps(self.tools, indent=4, default=str)
)
global_mcp_tool_registry = MCPToolRegistry()

View File

@@ -0,0 +1,85 @@
"""Helpers to resolve real team contexts for UI session tokens."""
from __future__ import annotations
from typing import List
from litellm._logging import verbose_logger
from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
from litellm.proxy._types import UserAPIKeyAuth
def clone_user_api_key_auth_with_team(
user_api_key_auth: UserAPIKeyAuth,
team_id: str,
) -> UserAPIKeyAuth:
"""Return a deep copy of the auth context with a different team id."""
try:
cloned_auth = user_api_key_auth.model_copy()
except AttributeError:
cloned_auth = user_api_key_auth.copy() # type: ignore[attr-defined]
cloned_auth.team_id = team_id
return cloned_auth
async def resolve_ui_session_team_ids(
user_api_key_auth: UserAPIKeyAuth,
) -> List[str]:
"""Resolve the real team ids backing a UI session token."""
if (
user_api_key_auth.team_id != UI_SESSION_TOKEN_TEAM_ID
or not user_api_key_auth.user_id
):
return []
from litellm.proxy.auth.auth_checks import get_user_object
from litellm.proxy.proxy_server import (
prisma_client,
proxy_logging_obj,
user_api_key_cache,
)
if prisma_client is None:
verbose_logger.debug("Cannot resolve UI session team ids without DB access")
return []
try:
user_obj = await get_user_object(
user_id=user_api_key_auth.user_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
user_id_upsert=False,
parent_otel_span=user_api_key_auth.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except Exception as exc: # pragma: no cover - defensive logging
verbose_logger.warning(
"Failed to load teams for UI session token user.",
exc,
)
return []
if user_obj is None or not user_obj.teams:
return []
resolved_team_ids: List[str] = []
for team_id in user_obj.teams:
if team_id and team_id not in resolved_team_ids:
resolved_team_ids.append(team_id)
return resolved_team_ids
async def build_effective_auth_contexts(
user_api_key_auth: UserAPIKeyAuth,
) -> List[UserAPIKeyAuth]:
"""Return auth contexts that reflect the actual teams for UI session tokens."""
resolved_team_ids = await resolve_ui_session_team_ids(user_api_key_auth)
if resolved_team_ids:
return [
clone_user_api_key_auth_with_team(user_api_key_auth, team_id)
for team_id in resolved_team_ids
]
return [user_api_key_auth]

View File

@@ -0,0 +1,167 @@
"""
MCP Server Utilities
"""
from typing import Any, Dict, Mapping, Optional, Tuple
import os
import importlib
# Constants
LITELLM_MCP_SERVER_NAME = "litellm-mcp-server"
LITELLM_MCP_SERVER_VERSION = "1.0.0"
LITELLM_MCP_SERVER_DESCRIPTION = "MCP Server for LiteLLM"
MCP_TOOL_PREFIX_SEPARATOR = os.environ.get("MCP_TOOL_PREFIX_SEPARATOR", "-")
MCP_TOOL_PREFIX_FORMAT = "{server_name}{separator}{tool_name}"
def is_mcp_available() -> bool:
"""
Returns True if the MCP module is available, False otherwise
"""
try:
importlib.import_module("mcp")
return True
except ImportError:
return False
def normalize_server_name(server_name: str) -> str:
"""
Normalize server name by replacing spaces with underscores
"""
return server_name.replace(" ", "_")
def validate_and_normalize_mcp_server_payload(payload: Any) -> None:
"""
Validate and normalize MCP server payload fields (server_name and alias).
This function:
1. Validates that server_name and alias don't contain the MCP_TOOL_PREFIX_SEPARATOR
2. Normalizes alias by replacing spaces with underscores
3. Sets default alias if not provided (using server_name as base)
Args:
payload: The payload object containing server_name and alias fields
Raises:
HTTPException: If validation fails
"""
# Server name validation: disallow '-'
if hasattr(payload, "server_name") and payload.server_name:
validate_mcp_server_name(payload.server_name, raise_http_exception=True)
# Alias validation: disallow '-'
if hasattr(payload, "alias") and payload.alias:
validate_mcp_server_name(payload.alias, raise_http_exception=True)
# Alias normalization and defaulting
alias = getattr(payload, "alias", None)
server_name = getattr(payload, "server_name", None)
if not alias and server_name:
alias = normalize_server_name(server_name)
elif alias:
alias = normalize_server_name(alias)
# Update the payload with normalized alias
if hasattr(payload, "alias"):
payload.alias = alias
def add_server_prefix_to_name(name: str, server_name: str) -> str:
"""Add server name prefix to any MCP resource name."""
formatted_server_name = normalize_server_name(server_name)
return MCP_TOOL_PREFIX_FORMAT.format(
server_name=formatted_server_name,
separator=MCP_TOOL_PREFIX_SEPARATOR,
tool_name=name,
)
def get_server_prefix(server: Any) -> str:
"""Return the prefix for a server: alias if present, else server_name, else server_id"""
if hasattr(server, "alias") and server.alias:
return server.alias
if hasattr(server, "server_name") and server.server_name:
return server.server_name
if hasattr(server, "server_id"):
return server.server_id
return ""
def split_server_prefix_from_name(prefixed_name: str) -> Tuple[str, str]:
"""Return the unprefixed name plus the server name used as prefix."""
if MCP_TOOL_PREFIX_SEPARATOR in prefixed_name:
parts = prefixed_name.split(MCP_TOOL_PREFIX_SEPARATOR, 1)
if len(parts) == 2:
return parts[1], parts[0]
return prefixed_name, ""
def is_tool_name_prefixed(tool_name: str) -> bool:
"""
Check if tool name has server prefix
Args:
tool_name: Tool name to check
Returns:
True if tool name is prefixed, False otherwise
"""
return MCP_TOOL_PREFIX_SEPARATOR in tool_name
def validate_mcp_server_name(
server_name: str, raise_http_exception: bool = False
) -> None:
"""
Validate that MCP server name does not contain 'MCP_TOOL_PREFIX_SEPARATOR'.
Args:
server_name: The server name to validate
raise_http_exception: If True, raises HTTPException instead of generic Exception
Raises:
Exception or HTTPException: If server name contains 'MCP_TOOL_PREFIX_SEPARATOR'
"""
if server_name and MCP_TOOL_PREFIX_SEPARATOR in server_name:
error_message = f"Server name cannot contain '{MCP_TOOL_PREFIX_SEPARATOR}'. Use an alternative character instead Found: {server_name}"
if raise_http_exception:
from fastapi import HTTPException
from starlette import status
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail={"error": error_message}
)
else:
raise Exception(error_message)
def merge_mcp_headers(
*,
extra_headers: Optional[Mapping[str, str]] = None,
static_headers: Optional[Mapping[str, str]] = None,
) -> Optional[Dict[str, str]]:
"""Merge outbound HTTP headers for MCP calls.
This is used when calling out to external MCP servers (or OpenAPI-based MCP tools).
Merge rules:
- Start with `extra_headers` (typically OAuth2-derived headers)
- Overlay `static_headers` (user-configured per MCP server)
If both contain the same key, `static_headers` wins. This matches the existing
behavior in `MCPServerManager` where `server.static_headers` is applied after
any caller-provided headers.
"""
merged: Dict[str, str] = {}
if extra_headers:
merged.update({str(k): str(v) for k, v in extra_headers.items()})
if static_headers:
merged.update({str(k): str(v) for k, v in static_headers.items()})
return merged or None

View File

@@ -0,0 +1,4 @@
def my_custom_rule(input): # receives the model response
# if len(input) < 5: # trigger fallback if the model response is too short
return False
return True