chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
class MCPAuthenticatedUser(AuthenticatedUser):
|
||||
"""
|
||||
Wrapper class to make LiteLLM's authentication and configuration compatible with MCP's AuthenticatedUser.
|
||||
|
||||
This class handles:
|
||||
1. User API key authentication information
|
||||
2. MCP authentication header (deprecated)
|
||||
3. MCP server configuration (can include access groups)
|
||||
4. Server-specific authentication headers
|
||||
5. OAuth2 headers
|
||||
6. Raw headers - allows forwarding specific headers to the MCP server, specified by the admin.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
mcp_auth_header: Optional[str] = None,
|
||||
mcp_servers: Optional[List[str]] = None,
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]] = None,
|
||||
oauth2_headers: Optional[Dict[str, str]] = None,
|
||||
mcp_protocol_version: Optional[str] = None,
|
||||
raw_headers: Optional[Dict[str, str]] = None,
|
||||
client_ip: Optional[str] = None,
|
||||
):
|
||||
self.user_api_key_auth = user_api_key_auth
|
||||
self.mcp_auth_header = mcp_auth_header
|
||||
self.mcp_servers = mcp_servers
|
||||
self.mcp_server_auth_headers = mcp_server_auth_headers or {}
|
||||
self.mcp_protocol_version = mcp_protocol_version
|
||||
self.oauth2_headers = oauth2_headers
|
||||
self.raw_headers = raw_headers
|
||||
self.client_ip = client_ip
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,789 @@
|
||||
"""
|
||||
BYOK (Bring Your Own Key) OAuth 2.1 Authorization Server endpoints for MCP servers.
|
||||
|
||||
When an MCP client connects to a BYOK-enabled server and no stored credential exists,
|
||||
LiteLLM runs a minimal OAuth 2.1 authorization code flow. The "authorization page" is
|
||||
just a form that asks the user for their API key — not a full identity-provider OAuth.
|
||||
|
||||
Endpoints implemented here:
|
||||
GET /.well-known/oauth-authorization-server — OAuth authorization server metadata
|
||||
GET /.well-known/oauth-protected-resource — OAuth protected resource metadata
|
||||
GET /v1/mcp/oauth/authorize — Shows HTML form to collect the API key
|
||||
POST /v1/mcp/oauth/authorize — Stores temp auth code and redirects
|
||||
POST /v1/mcp/oauth/token — Exchanges code for a bearer JWT token
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import html as _html_module
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, Optional, cast
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import jwt
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._experimental.mcp_server.db import store_user_credential
|
||||
from litellm.proxy._experimental.mcp_server.discoverable_endpoints import (
|
||||
get_request_base_url,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory store for pending authorization codes.
|
||||
# Each entry: {code: {api_key, server_id, code_challenge, redirect_uri, user_id, expires_at}}
|
||||
# ---------------------------------------------------------------------------
|
||||
_byok_auth_codes: Dict[str, dict] = {}
|
||||
|
||||
# Authorization codes expire after 5 minutes.
|
||||
_AUTH_CODE_TTL_SECONDS = 300
|
||||
# Hard cap to prevent memory exhaustion from incomplete OAuth flows.
|
||||
_AUTH_CODES_MAX_SIZE = 1000
|
||||
|
||||
router = APIRouter(tags=["mcp"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PKCE helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _verify_pkce(code_verifier: str, code_challenge: str) -> bool:
|
||||
"""Return True iff SHA-256(code_verifier) == code_challenge (base64url, no padding)."""
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
computed = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return computed == code_challenge
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cleanup of expired auth codes (called lazily on each request)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _purge_expired_codes() -> None:
|
||||
now = time.time()
|
||||
expired = [k for k, v in _byok_auth_codes.items() if v["expires_at"] < now]
|
||||
for k in expired:
|
||||
del _byok_auth_codes[k]
|
||||
|
||||
|
||||
def _build_authorize_html(
|
||||
server_name: str,
|
||||
server_initial: str,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
code_challenge: str,
|
||||
code_challenge_method: str,
|
||||
state: str,
|
||||
server_id: str,
|
||||
access_items: list,
|
||||
help_url: str,
|
||||
) -> str:
|
||||
"""Build the 2-step BYOK OAuth authorization page HTML."""
|
||||
|
||||
# Escape all user-supplied / externally-derived values before interpolation
|
||||
e = _html_module.escape
|
||||
server_name = e(server_name)
|
||||
server_initial = e(server_initial)
|
||||
client_id = e(client_id)
|
||||
redirect_uri = e(redirect_uri)
|
||||
code_challenge = e(code_challenge)
|
||||
code_challenge_method = e(code_challenge_method)
|
||||
state = e(state)
|
||||
server_id = e(server_id)
|
||||
|
||||
# Build access checklist rows
|
||||
access_rows = "".join(
|
||||
f'<div class="access-item"><span class="check">✓</span>{e(item)}</div>'
|
||||
for item in access_items
|
||||
)
|
||||
access_section = ""
|
||||
if access_rows:
|
||||
access_section = f"""
|
||||
<div class="access-box">
|
||||
<div class="access-header">
|
||||
<span class="shield">▮</span>
|
||||
<span>Requested Access</span>
|
||||
</div>
|
||||
{access_rows}
|
||||
</div>"""
|
||||
|
||||
# Help link for step 2
|
||||
help_link_html = ""
|
||||
if help_url:
|
||||
help_link_html = f'<a class="help-link" href="{e(help_url)}" target="_blank">Where do I find my API key? ↗</a>'
|
||||
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Connect {server_name} — LiteLLM</title>
|
||||
<style>
|
||||
*, *::before, *::after {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;
|
||||
background: #0f172a;
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 24px;
|
||||
}}
|
||||
.modal {{
|
||||
background: #ffffff;
|
||||
border-radius: 20px;
|
||||
padding: 36px 32px 32px;
|
||||
width: 440px;
|
||||
max-width: 100%;
|
||||
position: relative;
|
||||
box-shadow: 0 25px 60px rgba(0,0,0,0.35);
|
||||
}}
|
||||
/* Progress dots */
|
||||
.dots {{
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
gap: 7px;
|
||||
margin-bottom: 28px;
|
||||
}}
|
||||
.dot {{
|
||||
width: 8px; height: 8px;
|
||||
border-radius: 50%;
|
||||
background: #e2e8f0;
|
||||
}}
|
||||
.dot.active {{ background: #38bdf8; }}
|
||||
/* Close button */
|
||||
.close-btn {{
|
||||
position: absolute;
|
||||
top: 16px; right: 16px;
|
||||
background: none; border: none;
|
||||
font-size: 16px; color: #94a3b8;
|
||||
cursor: pointer; line-height: 1;
|
||||
width: 28px; height: 28px;
|
||||
border-radius: 6px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
}}
|
||||
.close-btn:hover {{ background: #f1f5f9; color: #475569; }}
|
||||
/* Logo pair */
|
||||
.logos {{
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
gap: 12px; margin-bottom: 20px;
|
||||
}}
|
||||
.logo {{
|
||||
width: 52px; height: 52px;
|
||||
border-radius: 14px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
font-size: 22px; font-weight: 800; color: white;
|
||||
}}
|
||||
.logo-img {{
|
||||
width: 52px; height: 52px;
|
||||
border-radius: 14px;
|
||||
object-fit: cover;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
}}
|
||||
.logo-s {{ background: linear-gradient(135deg, #818cf8 0%, #4f46e5 100%); }}
|
||||
.logo-arrow {{ color: #cbd5e1; font-size: 20px; font-weight: 300; }}
|
||||
/* Headings */
|
||||
.step-title {{
|
||||
text-align: center;
|
||||
font-size: 21px; font-weight: 700;
|
||||
color: #0f172a; margin-bottom: 8px;
|
||||
}}
|
||||
.step-subtitle {{
|
||||
text-align: center;
|
||||
font-size: 14px; color: #64748b;
|
||||
line-height: 1.55; margin-bottom: 22px;
|
||||
}}
|
||||
/* Info box */
|
||||
.info-box {{
|
||||
background: #f8fafc;
|
||||
border-radius: 12px;
|
||||
padding: 14px 16px;
|
||||
display: flex; gap: 12px;
|
||||
margin-bottom: 14px;
|
||||
}}
|
||||
.info-icon {{ font-size: 17px; flex-shrink: 0; margin-top: 1px; color: #38bdf8; }}
|
||||
.info-box h4 {{ font-size: 13px; font-weight: 600; color: #1e293b; margin-bottom: 4px; }}
|
||||
.info-box p {{ font-size: 13px; color: #64748b; line-height: 1.5; }}
|
||||
/* Access checklist */
|
||||
.access-box {{
|
||||
background: #f8fafc;
|
||||
border-radius: 12px;
|
||||
padding: 14px 16px;
|
||||
margin-bottom: 22px;
|
||||
}}
|
||||
.access-header {{
|
||||
display: flex; align-items: center; gap: 8px;
|
||||
margin-bottom: 10px;
|
||||
}}
|
||||
.shield {{ color: #22c55e; font-size: 15px; }}
|
||||
.access-header > span:last-child {{
|
||||
font-size: 11px; font-weight: 700;
|
||||
letter-spacing: 0.07em;
|
||||
text-transform: uppercase;
|
||||
color: #475569;
|
||||
}}
|
||||
.access-item {{
|
||||
display: flex; align-items: center; gap: 9px;
|
||||
font-size: 13.5px; color: #374151;
|
||||
padding: 3px 0;
|
||||
}}
|
||||
.check {{ color: #22c55e; font-weight: 700; font-size: 13px; }}
|
||||
/* Primary CTA */
|
||||
.btn-primary {{
|
||||
width: 100%; padding: 15px;
|
||||
background: #0f172a; color: white;
|
||||
border: none; border-radius: 12px;
|
||||
font-size: 15px; font-weight: 600;
|
||||
cursor: pointer; margin-bottom: 10px;
|
||||
}}
|
||||
.btn-primary:hover {{ background: #1e293b; }}
|
||||
.btn-cancel {{
|
||||
width: 100%; padding: 8px;
|
||||
background: none; border: none;
|
||||
font-size: 13.5px; color: #94a3b8;
|
||||
cursor: pointer;
|
||||
}}
|
||||
.btn-cancel:hover {{ color: #64748b; }}
|
||||
/* Step 2 nav */
|
||||
.step2-nav {{
|
||||
display: flex; align-items: center;
|
||||
justify-content: space-between;
|
||||
margin-bottom: 24px;
|
||||
}}
|
||||
.back-btn {{
|
||||
background: none; border: none;
|
||||
font-size: 13.5px; color: #64748b;
|
||||
cursor: pointer; display: flex; align-items: center; gap: 4px;
|
||||
}}
|
||||
.back-btn:hover {{ color: #374151; }}
|
||||
/* Key icon */
|
||||
.key-icon-wrap {{
|
||||
width: 46px; height: 46px;
|
||||
background: #e0f2fe;
|
||||
border-radius: 12px;
|
||||
display: flex; align-items: center; justify-content: center;
|
||||
margin-bottom: 14px;
|
||||
}}
|
||||
.key-icon-wrap svg {{ width: 22px; height: 22px; color: #0284c7; }}
|
||||
/* Form elements */
|
||||
.field-label {{
|
||||
font-size: 13.5px; font-weight: 600;
|
||||
color: #1e293b; display: block;
|
||||
margin-bottom: 7px;
|
||||
}}
|
||||
.key-input {{
|
||||
width: 100%; padding: 11px 13px;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 10px;
|
||||
font-size: 14px; color: #0f172a;
|
||||
outline: none; transition: border-color 0.15s, box-shadow 0.15s;
|
||||
}}
|
||||
.key-input:focus {{
|
||||
border-color: #38bdf8;
|
||||
box-shadow: 0 0 0 3px rgba(56,189,248,0.12);
|
||||
}}
|
||||
.help-link {{
|
||||
display: inline-flex; align-items: center; gap: 4px;
|
||||
color: #0ea5e9; font-size: 13px;
|
||||
text-decoration: none; margin: 8px 0 16px;
|
||||
}}
|
||||
.help-link:hover {{ text-decoration: underline; }}
|
||||
/* Save toggle card */
|
||||
.save-card {{
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 12px;
|
||||
padding: 13px 15px;
|
||||
margin-bottom: 6px;
|
||||
}}
|
||||
.save-row {{
|
||||
display: flex; align-items: center; gap: 10px;
|
||||
}}
|
||||
.save-icon {{ font-size: 16px; }}
|
||||
.save-label {{
|
||||
flex: 1;
|
||||
font-size: 14px; font-weight: 500; color: #1e293b;
|
||||
}}
|
||||
/* Toggle switch */
|
||||
.toggle {{ position: relative; width: 44px; height: 24px; flex-shrink: 0; }}
|
||||
.toggle input {{ opacity: 0; width: 0; height: 0; }}
|
||||
.slider {{
|
||||
position: absolute; inset: 0;
|
||||
background: #e2e8f0;
|
||||
border-radius: 24px; cursor: pointer;
|
||||
transition: background 0.18s;
|
||||
}}
|
||||
.slider::before {{
|
||||
content: '';
|
||||
position: absolute;
|
||||
width: 18px; height: 18px;
|
||||
left: 3px; bottom: 3px;
|
||||
background: white;
|
||||
border-radius: 50%;
|
||||
transition: transform 0.18s;
|
||||
box-shadow: 0 1px 3px rgba(0,0,0,0.18);
|
||||
}}
|
||||
input:checked + .slider {{ background: #38bdf8; }}
|
||||
input:checked + .slider::before {{ transform: translateX(20px); }}
|
||||
/* Duration pills */
|
||||
.duration-section {{ margin-top: 14px; }}
|
||||
.duration-label {{
|
||||
font-size: 12px; font-weight: 600;
|
||||
color: #64748b; margin-bottom: 8px;
|
||||
text-transform: uppercase; letter-spacing: 0.05em;
|
||||
}}
|
||||
.pills {{ display: flex; flex-wrap: wrap; gap: 7px; }}
|
||||
.pill {{
|
||||
padding: 6px 13px;
|
||||
border: 1.5px solid #e2e8f0;
|
||||
border-radius: 20px;
|
||||
font-size: 13px; color: #475569;
|
||||
cursor: pointer; background: white;
|
||||
transition: all 0.13s;
|
||||
user-select: none;
|
||||
}}
|
||||
.pill:hover {{ border-color: #94a3b8; }}
|
||||
.pill.sel {{
|
||||
border-color: #38bdf8;
|
||||
color: #0284c7;
|
||||
background: #e0f2fe;
|
||||
}}
|
||||
/* Security note */
|
||||
.sec-note {{
|
||||
background: #f8fafc;
|
||||
border-radius: 10px;
|
||||
padding: 11px 14px;
|
||||
display: flex; gap: 9px; align-items: flex-start;
|
||||
margin: 16px 0;
|
||||
}}
|
||||
.sec-icon {{ font-size: 13px; color: #94a3b8; margin-top: 1px; flex-shrink: 0; }}
|
||||
.sec-note p {{ font-size: 12.5px; color: #64748b; line-height: 1.5; }}
|
||||
/* Connect button */
|
||||
.btn-connect {{
|
||||
width: 100%; padding: 15px;
|
||||
border: none; border-radius: 12px;
|
||||
font-size: 15px; font-weight: 600;
|
||||
cursor: pointer;
|
||||
background: #bae6fd; color: #0369a1;
|
||||
transition: background 0.15s, color 0.15s;
|
||||
}}
|
||||
.btn-connect.ready {{
|
||||
background: #0ea5e9; color: white;
|
||||
}}
|
||||
.btn-connect.ready:hover {{ background: #0284c7; }}
|
||||
/* Step visibility */
|
||||
.step {{ display: none; }}
|
||||
.step.show {{ display: block; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="modal">
|
||||
|
||||
<!-- ── STEP 1: Connect ─────────────────────────────────────── -->
|
||||
<div id="s1" class="step show">
|
||||
<div class="dots">
|
||||
<div class="dot active"></div>
|
||||
<div class="dot"></div>
|
||||
</div>
|
||||
<button class="close-btn" type="button" onclick="doCancel()" title="Close">×</button>
|
||||
|
||||
<div class="logos">
|
||||
<img src="/ui/assets/logos/litellm_logo.jpg" class="logo-img" alt="LiteLLM">
|
||||
<span class="logo-arrow">→</span>
|
||||
<div class="logo logo-s">{server_initial}</div>
|
||||
</div>
|
||||
|
||||
<h2 class="step-title">Connect {server_name} MCP</h2>
|
||||
<p class="step-subtitle">LiteLLM needs access to {server_name} to complete your request.</p>
|
||||
|
||||
<div class="info-box">
|
||||
<span class="info-icon">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="12" y1="8" x2="12" y2="12"/><line x1="12" y1="16" x2="12.01" y2="16"/></svg>
|
||||
</span>
|
||||
<div>
|
||||
<h4>How it works</h4>
|
||||
<p>LiteLLM acts as a secure bridge. Your requests are routed through our MCP client directly to {server_name}’s API.</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{access_section}
|
||||
|
||||
<button class="btn-primary" type="button" onclick="goStep2()">
|
||||
Continue to Authentication →
|
||||
</button>
|
||||
<button class="btn-cancel" type="button" onclick="doCancel()">Cancel</button>
|
||||
</div>
|
||||
|
||||
<!-- ── STEP 2: Provide API Key ──────────────────────────────── -->
|
||||
<div id="s2" class="step">
|
||||
<div class="step2-nav">
|
||||
<button class="back-btn" type="button" onclick="goStep1()">← Back</button>
|
||||
<div class="dots">
|
||||
<div class="dot active"></div>
|
||||
<div class="dot active"></div>
|
||||
</div>
|
||||
<button class="close-btn" style="position:static;" type="button" onclick="doCancel()" title="Close">×</button>
|
||||
</div>
|
||||
|
||||
<div class="key-icon-wrap">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="none" stroke="#0284c7" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M21 2l-2 2m-7.61 7.61a5.5 5.5 0 1 1-7.778 7.778 5.5 5.5 0 0 1 7.777-7.777zm0 0L15.5 7.5m0 0l3 3L22 7l-3-3m-3.5 3.5L19 4"/></svg>
|
||||
</div>
|
||||
<h2 class="step-title" style="text-align:left;">Provide API Key</h2>
|
||||
<p class="step-subtitle" style="text-align:left;">Enter your {server_name} API key to authorize this connection.</p>
|
||||
|
||||
<form method="POST" id="authForm" onsubmit="prepareSubmit()">
|
||||
<input type="hidden" name="client_id" value="{client_id}">
|
||||
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
|
||||
<input type="hidden" name="code_challenge" value="{code_challenge}">
|
||||
<input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
|
||||
<input type="hidden" name="state" value="{state}">
|
||||
<input type="hidden" name="server_id" value="{server_id}">
|
||||
<input type="hidden" name="duration" id="durInput" value="until_revoked">
|
||||
|
||||
<label class="field-label">{server_name} API Key</label>
|
||||
<input
|
||||
type="password"
|
||||
name="api_key"
|
||||
id="apiKey"
|
||||
class="key-input"
|
||||
placeholder="Enter your API key"
|
||||
required
|
||||
autofocus
|
||||
oninput="syncBtn()"
|
||||
>
|
||||
|
||||
{help_link_html}
|
||||
|
||||
<div class="save-card">
|
||||
<div class="save-row">
|
||||
<span class="save-label">Save key for future use</span>
|
||||
<label class="toggle">
|
||||
<input type="checkbox" id="saveToggle" onchange="toggleDur()">
|
||||
<span class="slider"></span>
|
||||
</label>
|
||||
</div>
|
||||
<div id="durSection" class="duration-section" style="display:none;">
|
||||
<div class="duration-label">Duration</div>
|
||||
<div class="pills">
|
||||
<div class="pill" onclick="selDur('1h',this)">1 hour</div>
|
||||
<div class="pill sel" onclick="selDur('24h',this)">24 hours</div>
|
||||
<div class="pill" onclick="selDur('7d',this)">7 days</div>
|
||||
<div class="pill" onclick="selDur('30d',this)">30 days</div>
|
||||
<div class="pill" onclick="selDur('until_revoked',this)">Until I revoke</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="sec-note">
|
||||
<span class="sec-icon">
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="14" height="14" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="11" width="18" height="11" rx="2" ry="2"/><path d="M7 11V7a5 5 0 0 1 10 0v4"/></svg>
|
||||
</span>
|
||||
<p>Your key is stored securely and transmitted over HTTPS. It is never shared with third parties.</p>
|
||||
</div>
|
||||
|
||||
<button type="submit" class="btn-connect" id="connectBtn">
|
||||
Connect & Authorize
|
||||
</button>
|
||||
</form>
|
||||
</div>
|
||||
|
||||
</div>
|
||||
<script>
|
||||
function goStep2() {{
|
||||
document.getElementById('s1').classList.remove('show');
|
||||
document.getElementById('s2').classList.add('show');
|
||||
}}
|
||||
function goStep1() {{
|
||||
document.getElementById('s2').classList.remove('show');
|
||||
document.getElementById('s1').classList.add('show');
|
||||
}}
|
||||
function doCancel() {{
|
||||
if (window.opener) window.close();
|
||||
else window.history.back();
|
||||
}}
|
||||
function toggleDur() {{
|
||||
const on = document.getElementById('saveToggle').checked;
|
||||
document.getElementById('durSection').style.display = on ? 'block' : 'none';
|
||||
}}
|
||||
function selDur(val, el) {{
|
||||
document.querySelectorAll('.pill').forEach(p => p.classList.remove('sel'));
|
||||
el.classList.add('sel');
|
||||
document.getElementById('durInput').value = val;
|
||||
}}
|
||||
function syncBtn() {{
|
||||
const btn = document.getElementById('connectBtn');
|
||||
if (document.getElementById('apiKey').value.length > 0) {{
|
||||
btn.classList.add('ready');
|
||||
}} else {{
|
||||
btn.classList.remove('ready');
|
||||
}}
|
||||
}}
|
||||
function prepareSubmit() {{
|
||||
// nothing extra needed — duration is already in the hidden input
|
||||
}}
|
||||
</script>
|
||||
</body>
|
||||
</html>"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OAuth metadata discovery endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/.well-known/oauth-authorization-server", include_in_schema=False)
|
||||
async def oauth_authorization_server_metadata(request: Request) -> JSONResponse:
|
||||
"""RFC 8414 Authorization Server Metadata for the BYOK OAuth flow."""
|
||||
base_url = get_request_base_url(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"issuer": base_url,
|
||||
"authorization_endpoint": f"{base_url}/v1/mcp/oauth/authorize",
|
||||
"token_endpoint": f"{base_url}/v1/mcp/oauth/token",
|
||||
"response_types_supported": ["code"],
|
||||
"grant_types_supported": ["authorization_code"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/.well-known/oauth-protected-resource", include_in_schema=False)
|
||||
async def oauth_protected_resource_metadata(request: Request) -> JSONResponse:
|
||||
"""RFC 9728 Protected Resource Metadata pointing back at this server."""
|
||||
base_url = get_request_base_url(request)
|
||||
return JSONResponse(
|
||||
{
|
||||
"resource": base_url,
|
||||
"authorization_servers": [base_url],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Authorization endpoint — GET (show form) and POST (process form)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get("/v1/mcp/oauth/authorize", include_in_schema=False)
|
||||
async def byok_authorize_get(
|
||||
request: Request,
|
||||
client_id: Optional[str] = None,
|
||||
redirect_uri: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
server_id: Optional[str] = None,
|
||||
) -> HTMLResponse:
|
||||
"""
|
||||
Show the BYOK API-key entry form.
|
||||
|
||||
The MCP client navigates the user here; the user types their API key and
|
||||
clicks "Connect & Authorize", which POSTs back to this same path.
|
||||
"""
|
||||
if response_type != "code":
|
||||
raise HTTPException(status_code=400, detail="response_type must be 'code'")
|
||||
if not redirect_uri:
|
||||
raise HTTPException(status_code=400, detail="redirect_uri is required")
|
||||
if not code_challenge:
|
||||
raise HTTPException(status_code=400, detail="code_challenge is required")
|
||||
|
||||
# Resolve server metadata (name, description items, help URL).
|
||||
server_name = "MCP Server"
|
||||
access_items: list = []
|
||||
help_url = ""
|
||||
if server_id:
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
registry = global_mcp_server_manager.get_registry()
|
||||
if server_id in registry:
|
||||
srv = registry[server_id]
|
||||
server_name = srv.server_name or srv.name
|
||||
access_items = list(srv.byok_description or [])
|
||||
help_url = srv.byok_api_key_help_url or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
server_initial = (server_name[0].upper()) if server_name else "S"
|
||||
|
||||
html = _build_authorize_html(
|
||||
server_name=server_name,
|
||||
server_initial=server_initial,
|
||||
client_id=client_id or "",
|
||||
redirect_uri=redirect_uri,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method or "S256",
|
||||
state=state or "",
|
||||
server_id=server_id or "",
|
||||
access_items=access_items,
|
||||
help_url=help_url,
|
||||
)
|
||||
return HTMLResponse(content=html)
|
||||
|
||||
|
||||
@router.post("/v1/mcp/oauth/authorize", include_in_schema=False)
|
||||
async def byok_authorize_post(
|
||||
request: Request,
|
||||
client_id: str = Form(default=""),
|
||||
redirect_uri: str = Form(...),
|
||||
code_challenge: str = Form(...),
|
||||
code_challenge_method: str = Form(default="S256"),
|
||||
state: str = Form(default=""),
|
||||
server_id: str = Form(default=""),
|
||||
api_key: str = Form(...),
|
||||
) -> RedirectResponse:
|
||||
"""
|
||||
Process the BYOK API-key form submission.
|
||||
|
||||
Stores a short-lived authorization code and redirects the client back to
|
||||
redirect_uri with ?code=...&state=... query parameters.
|
||||
"""
|
||||
_purge_expired_codes()
|
||||
|
||||
# Validate redirect_uri scheme to prevent open redirect
|
||||
parsed_uri = urlparse(redirect_uri)
|
||||
if parsed_uri.scheme not in ("http", "https"):
|
||||
raise HTTPException(status_code=400, detail="Invalid redirect_uri scheme")
|
||||
|
||||
# Reject new codes if the store is at capacity (prevents memory exhaustion
|
||||
# from a burst of abandoned OAuth flows).
|
||||
if len(_byok_auth_codes) >= _AUTH_CODES_MAX_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Too many pending authorization flows"
|
||||
)
|
||||
|
||||
if code_challenge_method != "S256":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only S256 code_challenge_method is supported"
|
||||
)
|
||||
|
||||
auth_code = str(uuid.uuid4())
|
||||
_byok_auth_codes[auth_code] = {
|
||||
"api_key": api_key,
|
||||
"server_id": server_id,
|
||||
"code_challenge": code_challenge,
|
||||
"redirect_uri": redirect_uri,
|
||||
"user_id": client_id, # external client passes LiteLLM user-id as client_id
|
||||
"expires_at": time.time() + _AUTH_CODE_TTL_SECONDS,
|
||||
}
|
||||
|
||||
params = urlencode({"code": auth_code, "state": state})
|
||||
separator = "&" if "?" in redirect_uri else "?"
|
||||
location = f"{redirect_uri}{separator}{params}"
|
||||
return RedirectResponse(url=location, status_code=302)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Token endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.post("/v1/mcp/oauth/token", include_in_schema=False)
|
||||
async def byok_token(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: str = Form(...),
|
||||
redirect_uri: str = Form(default=""),
|
||||
code_verifier: str = Form(...),
|
||||
client_id: str = Form(default=""),
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
Exchange an authorization code for a short-lived BYOK session JWT.
|
||||
|
||||
1. Validates the authorization code and PKCE challenge.
|
||||
2. Stores the API key via store_user_credential().
|
||||
3. Issues a signed JWT with type="byok_session".
|
||||
"""
|
||||
from litellm.proxy.proxy_server import master_key, prisma_client
|
||||
|
||||
_purge_expired_codes()
|
||||
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="unsupported_grant_type")
|
||||
|
||||
record = _byok_auth_codes.get(code)
|
||||
if record is None:
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
if time.time() > record["expires_at"]:
|
||||
del _byok_auth_codes[code]
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
# PKCE verification
|
||||
if not _verify_pkce(code_verifier, record["code_challenge"]):
|
||||
raise HTTPException(status_code=400, detail="invalid_grant")
|
||||
|
||||
# Consume the code (one-time use)
|
||||
del _byok_auth_codes[code]
|
||||
|
||||
server_id: str = record["server_id"]
|
||||
api_key_value: str = record["api_key"]
|
||||
# Prefer the user_id that was stored when the code was issued; fall back to
|
||||
# whatever client_id the token request supplies (they should match).
|
||||
user_id: str = record.get("user_id") or client_id
|
||||
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot determine user_id; pass LiteLLM user id as client_id",
|
||||
)
|
||||
|
||||
# Persist the BYOK credential
|
||||
if prisma_client is not None:
|
||||
try:
|
||||
await store_user_credential(
|
||||
prisma_client=prisma_client,
|
||||
user_id=user_id,
|
||||
server_id=server_id,
|
||||
credential=api_key_value,
|
||||
)
|
||||
# Invalidate any cached negative result so the user isn't blocked
|
||||
# for up to the TTL period after completing the OAuth flow.
|
||||
from litellm.proxy._experimental.mcp_server.server import (
|
||||
_invalidate_byok_cred_cache,
|
||||
)
|
||||
|
||||
_invalidate_byok_cred_cache(user_id, server_id)
|
||||
except Exception as exc:
|
||||
verbose_proxy_logger.error(
|
||||
"byok_token: failed to store user credential for user=%s server=%s: %s",
|
||||
user_id,
|
||||
server_id,
|
||||
exc,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to store credential")
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"byok_token: prisma_client is None — credential not persisted"
|
||||
)
|
||||
|
||||
if master_key is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Master key not configured; cannot issue token"
|
||||
)
|
||||
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
# "type" distinguishes this from regular proxy auth tokens.
|
||||
# The proxy's SSO JWT path uses asymmetric keys (RS256/ES256), so an
|
||||
# HS256 token signed with master_key cannot be accepted there.
|
||||
"type": "byok_session",
|
||||
"iat": now,
|
||||
"exp": now + 3600,
|
||||
}
|
||||
access_token = jwt.encode(payload, cast(str, master_key), algorithm="HS256")
|
||||
|
||||
return JSONResponse(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"token_type": "bearer",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Cost calculator for MCP tools.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||
|
||||
from litellm.types.mcp import MCPServerCostInfo
|
||||
from litellm.types.utils import StandardLoggingMCPToolCall
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LitellmLoggingObject,
|
||||
)
|
||||
else:
|
||||
LitellmLoggingObject = Any
|
||||
|
||||
|
||||
class MCPCostCalculator:
|
||||
@staticmethod
|
||||
def calculate_mcp_tool_call_cost(
|
||||
litellm_logging_obj: Optional[LitellmLoggingObject],
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of an MCP tool call.
|
||||
|
||||
Default is 0.0, unless user specifies a custom cost per request for MCP tools.
|
||||
"""
|
||||
if litellm_logging_obj is None:
|
||||
return 0.0
|
||||
|
||||
#########################################################
|
||||
# Get the response cost from logging object model_call_details
|
||||
# This is set when a user modifies the response in a post_mcp_tool_call_hook
|
||||
#########################################################
|
||||
response_cost = litellm_logging_obj.model_call_details.get(
|
||||
"response_cost", None
|
||||
)
|
||||
if response_cost is not None:
|
||||
return response_cost
|
||||
|
||||
#########################################################
|
||||
# Unpack the mcp_tool_call_metadata
|
||||
#########################################################
|
||||
mcp_tool_call_metadata: StandardLoggingMCPToolCall = (
|
||||
cast(
|
||||
StandardLoggingMCPToolCall,
|
||||
litellm_logging_obj.model_call_details.get(
|
||||
"mcp_tool_call_metadata", {}
|
||||
),
|
||||
)
|
||||
or {}
|
||||
)
|
||||
mcp_server_cost_info: MCPServerCostInfo = (
|
||||
mcp_tool_call_metadata.get("mcp_server_cost_info") or MCPServerCostInfo()
|
||||
)
|
||||
#########################################################
|
||||
# User defined cost per query
|
||||
#########################################################
|
||||
default_cost_per_query = mcp_server_cost_info.get(
|
||||
"default_cost_per_query", None
|
||||
)
|
||||
tool_name_to_cost_per_query: dict = (
|
||||
mcp_server_cost_info.get("tool_name_to_cost_per_query", {}) or {}
|
||||
)
|
||||
tool_name = mcp_tool_call_metadata.get("name", "")
|
||||
|
||||
#########################################################
|
||||
# 1. If tool_name is in tool_name_to_cost_per_query, use the cost per query
|
||||
# 2. If tool_name is not in tool_name_to_cost_per_query, use the default cost per query
|
||||
# 3. Default to 0.0 if no cost per query is found
|
||||
#########################################################
|
||||
cost_per_query: float = 0.0
|
||||
if tool_name in tool_name_to_cost_per_query:
|
||||
cost_per_query = tool_name_to_cost_per_query[tool_name]
|
||||
elif default_cost_per_query is not None:
|
||||
cost_per_query = default_cost_per_query
|
||||
return cost_per_query
|
||||
@@ -0,0 +1,767 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
MCPApprovalStatus,
|
||||
MCPSubmissionsSummary,
|
||||
NewMCPServerRequest,
|
||||
SpecialMCPServerName,
|
||||
UpdateMCPServerRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
_get_salt_key,
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
|
||||
|
||||
def _prepare_mcp_server_data(
|
||||
data: Union[NewMCPServerRequest, UpdateMCPServerRequest],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Helper function to prepare MCP server data for database operations.
|
||||
Handles JSON field serialization for mcp_info and env fields.
|
||||
|
||||
Args:
|
||||
data: NewMCPServerRequest or UpdateMCPServerRequest object
|
||||
|
||||
Returns:
|
||||
Dict with properly serialized JSON fields
|
||||
"""
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Convert model to dict
|
||||
data_dict = data.model_dump(exclude_none=True)
|
||||
# Ensure alias is always present in the dict (even if None)
|
||||
if "alias" not in data_dict:
|
||||
data_dict["alias"] = getattr(data, "alias", None)
|
||||
|
||||
# Handle credentials serialization
|
||||
credentials = data_dict.get("credentials")
|
||||
if credentials is not None:
|
||||
data_dict["credentials"] = encrypt_credentials(
|
||||
credentials=credentials, encryption_key=_get_salt_key()
|
||||
)
|
||||
data_dict["credentials"] = safe_dumps(data_dict["credentials"])
|
||||
|
||||
# Handle static_headers serialization
|
||||
if data.static_headers is not None:
|
||||
data_dict["static_headers"] = safe_dumps(data.static_headers)
|
||||
|
||||
# Handle mcp_info serialization
|
||||
if data.mcp_info is not None:
|
||||
data_dict["mcp_info"] = safe_dumps(data.mcp_info)
|
||||
|
||||
# Handle env serialization
|
||||
if data.env is not None:
|
||||
data_dict["env"] = safe_dumps(data.env)
|
||||
|
||||
# Handle tool name override serialization
|
||||
if data.tool_name_to_display_name is not None:
|
||||
data_dict["tool_name_to_display_name"] = safe_dumps(
|
||||
data.tool_name_to_display_name
|
||||
)
|
||||
if data.tool_name_to_description is not None:
|
||||
data_dict["tool_name_to_description"] = safe_dumps(
|
||||
data.tool_name_to_description
|
||||
)
|
||||
|
||||
# mcp_access_groups is already List[str], no serialization needed
|
||||
|
||||
# Force include is_byok even when False (exclude_none=True would not drop it,
|
||||
# but be explicit to ensure a False value is always written to the DB).
|
||||
data_dict["is_byok"] = getattr(data, "is_byok", False)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def encrypt_credentials(
|
||||
credentials: MCPCredentials, encryption_key: Optional[str]
|
||||
) -> MCPCredentials:
|
||||
auth_value = credentials.get("auth_value")
|
||||
if auth_value is not None:
|
||||
credentials["auth_value"] = encrypt_value_helper(
|
||||
value=auth_value,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_id = credentials.get("client_id")
|
||||
if client_id is not None:
|
||||
credentials["client_id"] = encrypt_value_helper(
|
||||
value=client_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
client_secret = credentials.get("client_secret")
|
||||
if client_secret is not None:
|
||||
credentials["client_secret"] = encrypt_value_helper(
|
||||
value=client_secret,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# AWS SigV4 credential fields
|
||||
aws_access_key_id = credentials.get("aws_access_key_id")
|
||||
if aws_access_key_id is not None:
|
||||
credentials["aws_access_key_id"] = encrypt_value_helper(
|
||||
value=aws_access_key_id,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_secret_access_key = credentials.get("aws_secret_access_key")
|
||||
if aws_secret_access_key is not None:
|
||||
credentials["aws_secret_access_key"] = encrypt_value_helper(
|
||||
value=aws_secret_access_key,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
aws_session_token = credentials.get("aws_session_token")
|
||||
if aws_session_token is not None:
|
||||
credentials["aws_session_token"] = encrypt_value_helper(
|
||||
value=aws_session_token,
|
||||
new_encryption_key=encryption_key,
|
||||
)
|
||||
# aws_region_name and aws_service_name are NOT secrets — stored as-is
|
||||
return credentials
|
||||
|
||||
|
||||
def decrypt_credentials(
|
||||
credentials: MCPCredentials,
|
||||
) -> MCPCredentials:
|
||||
"""Decrypt all secret fields in an MCPCredentials dict using the global salt key."""
|
||||
secret_fields = [
|
||||
"auth_value",
|
||||
"client_id",
|
||||
"client_secret",
|
||||
"aws_access_key_id",
|
||||
"aws_secret_access_key",
|
||||
"aws_session_token",
|
||||
]
|
||||
for field in secret_fields:
|
||||
value = credentials.get(field) # type: ignore[literal-required]
|
||||
if value is not None and isinstance(value, str):
|
||||
credentials[field] = decrypt_value_helper( # type: ignore[literal-required]
|
||||
value=value,
|
||||
key=field,
|
||||
exception_type="debug",
|
||||
return_original_value=True,
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
async def get_all_mcp_servers(
|
||||
prisma_client: PrismaClient,
|
||||
approval_status: Optional[str] = None,
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns mcp servers from the db, optionally filtered by approval_status.
|
||||
Pass approval_status=None to return all servers regardless of approval state.
|
||||
"""
|
||||
try:
|
||||
where: Dict[str, Any] = {}
|
||||
if approval_status is not None:
|
||||
where["approval_status"] = approval_status
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where=where if where else {}
|
||||
)
|
||||
|
||||
return [
|
||||
LiteLLM_MCPServerTable(**mcp_server.model_dump())
|
||||
for mcp_server in mcp_servers
|
||||
]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
async def get_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> Optional[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns the matching mcp server from the db iff exists
|
||||
"""
|
||||
mcp_server: Optional[
|
||||
LiteLLM_MCPServerTable
|
||||
] = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
}
|
||||
)
|
||||
return mcp_server
|
||||
|
||||
|
||||
async def get_mcp_servers(
|
||||
prisma_client: PrismaClient, server_ids: Iterable[str]
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Returns the matching mcp servers from the db with the server_ids
|
||||
"""
|
||||
_mcp_servers: List[
|
||||
LiteLLM_MCPServerTable
|
||||
] = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where={
|
||||
"server_id": {"in": server_ids},
|
||||
}
|
||||
)
|
||||
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
|
||||
for _mcp_server in _mcp_servers:
|
||||
final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump()))
|
||||
|
||||
return final_mcp_servers
|
||||
|
||||
|
||||
async def get_mcp_servers_by_verificationtoken(
|
||||
prisma_client: PrismaClient, token: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the mcp servers from the db for the verification token
|
||||
"""
|
||||
verification_token_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
if (
|
||||
verification_token_record is not None
|
||||
and verification_token_record.object_permission is not None
|
||||
):
|
||||
mcp_servers = verification_token_record.object_permission.mcp_servers
|
||||
return mcp_servers or []
|
||||
|
||||
|
||||
async def get_mcp_servers_by_team(
|
||||
prisma_client: PrismaClient, team_id: str
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns the mcp servers from the db for the team id
|
||||
"""
|
||||
team_record: LiteLLM_TeamTable = (
|
||||
await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={
|
||||
"team_id": team_id,
|
||||
},
|
||||
include={
|
||||
"object_permission": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
mcp_servers: Optional[List[str]] = []
|
||||
if team_record is not None and team_record.object_permission is not None:
|
||||
mcp_servers = team_record.object_permission.mcp_servers
|
||||
return mcp_servers or []
|
||||
|
||||
|
||||
async def get_all_mcp_servers_for_user(
|
||||
prisma_client: PrismaClient,
|
||||
user: UserAPIKeyAuth,
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Get all the mcp servers filtered by the given user has access to.
|
||||
|
||||
Following Least-Privilege Principle - the requestor should only be able to see the mcp servers that they have access to.
|
||||
"""
|
||||
|
||||
mcp_server_ids: Set[str] = set()
|
||||
mcp_servers = []
|
||||
|
||||
# Get the mcp servers for the key
|
||||
if user.api_key:
|
||||
token_mcp_servers = await get_mcp_servers_by_verificationtoken(
|
||||
prisma_client, user.api_key
|
||||
)
|
||||
mcp_server_ids.update(token_mcp_servers)
|
||||
|
||||
# check for special team membership
|
||||
if (
|
||||
SpecialMCPServerName.all_team_servers in mcp_server_ids
|
||||
and user.team_id is not None
|
||||
):
|
||||
team_mcp_servers = await get_mcp_servers_by_team(
|
||||
prisma_client, user.team_id
|
||||
)
|
||||
mcp_server_ids.update(team_mcp_servers)
|
||||
|
||||
if len(mcp_server_ids) > 0:
|
||||
mcp_servers = await get_mcp_servers(prisma_client, mcp_server_ids)
|
||||
|
||||
return mcp_servers
|
||||
|
||||
|
||||
async def get_objectpermissions_for_mcp_server(
|
||||
prisma_client: PrismaClient, mcp_server_id: str
|
||||
) -> List[LiteLLM_ObjectPermissionTable]:
|
||||
"""
|
||||
Get all the object permissions records and the associated team and verficiationtoken records that have access to the mcp server
|
||||
"""
|
||||
object_permission_records = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": mcp_server_id},
|
||||
},
|
||||
include={
|
||||
"teams": True,
|
||||
"verification_tokens": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return object_permission_records
|
||||
|
||||
|
||||
async def get_virtualkeys_for_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> List:
|
||||
"""
|
||||
Get all the virtual keys that have access to the mcp server
|
||||
"""
|
||||
virtual_keys = await prisma_client.db.litellm_verificationtoken.find_many(
|
||||
where={
|
||||
"mcp_servers": {"has": server_id},
|
||||
},
|
||||
)
|
||||
|
||||
if virtual_keys is None:
|
||||
return []
|
||||
return virtual_keys
|
||||
|
||||
|
||||
async def delete_mcp_server_from_team(prisma_client: PrismaClient, server_id: str):
|
||||
"""
|
||||
Remove the mcp server from the team
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def delete_mcp_server_from_virtualkey():
|
||||
"""
|
||||
Remove the mcp server from the virtual key
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
async def delete_mcp_server(
|
||||
prisma_client: PrismaClient, server_id: str
|
||||
) -> Optional[LiteLLM_MCPServerTable]:
|
||||
"""
|
||||
Delete the mcp server from the db by server_id
|
||||
|
||||
Returns the deleted mcp server record if it exists, otherwise None
|
||||
"""
|
||||
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
},
|
||||
)
|
||||
return deleted_server
|
||||
|
||||
|
||||
async def create_mcp_server(
|
||||
prisma_client: PrismaClient, data: NewMCPServerRequest, touched_by: str
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""
|
||||
Create a new mcp server record in the db
|
||||
"""
|
||||
if data.server_id is None:
|
||||
data.server_id = str(uuid.uuid4())
|
||||
|
||||
# Use helper to prepare data with proper JSON serialization
|
||||
data_dict = _prepare_mcp_server_data(data)
|
||||
|
||||
# Add audit fields
|
||||
data_dict["created_by"] = touched_by
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
new_mcp_server = await prisma_client.db.litellm_mcpservertable.create(
|
||||
data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
return new_mcp_server
|
||||
|
||||
|
||||
async def update_mcp_server(
|
||||
prisma_client: PrismaClient, data: UpdateMCPServerRequest, touched_by: str
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""
|
||||
Update a new mcp server record in the db
|
||||
"""
|
||||
import json
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
# Use helper to prepare data with proper JSON serialization
|
||||
data_dict = _prepare_mcp_server_data(data)
|
||||
|
||||
# Pre-fetch existing record once if we need it for auth_type or credential logic
|
||||
existing = None
|
||||
has_credentials = (
|
||||
"credentials" in data_dict and data_dict["credentials"] is not None
|
||||
)
|
||||
if data.auth_type or has_credentials:
|
||||
existing = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={"server_id": data.server_id}
|
||||
)
|
||||
|
||||
# Clear stale credentials when auth_type changes but no new credentials provided
|
||||
if (
|
||||
data.auth_type
|
||||
and "credentials" not in data_dict
|
||||
and existing
|
||||
and existing.auth_type is not None
|
||||
and existing.auth_type != data.auth_type
|
||||
):
|
||||
data_dict["credentials"] = None
|
||||
|
||||
# Merge credentials: preserve existing fields not present in the update.
|
||||
# Without this, a partial credential update (e.g. changing only region)
|
||||
# would wipe encrypted secrets that the UI cannot display back.
|
||||
if "credentials" in data_dict and data_dict["credentials"] is not None:
|
||||
if existing and existing.credentials:
|
||||
# Only merge when auth_type is unchanged. Switching auth types
|
||||
# (e.g. oauth2 → api_key) should replace credentials entirely
|
||||
# to avoid stale secrets from the previous auth type lingering.
|
||||
auth_type_unchanged = (
|
||||
data.auth_type is None or data.auth_type == existing.auth_type
|
||||
)
|
||||
if auth_type_unchanged:
|
||||
existing_creds = (
|
||||
json.loads(existing.credentials)
|
||||
if isinstance(existing.credentials, str)
|
||||
else dict(existing.credentials)
|
||||
)
|
||||
new_creds = (
|
||||
json.loads(data_dict["credentials"])
|
||||
if isinstance(data_dict["credentials"], str)
|
||||
else dict(data_dict["credentials"])
|
||||
)
|
||||
# New values override existing; existing keys not in update are preserved
|
||||
merged = {**existing_creds, **new_creds}
|
||||
data_dict["credentials"] = safe_dumps(merged)
|
||||
|
||||
# Add audit fields
|
||||
data_dict["updated_by"] = touched_by
|
||||
|
||||
updated_mcp_server = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": data.server_id}, data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
return updated_mcp_server
|
||||
|
||||
|
||||
async def rotate_mcp_server_credentials_master_key(
|
||||
prisma_client: PrismaClient, touched_by: str, new_master_key: str
|
||||
):
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
|
||||
for mcp_server in mcp_servers:
|
||||
credentials = mcp_server.credentials
|
||||
if not credentials:
|
||||
continue
|
||||
|
||||
credentials_copy = dict(credentials)
|
||||
# Decrypt with current key first, then re-encrypt with new key
|
||||
decrypted_credentials = decrypt_credentials(
|
||||
credentials=cast(MCPCredentials, credentials_copy),
|
||||
)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=decrypted_credentials,
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
serialized_credentials = safe_dumps(encrypted_credentials)
|
||||
|
||||
await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": mcp_server.server_id},
|
||||
data={
|
||||
"credentials": serialized_credentials,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def store_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
credential: str,
|
||||
) -> None:
|
||||
"""Store a user credential for a BYOK MCP server."""
|
||||
|
||||
encoded = base64.urlsafe_b64encode(credential.encode()).decode()
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
"credential_b64": encoded,
|
||||
},
|
||||
"update": {"credential_b64": encoded},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def get_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Return credential for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
return base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
except Exception:
|
||||
# Fall back to nacl decryption for credentials stored by older code
|
||||
return decrypt_value_helper(
|
||||
value=row.credential_b64,
|
||||
key="byok_credential",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
|
||||
|
||||
async def has_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> bool:
|
||||
"""Return True if the user has a stored credential for this server."""
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
return row is not None
|
||||
|
||||
|
||||
async def delete_user_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> None:
|
||||
"""Delete the user's stored credential for a BYOK MCP server."""
|
||||
await prisma_client.db.litellm_mcpusercredentials.delete(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth2 user-credential helpers ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def store_user_oauth_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
access_token: str,
|
||||
refresh_token: Optional[str] = None,
|
||||
expires_in: Optional[int] = None,
|
||||
scopes: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Persist an OAuth2 access token for a user+server pair.
|
||||
|
||||
The payload is JSON-serialised and stored base64-encoded in the same
|
||||
``credential_b64`` column used by BYOK. A ``"type": "oauth2"`` key
|
||||
differentiates it from plain BYOK API keys.
|
||||
"""
|
||||
|
||||
expires_at: Optional[str] = None
|
||||
if expires_in is not None:
|
||||
expires_at = (
|
||||
datetime.now(timezone.utc) + timedelta(seconds=expires_in)
|
||||
).isoformat()
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"type": "oauth2",
|
||||
"access_token": access_token,
|
||||
"connected_at": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
if refresh_token:
|
||||
payload["refresh_token"] = refresh_token
|
||||
if expires_at:
|
||||
payload["expires_at"] = expires_at
|
||||
if scopes:
|
||||
payload["scopes"] = scopes
|
||||
|
||||
# Guard against silently overwriting a BYOK credential with an OAuth token.
|
||||
# BYOK credentials lack a "type" field (or use a non-"oauth2" type).
|
||||
existing = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if existing is not None:
|
||||
_byok_error = ValueError(
|
||||
f"A non-OAuth2 credential already exists for user {user_id} "
|
||||
f"and server {server_id}. Refusing to overwrite."
|
||||
)
|
||||
try:
|
||||
raw = json.loads(base64.urlsafe_b64decode(existing.credential_b64).decode())
|
||||
except Exception:
|
||||
# Credential is not base64+JSON — it's a plain-text BYOK key.
|
||||
raise _byok_error
|
||||
if raw.get("type") != "oauth2":
|
||||
raise _byok_error
|
||||
|
||||
encoded = base64.urlsafe_b64encode(json.dumps(payload).encode()).decode()
|
||||
await prisma_client.db.litellm_mcpusercredentials.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
"credential_b64": encoded,
|
||||
},
|
||||
"update": {"credential_b64": encoded},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
|
||||
"""Return True if the OAuth2 credential's access_token has expired.
|
||||
|
||||
Checks the ``expires_at`` ISO-format string stored in the credential payload.
|
||||
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
|
||||
"""
|
||||
expires_at = cred.get("expires_at")
|
||||
if not expires_at:
|
||||
return False
|
||||
try:
|
||||
exp_dt = datetime.fromisoformat(expires_at)
|
||||
if exp_dt.tzinfo is None:
|
||||
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
|
||||
return datetime.now(timezone.utc) > exp_dt
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
async def get_user_oauth_credential(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Return the decoded OAuth2 payload dict for a user+server pair, or None."""
|
||||
|
||||
row = await prisma_client.db.litellm_mcpusercredentials.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
parsed = json.loads(decoded)
|
||||
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
|
||||
return parsed
|
||||
# Row exists but is a BYOK (plain string), not an OAuth token
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def list_user_oauth_credentials(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Return all OAuth2 credential payloads for a user, tagged with server_id."""
|
||||
|
||||
rows = await prisma_client.db.litellm_mcpusercredentials.find_many(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
results: List[Dict[str, Any]] = []
|
||||
for row in rows:
|
||||
try:
|
||||
decoded = base64.urlsafe_b64decode(row.credential_b64).decode()
|
||||
parsed = json.loads(decoded)
|
||||
if isinstance(parsed, dict) and parsed.get("type") == "oauth2":
|
||||
parsed["server_id"] = row.server_id
|
||||
results.append(parsed)
|
||||
except Exception:
|
||||
pass # Skip non-OAuth rows (BYOK plain strings)
|
||||
return results
|
||||
|
||||
|
||||
async def approve_mcp_server(
|
||||
prisma_client: PrismaClient,
|
||||
server_id: str,
|
||||
touched_by: str,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Set approval_status=active and record reviewed_at."""
|
||||
now = datetime.now(timezone.utc)
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": server_id},
|
||||
data={
|
||||
"approval_status": MCPApprovalStatus.active,
|
||||
"reviewed_at": now,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
|
||||
|
||||
async def reject_mcp_server(
|
||||
prisma_client: PrismaClient,
|
||||
server_id: str,
|
||||
touched_by: str,
|
||||
review_notes: Optional[str] = None,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Set approval_status=rejected, record reviewed_at and review_notes."""
|
||||
now = datetime.now(timezone.utc)
|
||||
data: Dict[str, Any] = {
|
||||
"approval_status": MCPApprovalStatus.rejected,
|
||||
"reviewed_at": now,
|
||||
"updated_by": touched_by,
|
||||
}
|
||||
if review_notes is not None:
|
||||
data["review_notes"] = review_notes
|
||||
updated = await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": server_id},
|
||||
data=data,
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
|
||||
|
||||
async def get_mcp_submissions(
|
||||
prisma_client: PrismaClient,
|
||||
) -> MCPSubmissionsSummary:
|
||||
"""
|
||||
Returns all MCP servers that were submitted by non-admin users (submitted_at IS NOT NULL),
|
||||
along with a summary count breakdown by approval_status.
|
||||
Mirrors get_guardrail_submissions() from guardrail_endpoints.py.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpservertable.find_many(
|
||||
where={"submitted_at": {"not": None}},
|
||||
order={"submitted_at": "desc"},
|
||||
take=500, # safety cap; paginate if needed in a future iteration
|
||||
)
|
||||
items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows]
|
||||
|
||||
pending = sum(
|
||||
1 for i in items if i.approval_status == MCPApprovalStatus.pending_review
|
||||
)
|
||||
active = sum(1 for i in items if i.approval_status == MCPApprovalStatus.active)
|
||||
rejected = sum(1 for i in items if i.approval_status == MCPApprovalStatus.rejected)
|
||||
|
||||
return MCPSubmissionsSummary(
|
||||
total=len(items),
|
||||
pending_review=pending,
|
||||
active=active,
|
||||
rejected=rejected,
|
||||
items=items,
|
||||
)
|
||||
@@ -0,0 +1,741 @@
|
||||
import json
|
||||
from typing import Optional
|
||||
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
|
||||
|
||||
from fastapi import APIRouter, Form, HTTPException, Request
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import get_server_root_path
|
||||
from litellm.types.mcp import MCPAuth
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
router = APIRouter(
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
|
||||
def get_request_base_url(request: Request) -> str:
|
||||
"""
|
||||
Get the base URL for the request, considering X-Forwarded-* headers.
|
||||
|
||||
When behind a proxy (like nginx), the proxy may set:
|
||||
- X-Forwarded-Proto: The original protocol (http/https)
|
||||
- X-Forwarded-Host: The original host (may include port)
|
||||
- X-Forwarded-Port: The original port (if not in Host header)
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
|
||||
Returns:
|
||||
The reconstructed base URL (e.g., "https://proxy.example.com")
|
||||
"""
|
||||
base_url = str(request.base_url).rstrip("/")
|
||||
parsed = urlparse(base_url)
|
||||
|
||||
# Get forwarded headers
|
||||
x_forwarded_proto = request.headers.get("X-Forwarded-Proto")
|
||||
x_forwarded_host = request.headers.get("X-Forwarded-Host")
|
||||
x_forwarded_port = request.headers.get("X-Forwarded-Port")
|
||||
|
||||
# Start with the original scheme
|
||||
scheme = x_forwarded_proto if x_forwarded_proto else parsed.scheme
|
||||
|
||||
# Handle host and port
|
||||
if x_forwarded_host:
|
||||
# X-Forwarded-Host may already include port (e.g., "example.com:8080")
|
||||
if ":" in x_forwarded_host and not x_forwarded_host.startswith("["):
|
||||
# Host includes port
|
||||
netloc = x_forwarded_host
|
||||
elif x_forwarded_port:
|
||||
# Port is separate
|
||||
netloc = f"{x_forwarded_host}:{x_forwarded_port}"
|
||||
else:
|
||||
# Just host, no explicit port
|
||||
netloc = x_forwarded_host
|
||||
else:
|
||||
# No X-Forwarded-Host, use original netloc
|
||||
netloc = parsed.netloc
|
||||
if x_forwarded_port and ":" not in netloc:
|
||||
# Add forwarded port if not already in netloc
|
||||
netloc = f"{netloc}:{x_forwarded_port}"
|
||||
|
||||
# Reconstruct the URL
|
||||
return urlunparse((scheme, netloc, parsed.path, "", "", ""))
|
||||
|
||||
|
||||
def encode_state_with_base_url(
|
||||
base_url: str,
|
||||
original_state: str,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
client_redirect_uri: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Encode the base_url, original state, and PKCE parameters using encryption.
|
||||
|
||||
Args:
|
||||
base_url: The base URL to encode
|
||||
original_state: The original state parameter
|
||||
code_challenge: PKCE code challenge from client
|
||||
code_challenge_method: PKCE code challenge method from client
|
||||
client_redirect_uri: Original redirect_uri from client
|
||||
|
||||
Returns:
|
||||
An encrypted string that encodes all values
|
||||
"""
|
||||
state_data = {
|
||||
"base_url": base_url,
|
||||
"original_state": original_state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": code_challenge_method,
|
||||
"client_redirect_uri": client_redirect_uri,
|
||||
}
|
||||
state_json = json.dumps(state_data, sort_keys=True)
|
||||
encrypted_state = encrypt_value_helper(state_json)
|
||||
return encrypted_state
|
||||
|
||||
|
||||
def decode_state_hash(encrypted_state: str) -> dict:
|
||||
"""
|
||||
Decode an encrypted state to retrieve all OAuth session data.
|
||||
|
||||
Args:
|
||||
encrypted_state: The encrypted string to decode
|
||||
|
||||
Returns:
|
||||
A dict containing base_url, original_state, and optional PKCE parameters
|
||||
|
||||
Raises:
|
||||
Exception: If decryption fails or data is malformed
|
||||
"""
|
||||
decrypted_json = decrypt_value_helper(encrypted_state, "oauth_state")
|
||||
if decrypted_json is None:
|
||||
raise ValueError("Failed to decrypt state parameter")
|
||||
|
||||
state_data = json.loads(decrypted_json)
|
||||
return state_data
|
||||
|
||||
|
||||
def _resolve_oauth2_server_for_root_endpoints(
|
||||
client_ip: Optional[str] = None,
|
||||
) -> Optional[MCPServer]:
|
||||
"""
|
||||
Resolve the MCP server for root-level OAuth endpoints (no server name in path).
|
||||
|
||||
When the MCP SDK hits root-level endpoints like /register, /authorize, /token
|
||||
without a server name prefix, we try to find the right server automatically.
|
||||
Returns the server if exactly one OAuth2 server is configured, else None.
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
registry = global_mcp_server_manager.get_filtered_registry(client_ip=client_ip)
|
||||
oauth2_servers = [s for s in registry.values() if s.auth_type == MCPAuth.oauth2]
|
||||
if len(oauth2_servers) == 1:
|
||||
return oauth2_servers[0]
|
||||
return None
|
||||
|
||||
|
||||
async def authorize_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_id: str,
|
||||
redirect_uri: str,
|
||||
state: str = "",
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
if mcp_server.auth_type != "oauth2":
|
||||
raise HTTPException(status_code=400, detail="MCP server is not OAuth2")
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
parsed = urlparse(redirect_uri)
|
||||
base_url = urlunparse(parsed._replace(query=""))
|
||||
request_base_url = get_request_base_url(request)
|
||||
encoded_state = encode_state_with_base_url(
|
||||
base_url=base_url,
|
||||
original_state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
client_redirect_uri=redirect_uri,
|
||||
)
|
||||
|
||||
params = {
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"redirect_uri": f"{request_base_url}/callback",
|
||||
"state": encoded_state,
|
||||
"response_type": response_type or "code",
|
||||
}
|
||||
if scope:
|
||||
params["scope"] = scope
|
||||
elif mcp_server.scopes:
|
||||
params["scope"] = " ".join(mcp_server.scopes)
|
||||
|
||||
if code_challenge:
|
||||
params["code_challenge"] = code_challenge
|
||||
if code_challenge_method:
|
||||
params["code_challenge_method"] = code_challenge_method
|
||||
|
||||
parsed_auth_url = urlparse(mcp_server.authorization_url)
|
||||
existing_params = dict(parse_qsl(parsed_auth_url.query))
|
||||
existing_params.update(params)
|
||||
final_url = urlunparse(parsed_auth_url._replace(query=urlencode(existing_params)))
|
||||
return RedirectResponse(final_url)
|
||||
|
||||
|
||||
async def exchange_token_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
grant_type: str,
|
||||
code: Optional[str],
|
||||
redirect_uri: Optional[str],
|
||||
client_id: str,
|
||||
client_secret: Optional[str],
|
||||
code_verifier: Optional[str],
|
||||
):
|
||||
if grant_type != "authorization_code":
|
||||
raise HTTPException(status_code=400, detail="Unsupported grant_type")
|
||||
|
||||
if mcp_server.token_url is None:
|
||||
raise HTTPException(status_code=400, detail="MCP server token url is not set")
|
||||
|
||||
proxy_base_url = get_request_base_url(request)
|
||||
token_data = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": mcp_server.client_id if mcp_server.client_id else client_id,
|
||||
"client_secret": mcp_server.client_secret
|
||||
if mcp_server.client_secret
|
||||
else client_secret,
|
||||
"code": code,
|
||||
"redirect_uri": f"{proxy_base_url}/callback",
|
||||
}
|
||||
|
||||
if code_verifier:
|
||||
token_data["code_verifier"] = code_verifier
|
||||
|
||||
async_client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
response = await async_client.post(
|
||||
mcp_server.token_url,
|
||||
headers={"Accept": "application/json"},
|
||||
data=token_data,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
token_response = response.json()
|
||||
access_token = token_response["access_token"]
|
||||
|
||||
result = {
|
||||
"access_token": access_token,
|
||||
"token_type": token_response.get("token_type", "Bearer"),
|
||||
"expires_in": token_response.get("expires_in", 3600),
|
||||
}
|
||||
|
||||
if "refresh_token" in token_response and token_response["refresh_token"]:
|
||||
result["refresh_token"] = token_response["refresh_token"]
|
||||
if "scope" in token_response and token_response["scope"]:
|
||||
result["scope"] = token_response["scope"]
|
||||
|
||||
return JSONResponse(result)
|
||||
|
||||
|
||||
async def register_client_with_server(
|
||||
request: Request,
|
||||
mcp_server: MCPServer,
|
||||
client_name: str,
|
||||
grant_types: Optional[list],
|
||||
response_types: Optional[list],
|
||||
token_endpoint_auth_method: Optional[str],
|
||||
fallback_client_id: Optional[str] = None,
|
||||
):
|
||||
request_base_url = get_request_base_url(request)
|
||||
dummy_return = {
|
||||
"client_id": fallback_client_id or mcp_server.server_name,
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
|
||||
if mcp_server.client_id and mcp_server.client_secret:
|
||||
return dummy_return
|
||||
|
||||
if mcp_server.authorization_url is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="MCP server authorization url is not set"
|
||||
)
|
||||
|
||||
if mcp_server.registration_url is None:
|
||||
return dummy_return
|
||||
|
||||
register_data = {
|
||||
"client_name": client_name,
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
"grant_types": grant_types or [],
|
||||
"response_types": response_types or [],
|
||||
"token_endpoint_auth_method": token_endpoint_auth_method or "",
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
}
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.Oauth2Register
|
||||
)
|
||||
response = await async_client.post(
|
||||
mcp_server.registration_url,
|
||||
headers=headers,
|
||||
json=register_data,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
token_response = response.json()
|
||||
|
||||
return JSONResponse(token_response)
|
||||
|
||||
|
||||
@router.get("/{mcp_server_name}/authorize")
|
||||
@router.get("/authorize")
|
||||
async def authorize(
|
||||
request: Request,
|
||||
redirect_uri: str,
|
||||
client_id: Optional[str] = None,
|
||||
state: str = "",
|
||||
mcp_server_name: Optional[str] = None,
|
||||
code_challenge: Optional[str] = None,
|
||||
code_challenge_method: Optional[str] = None,
|
||||
response_type: Optional[str] = None,
|
||||
scope: Optional[str] = None,
|
||||
):
|
||||
# Redirect to real OAuth provider with PKCE support
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
lookup_name: Optional[str] = mcp_server_name or client_id
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = (
|
||||
global_mcp_server_manager.get_mcp_server_by_name(
|
||||
lookup_name, client_ip=client_ip
|
||||
)
|
||||
if lookup_name
|
||||
else None
|
||||
)
|
||||
if mcp_server is None and mcp_server_name is None:
|
||||
mcp_server = _resolve_oauth2_server_for_root_endpoints()
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
# Use server's stored client_id when caller doesn't supply one.
|
||||
# Raise a clear error instead of passing an empty string — an empty
|
||||
# client_id would silently produce a broken authorization URL.
|
||||
resolved_client_id: str = mcp_server.client_id or client_id or ""
|
||||
if not resolved_client_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "client_id is required but was not supplied and is not "
|
||||
"stored on the MCP server record. Provide client_id as a query "
|
||||
"parameter or configure it on the server."
|
||||
},
|
||||
)
|
||||
return await authorize_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_id=resolved_client_id,
|
||||
redirect_uri=redirect_uri,
|
||||
state=state,
|
||||
code_challenge=code_challenge,
|
||||
code_challenge_method=code_challenge_method,
|
||||
response_type=response_type,
|
||||
scope=scope,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{mcp_server_name}/token")
|
||||
@router.post("/token")
|
||||
async def token_endpoint(
|
||||
request: Request,
|
||||
grant_type: str = Form(...),
|
||||
code: str = Form(None),
|
||||
redirect_uri: str = Form(None),
|
||||
client_id: str = Form(...),
|
||||
client_secret: Optional[str] = Form(None),
|
||||
code_verifier: str = Form(None),
|
||||
mcp_server_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Accept the authorization code from client and exchange it for OAuth token.
|
||||
Supports PKCE flow by forwarding code_verifier to upstream provider.
|
||||
|
||||
1. Call the token endpoint with PKCE parameters
|
||||
2. Store the user's token in the db - and generate a LiteLLM virtual key
|
||||
3. Return the token
|
||||
4. Return a virtual key in this response
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
lookup_name = mcp_server_name or client_id
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
lookup_name, client_ip=client_ip
|
||||
)
|
||||
if mcp_server is None and mcp_server_name is None:
|
||||
mcp_server = _resolve_oauth2_server_for_root_endpoints()
|
||||
if mcp_server is None:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
return await exchange_token_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
grant_type=grant_type,
|
||||
code=code,
|
||||
redirect_uri=redirect_uri,
|
||||
client_id=client_id,
|
||||
client_secret=client_secret,
|
||||
code_verifier=code_verifier,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/callback")
|
||||
async def callback(code: str, state: str):
|
||||
try:
|
||||
# Decode the state hash to get base_url, original state, and PKCE params
|
||||
state_data = decode_state_hash(state)
|
||||
base_url = state_data["base_url"]
|
||||
original_state = state_data["original_state"]
|
||||
|
||||
# Forward code and original state back to client
|
||||
params = {"code": code, "state": original_state}
|
||||
|
||||
# Forward to client's callback endpoint
|
||||
complete_returned_url = f"{base_url}?{urlencode(params)}"
|
||||
return RedirectResponse(url=complete_returned_url, status_code=302)
|
||||
|
||||
except Exception:
|
||||
# fallback if state hash not found
|
||||
return HTMLResponse(
|
||||
"<html><body>Authentication incomplete. You can close this window.</body></html>"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Optional .well-known endpoints for MCP + OAuth discovery
|
||||
# ------------------------------
|
||||
"""
|
||||
Per SEP-985, the client MUST:
|
||||
1. Try resource_metadata from WWW-Authenticate header (if present)
|
||||
2. Fall back to path-based well-known URI: /.well-known/oauth-protected-resource/{path}
|
||||
(
|
||||
If the resource identifier value contains a path or query component, any terminating slash (/)
|
||||
following the host component MUST be removed before inserting /.well-known/ and the well-known
|
||||
URI path suffix between the host component and the path(include root path) and/or query components.
|
||||
https://datatracker.ietf.org/doc/html/rfc9728#section-3.1)
|
||||
3. Fall back to root-based well-known URI: /.well-known/oauth-protected-resource
|
||||
|
||||
Dual Pattern Support:
|
||||
- Standard MCP pattern: /mcp/{server_name} (recommended, used by mcp-inspector, VSCode Copilot)
|
||||
- LiteLLM legacy pattern: /{server_name}/mcp (backward compatibility)
|
||||
|
||||
The resource URL returned matches the pattern used in the discovery request.
|
||||
"""
|
||||
|
||||
|
||||
def _build_oauth_protected_resource_response(
|
||||
request: Request,
|
||||
mcp_server_name: Optional[str],
|
||||
use_standard_pattern: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Build OAuth protected resource response with the appropriate URL pattern.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
mcp_server_name: Name of the MCP server
|
||||
use_standard_pattern: If True, use /mcp/{server_name} pattern;
|
||||
if False, use /{server_name}/mcp pattern
|
||||
|
||||
Returns:
|
||||
OAuth protected resource metadata dict
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
# When no server name provided, try to resolve the single OAuth2 server
|
||||
if mcp_server_name is None:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
mcp_server_name = resolved.server_name or resolved.name
|
||||
|
||||
mcp_server: Optional[MCPServer] = None
|
||||
if mcp_server_name:
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
|
||||
# Build resource URL based on the pattern
|
||||
if mcp_server_name:
|
||||
if use_standard_pattern:
|
||||
# Standard MCP pattern: /mcp/{server_name}
|
||||
resource_url = f"{request_base_url}/mcp/{mcp_server_name}"
|
||||
else:
|
||||
# LiteLLM legacy pattern: /{server_name}/mcp
|
||||
resource_url = f"{request_base_url}/{mcp_server_name}/mcp"
|
||||
else:
|
||||
resource_url = f"{request_base_url}/mcp"
|
||||
|
||||
return {
|
||||
"authorization_servers": [
|
||||
(
|
||||
f"{request_base_url}/{mcp_server_name}"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}"
|
||||
)
|
||||
],
|
||||
"resource": resource_url,
|
||||
"scopes_supported": mcp_server.scopes
|
||||
if mcp_server and mcp_server.scopes
|
||||
else [],
|
||||
}
|
||||
|
||||
|
||||
# Standard MCP pattern: /.well-known/oauth-protected-resource/mcp/{server_name}
|
||||
# This is the pattern expected by standard MCP clients (mcp-inspector, VSCode Copilot)
|
||||
@router.get(
|
||||
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
|
||||
)
|
||||
async def oauth_protected_resource_mcp_standard(request: Request, mcp_server_name: str):
|
||||
"""
|
||||
OAuth protected resource discovery endpoint using standard MCP URL pattern.
|
||||
|
||||
Standard pattern: /mcp/{server_name}
|
||||
Discovery path: /.well-known/oauth-protected-resource/mcp/{server_name}
|
||||
|
||||
This endpoint is compliant with MCP specification and works with standard
|
||||
MCP clients like mcp-inspector and VSCode Copilot.
|
||||
"""
|
||||
return _build_oauth_protected_resource_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
use_standard_pattern=True,
|
||||
)
|
||||
|
||||
|
||||
# LiteLLM legacy pattern: /.well-known/oauth-protected-resource/{server_name}/mcp
|
||||
# Kept for backward compatibility with existing deployments
|
||||
@router.get(
|
||||
f"/.well-known/oauth-protected-resource{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}/mcp"
|
||||
)
|
||||
@router.get("/.well-known/oauth-protected-resource")
|
||||
async def oauth_protected_resource_mcp(
|
||||
request: Request, mcp_server_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
OAuth protected resource discovery endpoint using LiteLLM legacy URL pattern.
|
||||
|
||||
Legacy pattern: /{server_name}/mcp
|
||||
Discovery path: /.well-known/oauth-protected-resource/{server_name}/mcp
|
||||
|
||||
This endpoint is kept for backward compatibility. New integrations should
|
||||
use the standard MCP pattern (/mcp/{server_name}) instead.
|
||||
"""
|
||||
return _build_oauth_protected_resource_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
use_standard_pattern=False,
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
|
||||
RFC 8414: Path-aware OAuth discovery
|
||||
If the issuer identifier value contains a path component, any
|
||||
terminating "/" MUST be removed before inserting "/.well-known/" and
|
||||
the well-known URI suffix between the host component and the path(include root path)
|
||||
component.
|
||||
"""
|
||||
|
||||
|
||||
def _build_oauth_authorization_server_response(
|
||||
request: Request,
|
||||
mcp_server_name: Optional[str],
|
||||
) -> dict:
|
||||
"""
|
||||
Build OAuth authorization server metadata response.
|
||||
|
||||
Args:
|
||||
request: FastAPI Request object
|
||||
mcp_server_name: Name of the MCP server
|
||||
|
||||
Returns:
|
||||
OAuth authorization server metadata dict
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
# When no server name provided, try to resolve the single OAuth2 server
|
||||
if mcp_server_name is None:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
mcp_server_name = resolved.server_name or resolved.name
|
||||
|
||||
authorization_endpoint = (
|
||||
f"{request_base_url}/{mcp_server_name}/authorize"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/authorize"
|
||||
)
|
||||
token_endpoint = (
|
||||
f"{request_base_url}/{mcp_server_name}/token"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/token"
|
||||
)
|
||||
|
||||
mcp_server: Optional[MCPServer] = None
|
||||
if mcp_server_name:
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
|
||||
return {
|
||||
"issuer": request_base_url, # point to your proxy
|
||||
"authorization_endpoint": authorization_endpoint,
|
||||
"token_endpoint": token_endpoint,
|
||||
"response_types_supported": ["code"],
|
||||
"scopes_supported": mcp_server.scopes
|
||||
if mcp_server and mcp_server.scopes
|
||||
else [],
|
||||
"grant_types_supported": ["authorization_code", "refresh_token"],
|
||||
"code_challenge_methods_supported": ["S256"],
|
||||
"token_endpoint_auth_methods_supported": ["client_secret_post"],
|
||||
# Claude expects a registration endpoint, even if we just fake it
|
||||
"registration_endpoint": f"{request_base_url}/{mcp_server_name}/register"
|
||||
if mcp_server_name
|
||||
else f"{request_base_url}/register",
|
||||
}
|
||||
|
||||
|
||||
# Standard MCP pattern: /.well-known/oauth-authorization-server/mcp/{server_name}
|
||||
@router.get(
|
||||
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/mcp/{{mcp_server_name}}"
|
||||
)
|
||||
async def oauth_authorization_server_mcp_standard(
|
||||
request: Request, mcp_server_name: str
|
||||
):
|
||||
"""
|
||||
OAuth authorization server discovery endpoint using standard MCP URL pattern.
|
||||
|
||||
Standard pattern: /mcp/{server_name}
|
||||
Discovery path: /.well-known/oauth-authorization-server/mcp/{server_name}
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
# LiteLLM legacy pattern and root endpoint
|
||||
@router.get(
|
||||
f"/.well-known/oauth-authorization-server{'' if get_server_root_path() == '/' else get_server_root_path()}/{{mcp_server_name}}"
|
||||
)
|
||||
@router.get("/.well-known/oauth-authorization-server")
|
||||
async def oauth_authorization_server_mcp(
|
||||
request: Request, mcp_server_name: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
OAuth authorization server discovery endpoint.
|
||||
|
||||
Supports both legacy pattern (/{server_name}) and root endpoint.
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
# Alias for standard OpenID discovery
|
||||
@router.get("/.well-known/openid-configuration")
|
||||
async def openid_configuration(request: Request):
|
||||
return await oauth_authorization_server_mcp(request)
|
||||
|
||||
|
||||
# Additional legacy pattern support
|
||||
@router.get("/.well-known/oauth-authorization-server/{mcp_server_name}/mcp")
|
||||
async def oauth_authorization_server_legacy(request: Request, mcp_server_name: str):
|
||||
"""
|
||||
OAuth authorization server discovery for legacy /{server_name}/mcp pattern.
|
||||
"""
|
||||
return _build_oauth_authorization_server_response(
|
||||
request=request,
|
||||
mcp_server_name=mcp_server_name,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{mcp_server_name}/register")
|
||||
@router.post("/register")
|
||||
async def register_client(request: Request, mcp_server_name: Optional[str] = None):
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
# Get the correct base URL considering X-Forwarded-* headers
|
||||
request_base_url = get_request_base_url(request)
|
||||
|
||||
request_data = await _read_request_body(request=request)
|
||||
data: dict = {**request_data}
|
||||
|
||||
dummy_return = {
|
||||
"client_id": mcp_server_name or "dummy_client",
|
||||
"client_secret": "dummy",
|
||||
"redirect_uris": [f"{request_base_url}/callback"],
|
||||
}
|
||||
if not mcp_server_name:
|
||||
resolved = _resolve_oauth2_server_for_root_endpoints()
|
||||
if resolved:
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=resolved,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=resolved.server_name or resolved.name,
|
||||
)
|
||||
return dummy_return
|
||||
|
||||
client_ip = IPAddressUtils.get_mcp_client_ip(request)
|
||||
mcp_server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
mcp_server_name, client_ip=client_ip
|
||||
)
|
||||
if mcp_server is None:
|
||||
return dummy_return
|
||||
return await register_client_with_server(
|
||||
request=request,
|
||||
mcp_server=mcp_server,
|
||||
client_name=data.get("client_name", ""),
|
||||
grant_types=data.get("grant_types", []),
|
||||
response_types=data.get("response_types", []),
|
||||
token_endpoint_auth_method=data.get("token_endpoint_auth_method", ""),
|
||||
fallback_client_id=mcp_server_name,
|
||||
)
|
||||
@@ -0,0 +1,16 @@
|
||||
"""Guardrail translation mapping for MCP tool calls."""
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.guardrail_translation.handler import (
|
||||
MCPGuardrailTranslationHandler,
|
||||
)
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# This mapping lives alongside the MCP server implementation because MCP
|
||||
# integrations are managed by the proxy subsystem, not litellm.llms providers.
|
||||
# Unified guardrails import this module explicitly to register the handler.
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.call_mcp_tool: MCPGuardrailTranslationHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "MCPGuardrailTranslationHandler"]
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
MCP Guardrail Handler for Unified Guardrails.
|
||||
|
||||
Converts an MCP call_tool (name + arguments) into a single OpenAI-compatible
|
||||
tool_call and passes it to apply_guardrail. Works with the synthetic payload
|
||||
from ProxyLogging._convert_mcp_to_llm_format.
|
||||
|
||||
Note: For MCP tool definitions (schema) -> OpenAI tools=[], see
|
||||
litellm.experimental_mcp_client.tools.transform_mcp_tool_to_openai_tool
|
||||
when you have a full MCP Tool from list_tools. Here we only have the call
|
||||
payload (name + arguments) so we just build the tool_call.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.experimental_mcp_client.tools import transform_mcp_tool_to_openai_tool
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import CallToolResult
|
||||
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
|
||||
|
||||
class MCPGuardrailTranslationHandler(BaseTranslation):
|
||||
"""Guardrail translation handler for MCP tool calls (passes a single tool_call to guardrail)."""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: Dict[str, Any],
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Dict[str, Any]:
|
||||
mcp_tool_name = data.get("mcp_tool_name") or data.get("name")
|
||||
mcp_arguments = data.get("mcp_arguments") or data.get("arguments")
|
||||
mcp_tool_description = data.get("mcp_tool_description") or data.get(
|
||||
"description"
|
||||
)
|
||||
if mcp_arguments is None or not isinstance(mcp_arguments, dict):
|
||||
mcp_arguments = {}
|
||||
|
||||
if not mcp_tool_name:
|
||||
verbose_proxy_logger.debug("MCP Guardrail: mcp_tool_name missing")
|
||||
return data
|
||||
|
||||
# Convert MCP input via transform_mcp_tool_to_openai_tool, then map to litellm
|
||||
# ChatCompletionToolParam (openai SDK type has incompatible strict/cache_control).
|
||||
mcp_tool = MCPTool(
|
||||
name=mcp_tool_name,
|
||||
description=mcp_tool_description or "",
|
||||
inputSchema={}, # Call payload has no schema; guardrail gets args from request_data
|
||||
)
|
||||
openai_tool = transform_mcp_tool_to_openai_tool(mcp_tool)
|
||||
fn = openai_tool["function"]
|
||||
tool_def: ChatCompletionToolParam = {
|
||||
"type": "function",
|
||||
"function": ChatCompletionToolParamFunctionChunk(
|
||||
name=fn["name"],
|
||||
description=fn.get("description") or "",
|
||||
parameters=fn.get("parameters")
|
||||
or {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
strict=fn.get("strict", False) or False, # Default to False if None
|
||||
),
|
||||
}
|
||||
inputs: GenericGuardrailAPIInputs = GenericGuardrailAPIInputs(
|
||||
tools=[tool_def],
|
||||
)
|
||||
|
||||
await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "CallToolResult",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
verbose_proxy_logger.debug(
|
||||
"MCP Guardrail: Output processing not implemented for MCP tools",
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,325 @@
|
||||
"""
|
||||
MCP OAuth2 Debug Headers
|
||||
========================
|
||||
|
||||
Client-side debugging for MCP authentication flows.
|
||||
|
||||
When a client sends the ``x-litellm-mcp-debug: true`` header, LiteLLM
|
||||
returns masked diagnostic headers in the response so operators can
|
||||
troubleshoot OAuth2 issues without SSH access to the gateway.
|
||||
|
||||
Response headers returned (all values are masked for safety):
|
||||
|
||||
x-mcp-debug-inbound-auth
|
||||
Which inbound auth headers were present and how they were classified.
|
||||
Example: ``x-litellm-api-key=Bearer sk-12****1234``
|
||||
|
||||
x-mcp-debug-oauth2-token
|
||||
The OAuth2 token extracted from the Authorization header (masked).
|
||||
Shows ``(none)`` if absent, or flags ``SAME_AS_LITELLM_KEY`` when
|
||||
the LiteLLM API key is accidentally leaking to the MCP server.
|
||||
|
||||
x-mcp-debug-auth-resolution
|
||||
Which auth priority was used for the outbound MCP call:
|
||||
``per-request-header``, ``m2m-client-credentials``, ``static-token``,
|
||||
``oauth2-passthrough``, or ``no-auth``.
|
||||
|
||||
x-mcp-debug-outbound-url
|
||||
The upstream MCP server URL that will receive the request.
|
||||
|
||||
x-mcp-debug-server-auth-type
|
||||
The ``auth_type`` configured on the MCP server (e.g. ``oauth2``,
|
||||
``bearer_token``, ``none``).
|
||||
|
||||
Debugging Guide
|
||||
---------------
|
||||
|
||||
**Common issue: LiteLLM API key leaking to the MCP server**
|
||||
|
||||
Symptom: ``x-mcp-debug-oauth2-token`` shows ``SAME_AS_LITELLM_KEY``.
|
||||
|
||||
This means the ``Authorization`` header carries the LiteLLM API key and
|
||||
it's being forwarded to the upstream MCP server instead of an OAuth2 token.
|
||||
|
||||
Fix: Move the LiteLLM key to ``x-litellm-api-key`` so the ``Authorization``
|
||||
header is free for OAuth2 discovery::
|
||||
|
||||
# WRONG — blocks OAuth2 discovery
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "Authorization: Bearer sk-..."
|
||||
|
||||
# CORRECT — LiteLLM key in dedicated header, Authorization free for OAuth2
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "x-litellm-api-key: Bearer sk-..." \\
|
||||
--header "x-litellm-mcp-debug: true"
|
||||
|
||||
**Common issue: No OAuth2 token present**
|
||||
|
||||
Symptom: ``x-mcp-debug-oauth2-token`` shows ``(none)`` and
|
||||
``x-mcp-debug-auth-resolution`` shows ``no-auth``.
|
||||
|
||||
This means the client didn't go through the OAuth2 flow. Check that:
|
||||
1. The ``Authorization`` header is NOT set as a static header in the client config.
|
||||
2. The ``.well-known/oauth-protected-resource`` endpoint returns valid metadata.
|
||||
3. The MCP server in LiteLLM config has ``auth_type: oauth2``.
|
||||
|
||||
**Common issue: M2M token used instead of user token**
|
||||
|
||||
Symptom: ``x-mcp-debug-auth-resolution`` shows ``m2m-client-credentials``.
|
||||
|
||||
This means the server has ``client_id``/``client_secret``/``token_url``
|
||||
configured and LiteLLM is fetching a machine-to-machine token instead of
|
||||
using the per-user OAuth2 token. If you want per-user tokens, remove the
|
||||
client credentials from the server config.
|
||||
|
||||
Usage from Claude Code::
|
||||
|
||||
claude mcp add --transport http my_server http://proxy/mcp/server \\
|
||||
--header "x-litellm-api-key: Bearer sk-..." \\
|
||||
--header "x-litellm-mcp-debug: true"
|
||||
|
||||
Usage with curl::
|
||||
|
||||
curl -H "x-litellm-mcp-debug: true" \\
|
||||
-H "x-litellm-api-key: Bearer sk-..." \\
|
||||
http://localhost:4000/mcp/atlassian_mcp
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
|
||||
from starlette.types import Message, Send
|
||||
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
# Header the client sends to opt into debug mode
|
||||
MCP_DEBUG_REQUEST_HEADER = "x-litellm-mcp-debug"
|
||||
|
||||
# Prefix for all debug response headers
|
||||
_RESPONSE_HEADER_PREFIX = "x-mcp-debug"
|
||||
|
||||
|
||||
class MCPDebug:
|
||||
"""
|
||||
Static helper class for MCP OAuth2 debug headers.
|
||||
|
||||
Provides opt-in client-side diagnostics by injecting masked
|
||||
authentication info into HTTP response headers.
|
||||
"""
|
||||
|
||||
# Masker: show first 6 and last 4 chars so you can distinguish token types
|
||||
# e.g. "Bearer****ef01" vs "sk-123****cdef"
|
||||
_masker = SensitiveDataMasker(
|
||||
sensitive_patterns={
|
||||
"authorization",
|
||||
"token",
|
||||
"key",
|
||||
"secret",
|
||||
"auth",
|
||||
"bearer",
|
||||
},
|
||||
visible_prefix=6,
|
||||
visible_suffix=4,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _mask(value: Optional[str]) -> str:
|
||||
"""Mask a single value for safe display in headers."""
|
||||
if not value:
|
||||
return "(none)"
|
||||
return MCPDebug._masker._mask_value(value)
|
||||
|
||||
@staticmethod
|
||||
def is_debug_enabled(headers: Dict[str, str]) -> bool:
|
||||
"""
|
||||
Check if the client opted into MCP debug mode.
|
||||
|
||||
Looks for ``x-litellm-mcp-debug: true`` (case-insensitive) in the
|
||||
request headers.
|
||||
"""
|
||||
for key, val in headers.items():
|
||||
if key.lower() == MCP_DEBUG_REQUEST_HEADER:
|
||||
return val.strip().lower() in ("true", "1", "yes")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def resolve_auth_resolution(
|
||||
server: "MCPServer",
|
||||
mcp_auth_header: Optional[str],
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
) -> str:
|
||||
"""
|
||||
Determine which auth priority will be used for the outbound MCP call.
|
||||
|
||||
Returns one of: ``per-request-header``, ``m2m-client-credentials``,
|
||||
``static-token``, ``oauth2-passthrough``, or ``no-auth``.
|
||||
"""
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
has_server_specific = bool(
|
||||
mcp_server_auth_headers
|
||||
and (
|
||||
mcp_server_auth_headers.get(server.alias or "")
|
||||
or mcp_server_auth_headers.get(server.server_name or "")
|
||||
)
|
||||
)
|
||||
if has_server_specific or mcp_auth_header:
|
||||
return "per-request-header"
|
||||
if server.has_client_credentials:
|
||||
return "m2m-client-credentials"
|
||||
if server.authentication_token:
|
||||
return "static-token"
|
||||
if oauth2_headers and server.auth_type == MCPAuth.oauth2:
|
||||
return "oauth2-passthrough"
|
||||
return "no-auth"
|
||||
|
||||
@staticmethod
|
||||
def build_debug_headers(
|
||||
*,
|
||||
inbound_headers: Dict[str, str],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
litellm_api_key: Optional[str],
|
||||
auth_resolution: str,
|
||||
server_url: Optional[str],
|
||||
server_auth_type: Optional[str],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Build masked debug response headers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
inbound_headers : dict
|
||||
Raw headers received from the MCP client.
|
||||
oauth2_headers : dict or None
|
||||
Extracted OAuth2 headers (``{"Authorization": "Bearer ..."}``).
|
||||
litellm_api_key : str or None
|
||||
The LiteLLM API key extracted from ``x-litellm-api-key`` or
|
||||
``Authorization`` header.
|
||||
auth_resolution : str
|
||||
Which auth priority was selected for the outbound call.
|
||||
server_url : str or None
|
||||
Upstream MCP server URL.
|
||||
server_auth_type : str or None
|
||||
The ``auth_type`` configured on the server (e.g. ``oauth2``).
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
Headers to include in the response (all values masked).
|
||||
"""
|
||||
debug: Dict[str, str] = {}
|
||||
|
||||
# --- Inbound auth summary ---
|
||||
inbound_parts = []
|
||||
for hdr_name in ("x-litellm-api-key", "authorization", "x-mcp-auth"):
|
||||
for k, v in inbound_headers.items():
|
||||
if k.lower() == hdr_name:
|
||||
inbound_parts.append(f"{hdr_name}={MCPDebug._mask(v)}")
|
||||
break
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-inbound-auth"] = (
|
||||
"; ".join(inbound_parts) if inbound_parts else "(none)"
|
||||
)
|
||||
|
||||
# --- OAuth2 token ---
|
||||
oauth2_token = (oauth2_headers or {}).get("Authorization")
|
||||
if oauth2_token and litellm_api_key:
|
||||
oauth2_raw = oauth2_token.removeprefix("Bearer ").strip()
|
||||
litellm_raw = litellm_api_key.removeprefix("Bearer ").strip()
|
||||
if oauth2_raw == litellm_raw:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = (
|
||||
f"{MCPDebug._mask(oauth2_token)} "
|
||||
f"(SAME_AS_LITELLM_KEY - likely misconfigured)"
|
||||
)
|
||||
else:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
|
||||
oauth2_token
|
||||
)
|
||||
else:
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-oauth2-token"] = MCPDebug._mask(
|
||||
oauth2_token
|
||||
)
|
||||
|
||||
# --- Auth resolution ---
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-auth-resolution"] = auth_resolution
|
||||
|
||||
# --- Server info ---
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-outbound-url"] = server_url or "(unknown)"
|
||||
debug[f"{_RESPONSE_HEADER_PREFIX}-server-auth-type"] = (
|
||||
server_auth_type or "(none)"
|
||||
)
|
||||
|
||||
return debug
|
||||
|
||||
@staticmethod
|
||||
def wrap_send_with_debug_headers(send: Send, debug_headers: Dict[str, str]) -> Send:
|
||||
"""
|
||||
Return a new ASGI ``send`` callable that injects *debug_headers*
|
||||
into the ``http.response.start`` message.
|
||||
"""
|
||||
|
||||
async def _send_with_debug(message: Message) -> None:
|
||||
if message["type"] == "http.response.start":
|
||||
headers = list(message.get("headers", []))
|
||||
for k, v in debug_headers.items():
|
||||
headers.append((k.encode(), v.encode()))
|
||||
message = {**message, "headers": headers}
|
||||
await send(message)
|
||||
|
||||
return _send_with_debug
|
||||
|
||||
@staticmethod
|
||||
def maybe_build_debug_headers(
|
||||
*,
|
||||
raw_headers: Optional[Dict[str, str]],
|
||||
scope: Dict,
|
||||
mcp_servers: Optional[List[str]],
|
||||
mcp_auth_header: Optional[str],
|
||||
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
|
||||
oauth2_headers: Optional[Dict[str, str]],
|
||||
client_ip: Optional[str],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Build debug headers if debug mode is enabled, otherwise return empty dict.
|
||||
|
||||
This is the single entry point called from the MCP request handler.
|
||||
"""
|
||||
if not raw_headers or not MCPDebug.is_debug_enabled(raw_headers):
|
||||
return {}
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
server_url: Optional[str] = None
|
||||
server_auth_type: Optional[str] = None
|
||||
auth_resolution = "no-auth"
|
||||
|
||||
for server_name in mcp_servers or []:
|
||||
server = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
server_name, client_ip=client_ip
|
||||
)
|
||||
if server:
|
||||
server_url = server.url
|
||||
server_auth_type = server.auth_type
|
||||
auth_resolution = MCPDebug.resolve_auth_resolution(
|
||||
server, mcp_auth_header, mcp_server_auth_headers, oauth2_headers
|
||||
)
|
||||
break
|
||||
|
||||
scope_headers = MCPRequestHandler._safe_get_headers_from_scope(scope)
|
||||
litellm_key = MCPRequestHandler.get_litellm_api_key_from_headers(scope_headers)
|
||||
|
||||
return MCPDebug.build_debug_headers(
|
||||
inbound_headers=raw_headers,
|
||||
oauth2_headers=oauth2_headers,
|
||||
litellm_api_key=litellm_key,
|
||||
auth_resolution=auth_resolution,
|
||||
server_url=server_url,
|
||||
server_auth_type=server_auth_type,
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,170 @@
|
||||
"""
|
||||
OAuth2 client_credentials token cache for MCP servers.
|
||||
|
||||
Automatically fetches and refreshes access tokens for MCP servers configured
|
||||
with ``client_id``, ``client_secret``, and ``token_url``.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.constants import (
|
||||
MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
|
||||
MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
|
||||
class MCPOAuth2TokenCache(InMemoryCache):
|
||||
"""
|
||||
In-memory cache for OAuth2 client_credentials tokens, keyed by server_id.
|
||||
|
||||
Inherits from ``InMemoryCache`` for TTL-based storage and eviction.
|
||||
Adds per-server ``asyncio.Lock`` to prevent duplicate concurrent fetches.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
max_size_in_memory=MCP_OAUTH2_TOKEN_CACHE_MAX_SIZE,
|
||||
default_ttl=MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL,
|
||||
)
|
||||
self._locks: Dict[str, asyncio.Lock] = {}
|
||||
|
||||
def _get_lock(self, server_id: str) -> asyncio.Lock:
|
||||
return self._locks.setdefault(server_id, asyncio.Lock())
|
||||
|
||||
async def async_get_token(self, server: "MCPServer") -> Optional[str]:
|
||||
"""Return a valid access token, fetching or refreshing as needed.
|
||||
|
||||
Returns ``None`` when the server lacks client credentials config.
|
||||
"""
|
||||
if not server.has_client_credentials:
|
||||
return None
|
||||
|
||||
server_id = server.server_id
|
||||
|
||||
# Fast path — cached token is still valid
|
||||
cached = self.get_cache(server_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Slow path — acquire per-server lock then double-check
|
||||
async with self._get_lock(server_id):
|
||||
cached = self.get_cache(server_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
token, ttl = await self._fetch_token(server)
|
||||
self.set_cache(server_id, token, ttl=ttl)
|
||||
return token
|
||||
|
||||
async def _fetch_token(self, server: "MCPServer") -> Tuple[str, int]:
|
||||
"""POST to ``token_url`` with ``grant_type=client_credentials``.
|
||||
|
||||
Returns ``(access_token, ttl_seconds)`` where ttl accounts for the
|
||||
expiry buffer so the cache entry expires before the real token does.
|
||||
"""
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
|
||||
if not server.client_id or not server.client_secret or not server.token_url:
|
||||
raise ValueError(
|
||||
f"MCP server '{server.server_id}' missing required OAuth2 fields: "
|
||||
f"client_id={bool(server.client_id)}, "
|
||||
f"client_secret={bool(server.client_secret)}, "
|
||||
f"token_url={bool(server.token_url)}"
|
||||
)
|
||||
|
||||
data: Dict[str, str] = {
|
||||
"grant_type": "client_credentials",
|
||||
"client_id": server.client_id,
|
||||
"client_secret": server.client_secret,
|
||||
}
|
||||
if server.scopes:
|
||||
data["scope"] = " ".join(server.scopes)
|
||||
|
||||
verbose_logger.debug(
|
||||
"Fetching OAuth2 client_credentials token for MCP server %s",
|
||||
server.server_id,
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(server.token_url, data=data)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
raise ValueError(
|
||||
f"OAuth2 token request for MCP server '{server.server_id}' "
|
||||
f"failed with status {exc.response.status_code}"
|
||||
) from exc
|
||||
|
||||
body = response.json()
|
||||
|
||||
if not isinstance(body, dict):
|
||||
raise ValueError(
|
||||
f"OAuth2 token response for MCP server '{server.server_id}' "
|
||||
f"returned non-object JSON (got {type(body).__name__})"
|
||||
)
|
||||
|
||||
access_token = body.get("access_token")
|
||||
if not access_token:
|
||||
raise ValueError(
|
||||
f"OAuth2 token response for MCP server '{server.server_id}' "
|
||||
f"missing 'access_token'"
|
||||
)
|
||||
|
||||
# Safely parse expires_in — providers may return null or non-numeric values
|
||||
raw_expires_in = body.get("expires_in")
|
||||
try:
|
||||
expires_in = (
|
||||
int(raw_expires_in)
|
||||
if raw_expires_in is not None
|
||||
else MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
|
||||
)
|
||||
except (TypeError, ValueError):
|
||||
expires_in = MCP_OAUTH2_TOKEN_CACHE_DEFAULT_TTL
|
||||
|
||||
ttl = max(
|
||||
expires_in - MCP_OAUTH2_TOKEN_EXPIRY_BUFFER_SECONDS,
|
||||
MCP_OAUTH2_TOKEN_CACHE_MIN_TTL,
|
||||
)
|
||||
|
||||
verbose_logger.info(
|
||||
"Fetched OAuth2 token for MCP server %s (expires in %ds)",
|
||||
server.server_id,
|
||||
expires_in,
|
||||
)
|
||||
return access_token, ttl
|
||||
|
||||
def invalidate(self, server_id: str) -> None:
|
||||
"""Remove a cached token (e.g. after a 401)."""
|
||||
self.delete_cache(server_id)
|
||||
|
||||
|
||||
mcp_oauth2_token_cache = MCPOAuth2TokenCache()
|
||||
|
||||
|
||||
async def resolve_mcp_auth(
|
||||
server: "MCPServer",
|
||||
mcp_auth_header: Optional[Union[str, Dict[str, str]]] = None,
|
||||
) -> Optional[Union[str, Dict[str, str]]]:
|
||||
"""Resolve the auth value for an MCP server.
|
||||
|
||||
Priority:
|
||||
1. ``mcp_auth_header`` — per-request/per-user override
|
||||
2. OAuth2 client_credentials token — auto-fetched and cached
|
||||
3. ``server.authentication_token`` — static token from config/DB
|
||||
"""
|
||||
if mcp_auth_header:
|
||||
return mcp_auth_header
|
||||
if server.has_client_credentials:
|
||||
return await mcp_oauth2_token_cache.async_get_token(server)
|
||||
return server.authentication_token
|
||||
@@ -0,0 +1,435 @@
|
||||
"""
|
||||
This module is used to generate MCP tools from OpenAPI specs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import json
|
||||
import os
|
||||
from pathlib import PurePosixPath
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.tool_registry import (
|
||||
global_mcp_tool_registry,
|
||||
)
|
||||
|
||||
# Store the base URL and headers globally
|
||||
BASE_URL = ""
|
||||
HEADERS: Dict[str, str] = {}
|
||||
|
||||
# Per-request auth header override for BYOK servers.
|
||||
# Set this ContextVar before calling a local tool handler to inject the user's
|
||||
# stored credential into the HTTP request made by the tool function closure.
|
||||
_request_auth_header: contextvars.ContextVar[Optional[str]] = contextvars.ContextVar(
|
||||
"_request_auth_header", default=None
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_path_parameter_value(param_value: Any, param_name: str) -> str:
|
||||
"""Ensure path params cannot introduce directory traversal."""
|
||||
if param_value is None:
|
||||
return ""
|
||||
|
||||
value_str = str(param_value)
|
||||
if value_str == "":
|
||||
return ""
|
||||
|
||||
normalized_value = value_str.replace("\\", "/")
|
||||
if "/" in normalized_value:
|
||||
raise ValueError(
|
||||
f"Path parameter '{param_name}' must not contain path separators"
|
||||
)
|
||||
|
||||
if any(part in {".", ".."} for part in PurePosixPath(normalized_value).parts):
|
||||
raise ValueError(
|
||||
f"Path parameter '{param_name}' cannot include '.' or '..' segments"
|
||||
)
|
||||
|
||||
return quote(value_str, safe="")
|
||||
|
||||
|
||||
def load_openapi_spec(filepath: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Sync wrapper. For URL specs, use the shared/custom MCP httpx client.
|
||||
"""
|
||||
try:
|
||||
# If we're already inside an event loop, prefer the async function.
|
||||
asyncio.get_running_loop()
|
||||
raise RuntimeError(
|
||||
"load_openapi_spec() was called from within a running event loop. "
|
||||
"Use 'await load_openapi_spec_async(...)' instead."
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# "no running event loop" is fine; other RuntimeErrors we re-raise
|
||||
if "no running event loop" not in str(e).lower():
|
||||
raise
|
||||
return asyncio.run(load_openapi_spec_async(filepath))
|
||||
|
||||
|
||||
async def load_openapi_spec_async(filepath: str) -> Dict[str, Any]:
|
||||
if filepath.startswith("http://") or filepath.startswith("https://"):
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
# NOTE: do not close shared client if get_async_httpx_client returns a shared singleton.
|
||||
# If it returns a new client each time, consider wrapping it in an async context manager.
|
||||
r = await client.get(filepath)
|
||||
r.raise_for_status()
|
||||
return r.json()
|
||||
|
||||
# fallback: local file
|
||||
# Local filesystem path
|
||||
if not os.path.exists(filepath):
|
||||
raise FileNotFoundError(f"OpenAPI spec not found at {filepath}")
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_base_url(spec: Dict[str, Any], spec_path: Optional[str] = None) -> str:
|
||||
"""Extract base URL from OpenAPI spec."""
|
||||
# OpenAPI 3.x
|
||||
if "servers" in spec and spec["servers"]:
|
||||
server_url = spec["servers"][0]["url"]
|
||||
|
||||
# If the server URL is relative (starts with /), derive base from spec_path
|
||||
if server_url.startswith("/") and spec_path:
|
||||
if spec_path.startswith("http://") or spec_path.startswith("https://"):
|
||||
# Extract base URL from spec_path (e.g., https://petstore3.swagger.io/api/v3/openapi.json)
|
||||
# Combine domain with the relative server URL
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(spec_path)
|
||||
base_domain = f"{parsed.scheme}://{parsed.netloc}"
|
||||
full_base_url = base_domain + server_url
|
||||
verbose_logger.info(
|
||||
f"OpenAPI spec has relative server URL '{server_url}'. "
|
||||
f"Deriving base from spec_path: {full_base_url}"
|
||||
)
|
||||
return full_base_url
|
||||
|
||||
return server_url
|
||||
# OpenAPI 2.x (Swagger)
|
||||
elif "host" in spec:
|
||||
scheme = spec.get("schemes", ["https"])[0]
|
||||
base_path = spec.get("basePath", "")
|
||||
return f"{scheme}://{spec['host']}{base_path}"
|
||||
|
||||
# Fallback: derive base URL from spec_path if it's a URL
|
||||
if spec_path and (
|
||||
spec_path.startswith("http://") or spec_path.startswith("https://")
|
||||
):
|
||||
for suffix in [
|
||||
"/openapi.json",
|
||||
"/openapi.yaml",
|
||||
"/swagger.json",
|
||||
"/swagger.yaml",
|
||||
]:
|
||||
if spec_path.endswith(suffix):
|
||||
base_url = spec_path[: -len(suffix)]
|
||||
verbose_logger.info(
|
||||
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
if spec_path.split("/")[-1].endswith((".json", ".yaml", ".yml")):
|
||||
base_url = "/".join(spec_path.split("/")[:-1])
|
||||
verbose_logger.info(
|
||||
f"No server info in OpenAPI spec. Using derived base URL: {base_url}"
|
||||
)
|
||||
return base_url
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def _resolve_ref(
|
||||
param: Dict[str, Any], component_params: Dict[str, Any]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Resolve a single parameter, following a $ref if present.
|
||||
|
||||
Returns the resolved param dict, or None if the $ref target is absent from
|
||||
components (so callers can skip/filter it rather than propagating a stub
|
||||
with name=None that would corrupt deduplication).
|
||||
"""
|
||||
ref = param.get("$ref", "")
|
||||
if not ref.startswith("#/components/parameters/"):
|
||||
return param
|
||||
return component_params.get(ref.split("/")[-1])
|
||||
|
||||
|
||||
def _resolve_param_list(
|
||||
raw: List[Dict[str, Any]], component_params: Dict[str, Any]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Resolve $refs in a parameter list, dropping any unresolvable entries."""
|
||||
result = []
|
||||
for p in raw:
|
||||
resolved = _resolve_ref(p, component_params)
|
||||
if resolved is not None and resolved.get("name"):
|
||||
result.append(resolved)
|
||||
return result
|
||||
|
||||
|
||||
def resolve_operation_params(
|
||||
operation: Dict[str, Any],
|
||||
path_item: Dict[str, Any],
|
||||
components: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""Return a copy of *operation* with fully-resolved, merged parameters.
|
||||
|
||||
Handles two common patterns in real-world OpenAPI specs:
|
||||
|
||||
1. **$ref parameters** — ``{"$ref": "#/components/parameters/per-page"}``
|
||||
instead of inline objects. Each ref is resolved against
|
||||
``components["parameters"]``; unresolvable refs are silently dropped so
|
||||
they cannot corrupt the deduplication set with ``(None, None)`` keys.
|
||||
|
||||
2. **Path-level parameters** — params defined on the path item that apply
|
||||
to every HTTP method on that path (e.g. ``owner``, ``repo``). They are
|
||||
merged with the operation-level params; operation-level wins when the
|
||||
same ``name`` + ``in`` combination appears in both.
|
||||
"""
|
||||
component_params = components.get("parameters", {})
|
||||
path_level = _resolve_param_list(path_item.get("parameters", []), component_params)
|
||||
op_level = _resolve_param_list(operation.get("parameters", []), component_params)
|
||||
op_keys = {(p["name"], p.get("in")) for p in op_level}
|
||||
merged = [
|
||||
p for p in path_level if (p["name"], p.get("in")) not in op_keys
|
||||
] + op_level
|
||||
result = dict(operation)
|
||||
result["parameters"] = merged
|
||||
return result
|
||||
|
||||
|
||||
def extract_parameters(operation: Dict[str, Any]) -> tuple:
|
||||
"""Extract parameter names from OpenAPI operation."""
|
||||
path_params = []
|
||||
query_params = []
|
||||
body_params = []
|
||||
|
||||
# OpenAPI 3.x and 2.x parameters
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
if "name" not in param:
|
||||
continue
|
||||
param_name = param["name"]
|
||||
if param.get("in") == "path":
|
||||
path_params.append(param_name)
|
||||
elif param.get("in") == "query":
|
||||
query_params.append(param_name)
|
||||
elif param.get("in") == "body":
|
||||
body_params.append(param_name)
|
||||
|
||||
# OpenAPI 3.x requestBody
|
||||
if "requestBody" in operation:
|
||||
body_params.append("body")
|
||||
|
||||
return path_params, query_params, body_params
|
||||
|
||||
|
||||
def build_input_schema(operation: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build MCP input schema from OpenAPI operation."""
|
||||
properties = {}
|
||||
required = []
|
||||
|
||||
# Process parameters
|
||||
if "parameters" in operation:
|
||||
for param in operation["parameters"]:
|
||||
if "name" not in param:
|
||||
continue
|
||||
param_name = param["name"]
|
||||
param_schema = param.get("schema", {})
|
||||
param_type = param_schema.get("type", "string")
|
||||
|
||||
properties[param_name] = {
|
||||
"type": param_type,
|
||||
"description": param.get("description", ""),
|
||||
}
|
||||
|
||||
if param.get("required", False):
|
||||
required.append(param_name)
|
||||
|
||||
# Process requestBody (OpenAPI 3.x)
|
||||
if "requestBody" in operation:
|
||||
request_body = operation["requestBody"]
|
||||
content = request_body.get("content", {})
|
||||
|
||||
# Try to get JSON schema
|
||||
if "application/json" in content:
|
||||
schema = content["application/json"].get("schema", {})
|
||||
properties["body"] = {
|
||||
"type": "object",
|
||||
"description": request_body.get("description", "Request body"),
|
||||
"properties": schema.get("properties", {}),
|
||||
}
|
||||
if request_body.get("required", False):
|
||||
required.append("body")
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required if required else [],
|
||||
}
|
||||
|
||||
|
||||
def create_tool_function(
|
||||
path: str,
|
||||
method: str,
|
||||
operation: Dict[str, Any],
|
||||
base_url: str,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
"""Create a tool function for an OpenAPI operation.
|
||||
|
||||
This function creates an async tool function that can be called with
|
||||
keyword arguments. Parameter names from the OpenAPI spec are accessed
|
||||
directly via **kwargs, avoiding syntax errors from invalid Python identifiers.
|
||||
|
||||
Args:
|
||||
path: API endpoint path
|
||||
method: HTTP method (get, post, put, delete, patch)
|
||||
operation: OpenAPI operation object
|
||||
base_url: Base URL for the API
|
||||
headers: Optional headers to include in requests (e.g., authentication)
|
||||
|
||||
Returns:
|
||||
An async function that accepts **kwargs and makes the HTTP request
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
path_params, query_params, body_params = extract_parameters(operation)
|
||||
original_method = method.lower()
|
||||
|
||||
async def tool_function(**kwargs: Any) -> str:
|
||||
"""
|
||||
Dynamically generated tool function.
|
||||
|
||||
Accepts keyword arguments where keys are the original OpenAPI parameter names.
|
||||
The function safely handles parameter names that aren't valid Python identifiers
|
||||
by using **kwargs instead of named parameters.
|
||||
"""
|
||||
# Allow per-request auth override (e.g. BYOK credential set via ContextVar).
|
||||
# The ContextVar holds the full Authorization header value, including the
|
||||
# correct prefix (Bearer / ApiKey / Basic) formatted by the caller in
|
||||
# server.py based on the server's configured auth_type.
|
||||
effective_headers = dict(headers)
|
||||
override_auth = _request_auth_header.get()
|
||||
if override_auth:
|
||||
effective_headers["Authorization"] = override_auth
|
||||
|
||||
# Build URL from base_url and path
|
||||
url = base_url + path
|
||||
|
||||
# Replace path parameters using original names from OpenAPI spec
|
||||
# Apply path traversal validation and URL encoding
|
||||
for param_name in path_params:
|
||||
param_value = kwargs.get(param_name, "")
|
||||
if param_value:
|
||||
try:
|
||||
# Sanitize and encode path parameter to prevent traversal attacks
|
||||
safe_value = _sanitize_path_parameter_value(param_value, param_name)
|
||||
except ValueError as exc:
|
||||
return "Invalid path parameter: " + str(exc)
|
||||
# Replace {param_name} or {{param_name}} in URL
|
||||
url = url.replace("{" + param_name + "}", safe_value)
|
||||
url = url.replace("{{" + param_name + "}}", safe_value)
|
||||
|
||||
# Build query params using original parameter names
|
||||
params: Dict[str, Any] = {}
|
||||
for param_name in query_params:
|
||||
param_value = kwargs.get(param_name, "")
|
||||
if param_value:
|
||||
# Use original parameter name in query string (as expected by API)
|
||||
params[param_name] = param_value
|
||||
|
||||
# Build request body
|
||||
json_body: Optional[Dict[str, Any]] = None
|
||||
if body_params:
|
||||
# Try "body" first (most common), then check all body param names
|
||||
body_value = kwargs.get("body", {})
|
||||
if not body_value:
|
||||
for param_name in body_params:
|
||||
body_value = kwargs.get(param_name, {})
|
||||
if body_value:
|
||||
break
|
||||
|
||||
if isinstance(body_value, dict):
|
||||
json_body = body_value
|
||||
elif body_value:
|
||||
# If it's a string, try to parse as JSON
|
||||
try:
|
||||
json_body = (
|
||||
json.loads(body_value)
|
||||
if isinstance(body_value, str)
|
||||
else {"data": body_value}
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
json_body = {"data": body_value}
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.MCP)
|
||||
|
||||
if original_method == "get":
|
||||
response = await client.get(url, params=params, headers=effective_headers)
|
||||
elif original_method == "post":
|
||||
response = await client.post(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
elif original_method == "put":
|
||||
response = await client.put(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
elif original_method == "delete":
|
||||
response = await client.delete(
|
||||
url, params=params, headers=effective_headers
|
||||
)
|
||||
elif original_method == "patch":
|
||||
response = await client.patch(
|
||||
url, params=params, json=json_body, headers=effective_headers
|
||||
)
|
||||
else:
|
||||
return f"Unsupported HTTP method: {original_method}"
|
||||
|
||||
return response.text
|
||||
|
||||
return tool_function
|
||||
|
||||
|
||||
def register_tools_from_openapi(spec: Dict[str, Any], base_url: str):
|
||||
"""Register MCP tools from OpenAPI specification."""
|
||||
paths = spec.get("paths", {})
|
||||
|
||||
for path, path_item in paths.items():
|
||||
for method in ["get", "post", "put", "delete", "patch"]:
|
||||
if method in path_item:
|
||||
operation = path_item[method]
|
||||
|
||||
# Generate tool name
|
||||
operation_id = operation.get(
|
||||
"operationId", f"{method}_{path.replace('/', '_')}"
|
||||
)
|
||||
tool_name = operation_id.replace(" ", "_").lower()
|
||||
|
||||
# Get description
|
||||
description = operation.get(
|
||||
"summary", operation.get("description", f"{method.upper()} {path}")
|
||||
)
|
||||
|
||||
# Build input schema
|
||||
input_schema = build_input_schema(operation)
|
||||
|
||||
# Create tool function
|
||||
tool_func = create_tool_function(path, method, operation, base_url)
|
||||
tool_func.__name__ = tool_name
|
||||
tool_func.__doc__ = description
|
||||
|
||||
# Register tool with local registry
|
||||
global_mcp_tool_registry.register_tool(
|
||||
name=tool_name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=tool_func,
|
||||
)
|
||||
verbose_logger.debug(f"Registered tool: {tool_name}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
"""
|
||||
Semantic MCP Tool Filtering using semantic-router
|
||||
|
||||
Filters MCP tools semantically for /chat/completions and /responses endpoints.
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from semantic_router.routers import SemanticRouter
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
class SemanticMCPToolFilter:
|
||||
"""Filters MCP tools using semantic similarity to reduce context window size."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_model: str,
|
||||
litellm_router_instance: "Router",
|
||||
top_k: int = 10,
|
||||
similarity_threshold: float = 0.3,
|
||||
enabled: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the semantic tool filter.
|
||||
|
||||
Args:
|
||||
embedding_model: Model to use for embeddings (e.g., "text-embedding-3-small")
|
||||
litellm_router_instance: Router instance for embedding generation
|
||||
top_k: Maximum number of tools to return
|
||||
similarity_threshold: Minimum similarity score for filtering
|
||||
enabled: Whether filtering is enabled
|
||||
"""
|
||||
self.enabled = enabled
|
||||
self.top_k = top_k
|
||||
self.similarity_threshold = similarity_threshold
|
||||
self.embedding_model = embedding_model
|
||||
self.router_instance = litellm_router_instance
|
||||
self.tool_router: Optional["SemanticRouter"] = None
|
||||
self._tool_map: Dict[str, Any] = {} # MCPTool objects or OpenAI function dicts
|
||||
|
||||
async def build_router_from_mcp_registry(self) -> None:
|
||||
"""Build semantic router from all MCP tools in the registry (no auth checks)."""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
try:
|
||||
# Get all servers from registry without auth checks
|
||||
registry = global_mcp_server_manager.get_registry()
|
||||
if not registry:
|
||||
verbose_logger.warning("MCP registry is empty")
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
# Fetch tools from all servers in parallel
|
||||
all_tools = []
|
||||
for server_id, server in registry.items():
|
||||
try:
|
||||
tools = await global_mcp_server_manager.get_tools_for_server(
|
||||
server_id
|
||||
)
|
||||
all_tools.extend(tools)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to fetch tools from server {server_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not all_tools:
|
||||
verbose_logger.warning("No MCP tools found in registry")
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
verbose_logger.info(
|
||||
f"Fetched {len(all_tools)} tools from {len(registry)} MCP servers"
|
||||
)
|
||||
self._build_router(all_tools)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to build router from MCP registry: {e}")
|
||||
self.tool_router = None
|
||||
raise
|
||||
|
||||
def _extract_tool_info(self, tool) -> tuple[str, str]:
|
||||
"""Extract name and description from MCP tool or OpenAI function dict."""
|
||||
name: str
|
||||
description: str
|
||||
|
||||
if isinstance(tool, dict):
|
||||
# OpenAI function format
|
||||
name = tool.get("name", "")
|
||||
description = tool.get("description", name)
|
||||
else:
|
||||
# MCPTool object
|
||||
name = str(tool.name)
|
||||
description = str(tool.description) if tool.description else str(tool.name)
|
||||
|
||||
return name, description
|
||||
|
||||
def _build_router(self, tools: List) -> None:
|
||||
"""Build semantic router with tools (MCPTool objects or OpenAI function dicts)."""
|
||||
from semantic_router.routers import SemanticRouter
|
||||
from semantic_router.routers.base import Route
|
||||
|
||||
from litellm.router_strategy.auto_router.litellm_encoder import (
|
||||
LiteLLMRouterEncoder,
|
||||
)
|
||||
|
||||
if not tools:
|
||||
self.tool_router = None
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert tools to routes
|
||||
routes = []
|
||||
self._tool_map = {}
|
||||
|
||||
for tool in tools:
|
||||
name, description = self._extract_tool_info(tool)
|
||||
self._tool_map[name] = tool
|
||||
|
||||
routes.append(
|
||||
Route(
|
||||
name=name,
|
||||
description=description,
|
||||
utterances=[description],
|
||||
score_threshold=self.similarity_threshold,
|
||||
)
|
||||
)
|
||||
|
||||
self.tool_router = SemanticRouter(
|
||||
routes=routes,
|
||||
encoder=LiteLLMRouterEncoder(
|
||||
litellm_router_instance=self.router_instance,
|
||||
model_name=self.embedding_model,
|
||||
score_threshold=self.similarity_threshold,
|
||||
),
|
||||
auto_sync="local",
|
||||
)
|
||||
|
||||
verbose_logger.info(f"Built semantic router with {len(routes)} tools")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Failed to build semantic router: {e}")
|
||||
self.tool_router = None
|
||||
raise
|
||||
|
||||
async def filter_tools(
|
||||
self,
|
||||
query: str,
|
||||
available_tools: List[Any],
|
||||
top_k: Optional[int] = None,
|
||||
) -> List[Any]:
|
||||
"""
|
||||
Filter tools semantically based on query.
|
||||
|
||||
Args:
|
||||
query: User query to match against tools
|
||||
available_tools: Full list of available MCP tools
|
||||
top_k: Override default top_k (optional)
|
||||
|
||||
Returns:
|
||||
Filtered and ordered list of tools (up to top_k)
|
||||
"""
|
||||
# Early returns for cases where we can't/shouldn't filter
|
||||
if not self.enabled:
|
||||
return available_tools
|
||||
|
||||
if not available_tools:
|
||||
return available_tools
|
||||
|
||||
if not query or not query.strip():
|
||||
return available_tools
|
||||
|
||||
# Router should be built on startup - if not, something went wrong
|
||||
if self.tool_router is None:
|
||||
verbose_logger.warning(
|
||||
"Router not initialized - was build_router_from_mcp_registry() called on startup?"
|
||||
)
|
||||
return available_tools
|
||||
|
||||
# Run semantic filtering
|
||||
try:
|
||||
limit = top_k or self.top_k
|
||||
matches = self.tool_router(text=query, limit=limit)
|
||||
matched_tool_names = self._extract_tool_names_from_matches(matches)
|
||||
|
||||
if not matched_tool_names:
|
||||
return available_tools
|
||||
|
||||
return self._get_tools_by_names(matched_tool_names, available_tools)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Semantic tool filter failed: {e}", exc_info=True)
|
||||
return available_tools
|
||||
|
||||
def _extract_tool_names_from_matches(self, matches) -> List[str]:
|
||||
"""Extract tool names from semantic router match results."""
|
||||
if not matches:
|
||||
return []
|
||||
|
||||
# Handle single match
|
||||
if hasattr(matches, "name") and matches.name:
|
||||
return [matches.name]
|
||||
|
||||
# Handle list of matches
|
||||
if isinstance(matches, list):
|
||||
return [m.name for m in matches if hasattr(m, "name") and m.name]
|
||||
|
||||
return []
|
||||
|
||||
def _get_tools_by_names(
|
||||
self, tool_names: List[str], available_tools: List[Any]
|
||||
) -> List[Any]:
|
||||
"""Get tools from available_tools by their names, preserving order."""
|
||||
# Match tools from available_tools (preserves format - dict or MCPTool)
|
||||
matched_tools = []
|
||||
for tool in available_tools:
|
||||
tool_name, _ = self._extract_tool_info(tool)
|
||||
if tool_name in tool_names:
|
||||
matched_tools.append(tool)
|
||||
|
||||
# Reorder to match semantic router's ordering
|
||||
tool_map = {self._extract_tool_info(t)[0]: t for t in matched_tools}
|
||||
return [tool_map[name] for name in tool_names if name in tool_map]
|
||||
|
||||
def extract_user_query(self, messages: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
Extract user query from messages for /chat/completions or /responses.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries (from 'messages' or 'input' field)
|
||||
|
||||
Returns:
|
||||
Extracted query string
|
||||
"""
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
if isinstance(content, list):
|
||||
texts = [
|
||||
block.get("text", "") if isinstance(block, dict) else str(block)
|
||||
for block in content
|
||||
if isinstance(block, (dict, str))
|
||||
]
|
||||
return " ".join(texts)
|
||||
|
||||
return ""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
This is a modification of code from: https://github.com/SecretiveShell/MCP-Bridge/blob/master/mcp_bridge/mcp_server/sse_transport.py
|
||||
|
||||
Credit to the maintainers of SecretiveShell for their SSE Transport implementation
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import anyio
|
||||
import mcp.types as types
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
|
||||
from fastapi.requests import Request
|
||||
from fastapi.responses import Response
|
||||
from pydantic import ValidationError
|
||||
from sse_starlette import EventSourceResponse
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class SseServerTransport:
|
||||
"""
|
||||
SSE server transport for MCP. This class provides _two_ ASGI applications,
|
||||
suitable to be used with a framework like Starlette and a server like Hypercorn:
|
||||
|
||||
1. connect_sse() is an ASGI application which receives incoming GET requests,
|
||||
and sets up a new SSE stream to send server messages to the client.
|
||||
2. handle_post_message() is an ASGI application which receives incoming POST
|
||||
requests, which should contain client messages that link to a
|
||||
previously-established SSE session.
|
||||
"""
|
||||
|
||||
_endpoint: str
|
||||
_read_stream_writers: dict[
|
||||
UUID, MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
]
|
||||
|
||||
def __init__(self, endpoint: str) -> None:
|
||||
"""
|
||||
Creates a new SSE server transport, which will direct the client to POST
|
||||
messages to the relative or absolute URL given.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self._endpoint = endpoint
|
||||
self._read_stream_writers = {}
|
||||
verbose_logger.debug(
|
||||
f"SseServerTransport initialized with endpoint: {endpoint}"
|
||||
)
|
||||
|
||||
@asynccontextmanager
|
||||
async def connect_sse(self, request: Request):
|
||||
if request.scope["type"] != "http":
|
||||
verbose_logger.error("connect_sse received non-HTTP request")
|
||||
raise ValueError("connect_sse can only handle HTTP requests")
|
||||
|
||||
verbose_logger.debug("Setting up SSE connection")
|
||||
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
|
||||
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
|
||||
|
||||
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
|
||||
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
|
||||
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
session_id = uuid4()
|
||||
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
|
||||
self._read_stream_writers[session_id] = read_stream_writer
|
||||
verbose_logger.debug(f"Created new session with ID: {session_id}")
|
||||
|
||||
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]]
|
||||
sse_stream_reader: MemoryObjectReceiveStream[dict[str, Any]]
|
||||
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream(
|
||||
0, dict[str, Any]
|
||||
)
|
||||
|
||||
async def sse_writer():
|
||||
verbose_logger.debug("Starting SSE writer")
|
||||
async with sse_stream_writer, write_stream_reader:
|
||||
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
|
||||
verbose_logger.debug(f"Sent endpoint event: {session_uri}")
|
||||
|
||||
async for message in write_stream_reader:
|
||||
verbose_logger.debug(f"Sending message via SSE: {message}")
|
||||
await sse_stream_writer.send(
|
||||
{
|
||||
"event": "message",
|
||||
"data": message.model_dump_json(
|
||||
by_alias=True, exclude_none=True
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
response = EventSourceResponse(
|
||||
content=sse_stream_reader, data_sender_callable=sse_writer
|
||||
)
|
||||
verbose_logger.debug("Starting SSE response task")
|
||||
tg.start_soon(response, request.scope, request.receive, request._send)
|
||||
|
||||
verbose_logger.debug("Yielding read and write streams")
|
||||
yield (read_stream, write_stream)
|
||||
|
||||
async def handle_post_message(
|
||||
self, scope: Scope, receive: Receive, send: Send
|
||||
) -> Response:
|
||||
verbose_logger.debug("Handling POST message")
|
||||
request = Request(scope, receive)
|
||||
|
||||
session_id_param = request.query_params.get("session_id")
|
||||
if session_id_param is None:
|
||||
verbose_logger.warning("Received request without session_id")
|
||||
response = Response("session_id is required", status_code=400)
|
||||
return response
|
||||
|
||||
try:
|
||||
session_id = UUID(hex=session_id_param)
|
||||
verbose_logger.debug(f"Parsed session ID: {session_id}")
|
||||
except ValueError:
|
||||
verbose_logger.warning(f"Received invalid session ID: {session_id_param}")
|
||||
response = Response("Invalid session ID", status_code=400)
|
||||
return response
|
||||
|
||||
writer = self._read_stream_writers.get(session_id)
|
||||
if not writer:
|
||||
verbose_logger.warning(f"Could not find session for ID: {session_id}")
|
||||
response = Response("Could not find session", status_code=404)
|
||||
return response
|
||||
|
||||
json = await request.json()
|
||||
verbose_logger.debug(f"Received JSON: {json}")
|
||||
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate(json)
|
||||
verbose_logger.debug(f"Validated client message: {message}")
|
||||
except ValidationError as err:
|
||||
verbose_logger.error(f"Failed to parse message: {err}")
|
||||
response = Response("Could not parse message", status_code=400)
|
||||
await writer.send(err)
|
||||
return response
|
||||
|
||||
verbose_logger.debug(f"Sending message to writer: {message}")
|
||||
response = Response("Accepted", status_code=202)
|
||||
await writer.send(message)
|
||||
return response
|
||||
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy.types_utils.utils import get_instance_fn
|
||||
from litellm.types.mcp_server.tool_registry import MCPTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from mcp.types import Tool as MCPToolSDKTool
|
||||
else:
|
||||
try:
|
||||
from mcp.types import Tool as MCPToolSDKTool
|
||||
except ImportError:
|
||||
MCPToolSDKTool = None # type: ignore
|
||||
|
||||
|
||||
class MCPToolRegistry:
|
||||
"""
|
||||
A registry for managing MCP tools
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# Registry to store all registered tools
|
||||
self.tools: Dict[str, MCPTool] = {}
|
||||
|
||||
def register_tool(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
input_schema: Dict[str, Any],
|
||||
handler: Callable,
|
||||
) -> None:
|
||||
"""
|
||||
Register a new tool in the registry
|
||||
"""
|
||||
self.tools[name] = MCPTool(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=handler,
|
||||
)
|
||||
verbose_logger.debug(f"Registered tool: {name}")
|
||||
|
||||
def get_tool(self, name: str) -> Optional[MCPTool]:
|
||||
"""
|
||||
Get a tool from the registry by name
|
||||
"""
|
||||
return self.tools.get(name)
|
||||
|
||||
def list_tools(self, tool_prefix: Optional[str] = None) -> List[MCPTool]:
|
||||
"""
|
||||
List all registered tools
|
||||
"""
|
||||
if tool_prefix:
|
||||
return [
|
||||
tool
|
||||
for tool in self.tools.values()
|
||||
if tool.name.startswith(tool_prefix)
|
||||
]
|
||||
return list(self.tools.values())
|
||||
|
||||
def convert_tools_to_mcp_sdk_tool_type(
|
||||
self, tools: List[MCPTool]
|
||||
) -> List["MCPToolSDKTool"]:
|
||||
if MCPToolSDKTool is None:
|
||||
raise ImportError(
|
||||
"MCP SDK is not installed. Please install it with: pip install 'litellm[proxy]'"
|
||||
)
|
||||
return [
|
||||
MCPToolSDKTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
inputSchema=tool.input_schema,
|
||||
)
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
def load_tools_from_config(
|
||||
self, mcp_tools_config: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Load and register tools from the proxy config
|
||||
|
||||
Args:
|
||||
mcp_tools_config: The mcp_tools config from the proxy config
|
||||
"""
|
||||
if mcp_tools_config is None:
|
||||
raise ValueError(
|
||||
"mcp_tools_config is required, please set `mcp_tools` in your proxy config"
|
||||
)
|
||||
|
||||
for tool_config in mcp_tools_config:
|
||||
if not isinstance(tool_config, dict):
|
||||
raise ValueError("mcp_tools_config must be a list of dictionaries")
|
||||
|
||||
name = tool_config.get("name")
|
||||
description = tool_config.get("description")
|
||||
input_schema = tool_config.get("input_schema", {})
|
||||
handler_name = tool_config.get("handler")
|
||||
|
||||
if not all([name, description, handler_name]):
|
||||
continue
|
||||
|
||||
# Try to resolve the handler
|
||||
# First check if it's a module path (e.g., "module.submodule.function")
|
||||
if handler_name is None:
|
||||
raise ValueError(f"handler is required for tool {name}")
|
||||
handler = get_instance_fn(handler_name)
|
||||
|
||||
if handler is None:
|
||||
verbose_logger.warning(
|
||||
f"Warning: Could not find handler {handler_name} for tool {name}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Register the tool
|
||||
if name is None:
|
||||
raise ValueError(f"name is required for tool {name}")
|
||||
if description is None:
|
||||
raise ValueError(f"description is required for tool {name}")
|
||||
|
||||
self.register_tool(
|
||||
name=name,
|
||||
description=description,
|
||||
input_schema=input_schema,
|
||||
handler=handler,
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"all registered tools: %s", json.dumps(self.tools, indent=4, default=str)
|
||||
)
|
||||
|
||||
|
||||
global_mcp_tool_registry = MCPToolRegistry()
|
||||
@@ -0,0 +1,85 @@
|
||||
"""Helpers to resolve real team contexts for UI session tokens."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import UI_SESSION_TOKEN_TEAM_ID
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
|
||||
def clone_user_api_key_auth_with_team(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
team_id: str,
|
||||
) -> UserAPIKeyAuth:
|
||||
"""Return a deep copy of the auth context with a different team id."""
|
||||
|
||||
try:
|
||||
cloned_auth = user_api_key_auth.model_copy()
|
||||
except AttributeError:
|
||||
cloned_auth = user_api_key_auth.copy() # type: ignore[attr-defined]
|
||||
cloned_auth.team_id = team_id
|
||||
return cloned_auth
|
||||
|
||||
|
||||
async def resolve_ui_session_team_ids(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
) -> List[str]:
|
||||
"""Resolve the real team ids backing a UI session token."""
|
||||
|
||||
if (
|
||||
user_api_key_auth.team_id != UI_SESSION_TOKEN_TEAM_ID
|
||||
or not user_api_key_auth.user_id
|
||||
):
|
||||
return []
|
||||
|
||||
from litellm.proxy.auth.auth_checks import get_user_object
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("Cannot resolve UI session team ids without DB access")
|
||||
return []
|
||||
|
||||
try:
|
||||
user_obj = await get_user_object(
|
||||
user_id=user_api_key_auth.user_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
user_id_upsert=False,
|
||||
parent_otel_span=user_api_key_auth.parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
verbose_logger.warning(
|
||||
"Failed to load teams for UI session token user.",
|
||||
exc,
|
||||
)
|
||||
return []
|
||||
|
||||
if user_obj is None or not user_obj.teams:
|
||||
return []
|
||||
|
||||
resolved_team_ids: List[str] = []
|
||||
for team_id in user_obj.teams:
|
||||
if team_id and team_id not in resolved_team_ids:
|
||||
resolved_team_ids.append(team_id)
|
||||
return resolved_team_ids
|
||||
|
||||
|
||||
async def build_effective_auth_contexts(
|
||||
user_api_key_auth: UserAPIKeyAuth,
|
||||
) -> List[UserAPIKeyAuth]:
|
||||
"""Return auth contexts that reflect the actual teams for UI session tokens."""
|
||||
|
||||
resolved_team_ids = await resolve_ui_session_team_ids(user_api_key_auth)
|
||||
if resolved_team_ids:
|
||||
return [
|
||||
clone_user_api_key_auth_with_team(user_api_key_auth, team_id)
|
||||
for team_id in resolved_team_ids
|
||||
]
|
||||
return [user_api_key_auth]
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
MCP Server Utilities
|
||||
"""
|
||||
from typing import Any, Dict, Mapping, Optional, Tuple
|
||||
|
||||
import os
|
||||
import importlib
|
||||
|
||||
# Constants
|
||||
LITELLM_MCP_SERVER_NAME = "litellm-mcp-server"
|
||||
LITELLM_MCP_SERVER_VERSION = "1.0.0"
|
||||
LITELLM_MCP_SERVER_DESCRIPTION = "MCP Server for LiteLLM"
|
||||
MCP_TOOL_PREFIX_SEPARATOR = os.environ.get("MCP_TOOL_PREFIX_SEPARATOR", "-")
|
||||
MCP_TOOL_PREFIX_FORMAT = "{server_name}{separator}{tool_name}"
|
||||
|
||||
|
||||
def is_mcp_available() -> bool:
|
||||
"""
|
||||
Returns True if the MCP module is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
importlib.import_module("mcp")
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def normalize_server_name(server_name: str) -> str:
|
||||
"""
|
||||
Normalize server name by replacing spaces with underscores
|
||||
"""
|
||||
return server_name.replace(" ", "_")
|
||||
|
||||
|
||||
def validate_and_normalize_mcp_server_payload(payload: Any) -> None:
|
||||
"""
|
||||
Validate and normalize MCP server payload fields (server_name and alias).
|
||||
|
||||
This function:
|
||||
1. Validates that server_name and alias don't contain the MCP_TOOL_PREFIX_SEPARATOR
|
||||
2. Normalizes alias by replacing spaces with underscores
|
||||
3. Sets default alias if not provided (using server_name as base)
|
||||
|
||||
Args:
|
||||
payload: The payload object containing server_name and alias fields
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
# Server name validation: disallow '-'
|
||||
if hasattr(payload, "server_name") and payload.server_name:
|
||||
validate_mcp_server_name(payload.server_name, raise_http_exception=True)
|
||||
|
||||
# Alias validation: disallow '-'
|
||||
if hasattr(payload, "alias") and payload.alias:
|
||||
validate_mcp_server_name(payload.alias, raise_http_exception=True)
|
||||
|
||||
# Alias normalization and defaulting
|
||||
alias = getattr(payload, "alias", None)
|
||||
server_name = getattr(payload, "server_name", None)
|
||||
|
||||
if not alias and server_name:
|
||||
alias = normalize_server_name(server_name)
|
||||
elif alias:
|
||||
alias = normalize_server_name(alias)
|
||||
|
||||
# Update the payload with normalized alias
|
||||
if hasattr(payload, "alias"):
|
||||
payload.alias = alias
|
||||
|
||||
|
||||
def add_server_prefix_to_name(name: str, server_name: str) -> str:
|
||||
"""Add server name prefix to any MCP resource name."""
|
||||
formatted_server_name = normalize_server_name(server_name)
|
||||
|
||||
return MCP_TOOL_PREFIX_FORMAT.format(
|
||||
server_name=formatted_server_name,
|
||||
separator=MCP_TOOL_PREFIX_SEPARATOR,
|
||||
tool_name=name,
|
||||
)
|
||||
|
||||
|
||||
def get_server_prefix(server: Any) -> str:
|
||||
"""Return the prefix for a server: alias if present, else server_name, else server_id"""
|
||||
if hasattr(server, "alias") and server.alias:
|
||||
return server.alias
|
||||
if hasattr(server, "server_name") and server.server_name:
|
||||
return server.server_name
|
||||
if hasattr(server, "server_id"):
|
||||
return server.server_id
|
||||
return ""
|
||||
|
||||
|
||||
def split_server_prefix_from_name(prefixed_name: str) -> Tuple[str, str]:
|
||||
"""Return the unprefixed name plus the server name used as prefix."""
|
||||
if MCP_TOOL_PREFIX_SEPARATOR in prefixed_name:
|
||||
parts = prefixed_name.split(MCP_TOOL_PREFIX_SEPARATOR, 1)
|
||||
if len(parts) == 2:
|
||||
return parts[1], parts[0]
|
||||
return prefixed_name, ""
|
||||
|
||||
|
||||
def is_tool_name_prefixed(tool_name: str) -> bool:
|
||||
"""
|
||||
Check if tool name has server prefix
|
||||
|
||||
Args:
|
||||
tool_name: Tool name to check
|
||||
|
||||
Returns:
|
||||
True if tool name is prefixed, False otherwise
|
||||
"""
|
||||
return MCP_TOOL_PREFIX_SEPARATOR in tool_name
|
||||
|
||||
|
||||
def validate_mcp_server_name(
|
||||
server_name: str, raise_http_exception: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that MCP server name does not contain 'MCP_TOOL_PREFIX_SEPARATOR'.
|
||||
|
||||
Args:
|
||||
server_name: The server name to validate
|
||||
raise_http_exception: If True, raises HTTPException instead of generic Exception
|
||||
|
||||
Raises:
|
||||
Exception or HTTPException: If server name contains 'MCP_TOOL_PREFIX_SEPARATOR'
|
||||
"""
|
||||
if server_name and MCP_TOOL_PREFIX_SEPARATOR in server_name:
|
||||
error_message = f"Server name cannot contain '{MCP_TOOL_PREFIX_SEPARATOR}'. Use an alternative character instead Found: {server_name}"
|
||||
if raise_http_exception:
|
||||
from fastapi import HTTPException
|
||||
from starlette import status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail={"error": error_message}
|
||||
)
|
||||
else:
|
||||
raise Exception(error_message)
|
||||
|
||||
|
||||
def merge_mcp_headers(
|
||||
*,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
static_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Merge outbound HTTP headers for MCP calls.
|
||||
|
||||
This is used when calling out to external MCP servers (or OpenAPI-based MCP tools).
|
||||
|
||||
Merge rules:
|
||||
- Start with `extra_headers` (typically OAuth2-derived headers)
|
||||
- Overlay `static_headers` (user-configured per MCP server)
|
||||
|
||||
If both contain the same key, `static_headers` wins. This matches the existing
|
||||
behavior in `MCPServerManager` where `server.static_headers` is applied after
|
||||
any caller-provided headers.
|
||||
"""
|
||||
merged: Dict[str, str] = {}
|
||||
|
||||
if extra_headers:
|
||||
merged.update({str(k): str(v) for k, v in extra_headers.items()})
|
||||
|
||||
if static_headers:
|
||||
merged.update({str(k): str(v) for k, v in static_headers.items()})
|
||||
|
||||
return merged or None
|
||||
@@ -0,0 +1,4 @@
|
||||
def my_custom_rule(input): # receives the model response
|
||||
# if len(input) < 5: # trigger fallback if the model response is too short
|
||||
return False
|
||||
return True
|
||||
Reference in New Issue
Block a user