chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""Command groups for the LiteLLM proxy CLI."""
|
||||
@@ -0,0 +1,623 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from litellm.constants import CLI_JWT_EXPIRATION_HOURS
|
||||
|
||||
|
||||
# Token storage utilities
|
||||
def get_token_file_path() -> str:
|
||||
"""Get the path to store the authentication token"""
|
||||
home_dir = Path.home()
|
||||
config_dir = home_dir / ".litellm"
|
||||
config_dir.mkdir(exist_ok=True)
|
||||
return str(config_dir / "token.json")
|
||||
|
||||
|
||||
def save_token(token_data: Dict[str, Any]) -> None:
|
||||
"""Save token data to file"""
|
||||
token_file = get_token_file_path()
|
||||
with open(token_file, "w") as f:
|
||||
json.dump(token_data, f, indent=2)
|
||||
# Set file permissions to be readable only by owner
|
||||
os.chmod(token_file, 0o600)
|
||||
|
||||
|
||||
def load_token() -> Optional[Dict[str, Any]]:
|
||||
"""Load token data from file"""
|
||||
token_file = get_token_file_path()
|
||||
if not os.path.exists(token_file):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(token_file, "r") as f:
|
||||
return json.load(f)
|
||||
except (json.JSONDecodeError, IOError):
|
||||
return None
|
||||
|
||||
|
||||
def clear_token() -> None:
|
||||
"""Clear stored token"""
|
||||
token_file = get_token_file_path()
|
||||
if os.path.exists(token_file):
|
||||
os.remove(token_file)
|
||||
|
||||
|
||||
def get_stored_api_key() -> Optional[str]:
|
||||
"""Get the stored API key from token file"""
|
||||
# Use the SDK-level utility
|
||||
from litellm.litellm_core_utils.cli_token_utils import get_litellm_gateway_api_key
|
||||
|
||||
return get_litellm_gateway_api_key()
|
||||
|
||||
|
||||
# Team selection utilities
|
||||
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
|
||||
"""Display teams in a formatted table"""
|
||||
console = Console()
|
||||
|
||||
if not teams:
|
||||
console.print("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Alias", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
table.add_column("Models", style="yellow")
|
||||
table.add_column("Max Budget", style="blue")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
def get_key_input():
|
||||
"""Get a single key input from the user (cross-platform)"""
|
||||
try:
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
|
||||
key = msvcrt.getch()
|
||||
if key == b"\xe0": # Arrow keys on Windows
|
||||
key = msvcrt.getch()
|
||||
if key == b"H": # Up arrow
|
||||
return "up"
|
||||
elif key == b"P": # Down arrow
|
||||
return "down"
|
||||
elif key == b"\r": # Enter key
|
||||
return "enter"
|
||||
elif key == b"\x1b": # Escape key
|
||||
return "escape"
|
||||
elif key == b"q":
|
||||
return "quit"
|
||||
return None
|
||||
else:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
fd = sys.stdin.fileno()
|
||||
old_settings = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(sys.stdin.fileno())
|
||||
key = sys.stdin.read(1)
|
||||
|
||||
if key == "\x1b": # Escape sequence
|
||||
key += sys.stdin.read(2)
|
||||
if key == "\x1b[A": # Up arrow
|
||||
return "up"
|
||||
elif key == "\x1b[B": # Down arrow
|
||||
return "down"
|
||||
elif key == "\x1b": # Just escape
|
||||
return "escape"
|
||||
elif key == "\r" or key == "\n": # Enter key
|
||||
return "enter"
|
||||
elif key == "q":
|
||||
return "quit"
|
||||
return None
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||
except ImportError:
|
||||
# Fallback to simple input if termios/msvcrt not available
|
||||
return None
|
||||
|
||||
|
||||
def display_interactive_team_selection(
|
||||
teams: List[Dict[str, Any]], selected_index: int = 0
|
||||
) -> None:
|
||||
"""Display teams with one highlighted for selection"""
|
||||
console = Console()
|
||||
|
||||
# Clear the screen using Rich's method
|
||||
console.clear()
|
||||
|
||||
console.print("🎯 Select a Team (Use ↑↓ arrows, Enter to select, 'q' to skip):\n")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
# Highlight the selected item
|
||||
if i == selected_index:
|
||||
console.print(f"➤ [bold cyan]{team_alias}[/bold cyan] ({team_id})")
|
||||
console.print(f" Models: [yellow]{models_str}[/yellow]")
|
||||
console.print(f" Budget: [blue]{budget_str}[/blue]\n")
|
||||
else:
|
||||
console.print(f" [dim]{team_alias}[/dim] ({team_id})")
|
||||
console.print(f" Models: [dim]{models_str}[/dim]")
|
||||
console.print(f" Budget: [dim]{budget_str}[/dim]\n")
|
||||
|
||||
|
||||
def prompt_team_selection(teams: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Interactive team selection with arrow keys"""
|
||||
if not teams:
|
||||
return None
|
||||
|
||||
selected_index = 0
|
||||
|
||||
try:
|
||||
# Check if we can use interactive mode
|
||||
if not sys.stdin.isatty():
|
||||
# Fallback to simple selection for non-interactive environments
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
while True:
|
||||
display_interactive_team_selection(teams, selected_index)
|
||||
|
||||
key = get_key_input()
|
||||
|
||||
if key == "up":
|
||||
selected_index = (selected_index - 1) % len(teams)
|
||||
elif key == "down":
|
||||
selected_index = (selected_index + 1) % len(teams)
|
||||
elif key == "enter":
|
||||
selected_team = teams[selected_index]
|
||||
# Clear screen and show selection
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo(
|
||||
f"✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
|
||||
)
|
||||
return selected_team
|
||||
elif key == "quit" or key == "escape":
|
||||
# Clear screen
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo("ℹ️ Team selection skipped.")
|
||||
return None
|
||||
elif key is None:
|
||||
# If we can't get key input, fall back to simple selection
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console = Console()
|
||||
console.clear()
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
except Exception:
|
||||
# If interactive mode fails, fall back to simple selection
|
||||
return prompt_team_selection_fallback(teams)
|
||||
|
||||
|
||||
def prompt_team_selection_fallback(
|
||||
teams: List[Dict[str, Any]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Fallback team selection for non-interactive environments"""
|
||||
if not teams:
|
||||
return None
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = click.prompt(
|
||||
"\nSelect a team by entering the index number (or 'skip' to continue without a team)",
|
||||
type=str,
|
||||
).strip()
|
||||
|
||||
if choice.lower() == "skip":
|
||||
return None
|
||||
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(teams):
|
||||
selected_team = teams[index]
|
||||
click.echo(
|
||||
f"\n✅ Selected team: {selected_team.get('team_alias', 'N/A')} ({selected_team.get('team_id')})"
|
||||
)
|
||||
return selected_team
|
||||
else:
|
||||
click.echo(
|
||||
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
|
||||
)
|
||||
except ValueError:
|
||||
click.echo("❌ Invalid input. Please enter a number or 'skip'")
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
|
||||
|
||||
# Polling-based authentication - no local server needed
|
||||
def _poll_for_ready_data(
|
||||
url: str,
|
||||
*,
|
||||
total_timeout: int = 300,
|
||||
poll_interval: int = 2,
|
||||
request_timeout: int = 10,
|
||||
pending_message: Optional[str] = None,
|
||||
pending_log_every: int = 10,
|
||||
other_status_message: Optional[str] = None,
|
||||
other_status_log_every: int = 10,
|
||||
http_error_log_every: int = 10,
|
||||
connection_error_log_every: int = 10,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
for attempt in range(total_timeout // poll_interval):
|
||||
try:
|
||||
response = requests.get(url, timeout=request_timeout)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
status = data.get("status")
|
||||
if status == "ready":
|
||||
return data
|
||||
if status == "pending":
|
||||
if (
|
||||
pending_message
|
||||
and pending_log_every > 0
|
||||
and attempt % pending_log_every == 0
|
||||
):
|
||||
click.echo(pending_message)
|
||||
elif (
|
||||
other_status_message
|
||||
and other_status_log_every > 0
|
||||
and attempt % other_status_log_every == 0
|
||||
):
|
||||
click.echo(other_status_message)
|
||||
elif http_error_log_every > 0 and attempt % http_error_log_every == 0:
|
||||
click.echo(f"Polling error: HTTP {response.status_code}")
|
||||
except requests.RequestException as e:
|
||||
if (
|
||||
connection_error_log_every > 0
|
||||
and attempt % connection_error_log_every == 0
|
||||
):
|
||||
click.echo(f"Connection error (will retry): {e}")
|
||||
time.sleep(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_teams(teams, team_details):
|
||||
"""If team_details are a
|
||||
|
||||
Args:
|
||||
teams (_type_): _description_
|
||||
team_details (_type_): _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
if isinstance(team_details, list) and team_details:
|
||||
return [
|
||||
{
|
||||
"team_id": i.get("team_id") or i.get("id"),
|
||||
"team_alias": i.get("team_alias"),
|
||||
}
|
||||
for i in team_details
|
||||
if isinstance(i, dict) and (i.get("team_id") or i.get("id"))
|
||||
]
|
||||
if isinstance(teams, list):
|
||||
return [{"team_id": str(t), "team_alias": None} for t in teams]
|
||||
return []
|
||||
|
||||
|
||||
def _poll_for_authentication(base_url: str, key_id: str) -> Optional[dict]:
|
||||
"""
|
||||
Poll the server for authentication completion and handle team selection.
|
||||
|
||||
Returns:
|
||||
Dictionary with authentication data if successful, None otherwise
|
||||
"""
|
||||
poll_url = f"{base_url}/sso/cli/poll/{key_id}"
|
||||
data = _poll_for_ready_data(
|
||||
poll_url,
|
||||
pending_message="Still waiting for authentication...",
|
||||
)
|
||||
if not data:
|
||||
return None
|
||||
if data.get("requires_team_selection"):
|
||||
teams = data.get("teams", [])
|
||||
team_details = data.get("team_details")
|
||||
user_id = data.get("user_id")
|
||||
normalized_teams: List[Dict[str, Any]] = _normalize_teams(teams, team_details)
|
||||
if not normalized_teams:
|
||||
click.echo("⚠️ No teams available for selection.")
|
||||
return None
|
||||
|
||||
# User has multiple teams - let them select
|
||||
jwt_with_team = _handle_team_selection_during_polling(
|
||||
base_url=base_url,
|
||||
key_id=key_id,
|
||||
teams=normalized_teams,
|
||||
)
|
||||
|
||||
# Use the team-specific JWT if selection succeeded
|
||||
if jwt_with_team:
|
||||
return {
|
||||
"api_key": jwt_with_team,
|
||||
"user_id": user_id,
|
||||
"teams": teams,
|
||||
"team_id": None, # Set by server in JWT
|
||||
}
|
||||
|
||||
click.echo("❌ Team selection cancelled or JWT generation failed.")
|
||||
return None
|
||||
|
||||
# JWT is ready (single team or team already selected)
|
||||
api_key = data.get("key")
|
||||
user_id = data.get("user_id")
|
||||
teams = data.get("teams", [])
|
||||
team_id = data.get("team_id")
|
||||
|
||||
# Show which team was assigned
|
||||
if team_id and len(teams) == 1:
|
||||
click.echo(f"\n✅ Automatically assigned to team: {team_id}")
|
||||
|
||||
if api_key:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
"user_id": user_id,
|
||||
"teams": teams,
|
||||
"team_id": team_id,
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _handle_team_selection_during_polling(
|
||||
base_url: str, key_id: str, teams: List[Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Handle team selection and re-poll with selected team_id.
|
||||
|
||||
Args:
|
||||
teams: List of team IDs (strings)
|
||||
|
||||
Returns:
|
||||
The JWT token with the selected team, or None if selection was skipped
|
||||
"""
|
||||
if not teams:
|
||||
click.echo(
|
||||
"ℹ️ No teams found. You can create or join teams using the web interface."
|
||||
)
|
||||
return None
|
||||
|
||||
click.echo("\n" + "=" * 60)
|
||||
click.echo("📋 Select a team for your CLI session...")
|
||||
|
||||
team_id = _render_and_prompt_for_team_selection(teams)
|
||||
|
||||
if not team_id:
|
||||
click.echo("ℹ️ No team selected.")
|
||||
return None
|
||||
|
||||
click.echo(f"\n🔄 Generating JWT for team: {team_id}")
|
||||
|
||||
poll_url = f"{base_url}/sso/cli/poll/{key_id}?team_id={team_id}"
|
||||
data = _poll_for_ready_data(
|
||||
poll_url,
|
||||
pending_message="Still waiting for team authentication...",
|
||||
other_status_message="Waiting for team authentication to complete...",
|
||||
http_error_log_every=10,
|
||||
)
|
||||
if not data:
|
||||
return None
|
||||
jwt_token = data.get("key")
|
||||
if jwt_token:
|
||||
click.echo(f"✅ Successfully generated JWT for team: {team_id}")
|
||||
return jwt_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _render_and_prompt_for_team_selection(teams: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Render teams table and prompt user for a team selection.
|
||||
|
||||
Returns the selected team_id as a string, or None if selection was
|
||||
cancelled or skipped without any teams available.
|
||||
"""
|
||||
# Display teams as a simple list, but prefer showing aliases where
|
||||
# available while still keeping the underlying IDs intact.
|
||||
console = Console()
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Name", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_id = str(team.get("team_id"))
|
||||
team_alias = team.get("team_alias") or team_id
|
||||
table.add_row(str(i + 1), team_alias, team_id)
|
||||
|
||||
console.print(table)
|
||||
|
||||
# Simple selection
|
||||
while True:
|
||||
try:
|
||||
choice = click.prompt(
|
||||
"\nSelect a team by entering the index number (or 'skip' to use first team)",
|
||||
type=str,
|
||||
).strip()
|
||||
|
||||
if choice.lower() == "skip":
|
||||
# Default to the first team's ID if the user skips an
|
||||
# explicit selection.
|
||||
if teams:
|
||||
first_team = teams[0]
|
||||
return str(first_team.get("team_id"))
|
||||
return None
|
||||
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(teams):
|
||||
selected_team = teams[index]
|
||||
team_id = str(selected_team.get("team_id"))
|
||||
team_alias = selected_team.get("team_alias") or team_id
|
||||
click.echo(f"\n✅ Selected team: {team_alias} ({team_id})")
|
||||
return team_id
|
||||
|
||||
click.echo(
|
||||
f"❌ Invalid selection. Please enter a number between 1 and {len(teams)}"
|
||||
)
|
||||
except ValueError:
|
||||
click.echo("❌ Invalid input. Please enter a number or 'skip'")
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Team selection cancelled.")
|
||||
return None
|
||||
|
||||
|
||||
@click.command(name="login")
|
||||
@click.pass_context
|
||||
def login(ctx: click.Context):
|
||||
"""Login to LiteLLM proxy using SSO authentication"""
|
||||
from litellm._uuid import uuid
|
||||
from litellm.constants import LITELLM_CLI_SOURCE_IDENTIFIER
|
||||
from litellm.proxy.client.cli.interface import show_commands
|
||||
|
||||
base_url = ctx.obj["base_url"]
|
||||
|
||||
# Check if we have an existing key to regenerate
|
||||
existing_key = get_stored_api_key()
|
||||
|
||||
# Generate unique key ID for this login session
|
||||
key_id = f"sk-{str(uuid.uuid4())}"
|
||||
|
||||
try:
|
||||
# Construct SSO login URL with CLI source and pre-generated key
|
||||
sso_url = f"{base_url}/sso/key/generate?source={LITELLM_CLI_SOURCE_IDENTIFIER}&key={key_id}"
|
||||
|
||||
# If we have an existing key, include it as a parameter to the login endpoint
|
||||
# The server will encode it in the OAuth state parameter for the SSO flow
|
||||
if existing_key:
|
||||
sso_url += f"&existing_key={existing_key}"
|
||||
|
||||
click.echo(f"Opening browser to: {sso_url}")
|
||||
click.echo("Please complete the SSO authentication in your browser...")
|
||||
click.echo(f"Session ID: {key_id}")
|
||||
|
||||
# Open browser
|
||||
webbrowser.open(sso_url)
|
||||
|
||||
# Poll for authentication completion
|
||||
click.echo("Waiting for authentication...")
|
||||
|
||||
auth_result = _poll_for_authentication(base_url=base_url, key_id=key_id)
|
||||
|
||||
if auth_result:
|
||||
api_key = auth_result["api_key"]
|
||||
user_id = auth_result["user_id"]
|
||||
|
||||
# Save token data (simplified for CLI - we just need the key)
|
||||
save_token(
|
||||
{
|
||||
"key": api_key,
|
||||
"user_id": user_id or "cli-user",
|
||||
"user_email": "unknown",
|
||||
"user_role": "cli",
|
||||
"auth_header_name": "Authorization",
|
||||
"jwt_token": "",
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
)
|
||||
|
||||
click.echo("\n✅ Login successful!")
|
||||
click.echo(f"JWT Token: {api_key[:20]}...")
|
||||
click.echo("You can now use the CLI without specifying --api-key")
|
||||
|
||||
# Show available commands after successful login
|
||||
click.echo("\n" + "=" * 60)
|
||||
show_commands()
|
||||
return
|
||||
else:
|
||||
click.echo("❌ Authentication timed out. Please try again.")
|
||||
return
|
||||
|
||||
except KeyboardInterrupt:
|
||||
click.echo("\n❌ Authentication cancelled by user.")
|
||||
return
|
||||
except Exception as e:
|
||||
click.echo(f"❌ Authentication failed: {e}")
|
||||
return
|
||||
|
||||
|
||||
@click.command(name="logout")
|
||||
def logout():
|
||||
"""Logout and clear stored authentication"""
|
||||
clear_token()
|
||||
click.echo("✅ Logged out successfully. Authentication token cleared.")
|
||||
|
||||
|
||||
@click.command(name="whoami")
|
||||
def whoami():
|
||||
"""Show current authentication status"""
|
||||
token_data = load_token()
|
||||
|
||||
if not token_data:
|
||||
click.echo("❌ Not authenticated. Run 'litellm-proxy login' to authenticate.")
|
||||
return
|
||||
|
||||
click.echo("✅ Authenticated")
|
||||
click.echo(f"User Email: {token_data.get('user_email', 'Unknown')}")
|
||||
click.echo(f"User ID: {token_data.get('user_id', 'Unknown')}")
|
||||
click.echo(f"User Role: {token_data.get('user_role', 'Unknown')}")
|
||||
|
||||
# Check if token is still valid (basic timestamp check)
|
||||
timestamp = token_data.get("timestamp", 0)
|
||||
age_hours = (time.time() - timestamp) / 3600
|
||||
click.echo(f"Token age: {age_hours:.1f} hours")
|
||||
|
||||
if age_hours > CLI_JWT_EXPIRATION_HOURS:
|
||||
click.echo(
|
||||
f"⚠️ Warning: Token is more than {CLI_JWT_EXPIRATION_HOURS} hours old and may have expired."
|
||||
)
|
||||
|
||||
|
||||
# Export functions for use by other CLI commands
|
||||
__all__ = ["login", "logout", "whoami", "prompt_team_selection"]
|
||||
|
||||
# Export individual commands instead of grouping them
|
||||
# login, logout, and whoami will be added as top-level commands
|
||||
@@ -0,0 +1,406 @@
|
||||
import json
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.prompt import Prompt
|
||||
from rich.table import Table
|
||||
|
||||
from ... import Client
|
||||
from ...chat import ChatClient
|
||||
|
||||
|
||||
def _get_available_models(ctx: click.Context) -> List[Dict[str, Any]]:
|
||||
"""Get list of available models from the proxy server"""
|
||||
try:
|
||||
client = Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
|
||||
models_list = client.models.list()
|
||||
# Ensure we return a list of dictionaries
|
||||
if isinstance(models_list, list):
|
||||
# Filter to ensure all items are dictionaries
|
||||
return [model for model in models_list if isinstance(model, dict)]
|
||||
return []
|
||||
except Exception as e:
|
||||
click.echo(f"Warning: Could not fetch models list: {e}", err=True)
|
||||
return []
|
||||
|
||||
|
||||
def _select_model(
|
||||
console: Console, available_models: List[Dict[str, Any]]
|
||||
) -> Optional[str]:
|
||||
"""Interactive model selection"""
|
||||
if not available_models:
|
||||
console.print(
|
||||
"[yellow]No models available or could not fetch models list.[/yellow]"
|
||||
)
|
||||
model_name = Prompt.ask("Please enter a model name")
|
||||
return model_name if model_name.strip() else None
|
||||
|
||||
# Display available models in a table
|
||||
table = Table(title="Available Models")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Model ID", style="green")
|
||||
table.add_column("Owned By", style="yellow")
|
||||
MAX_MODELS_TO_DISPLAY = 200
|
||||
|
||||
models_to_display: List[Dict[str, Any]] = available_models[:MAX_MODELS_TO_DISPLAY]
|
||||
for i, model in enumerate(models_to_display): # Limit to first 200 models
|
||||
table.add_row(
|
||||
str(i + 1), str(model.get("id", "")), str(model.get("owned_by", ""))
|
||||
)
|
||||
|
||||
if len(available_models) > MAX_MODELS_TO_DISPLAY:
|
||||
console.print(
|
||||
f"\n[dim]... and {len(available_models) - MAX_MODELS_TO_DISPLAY} more models[/dim]"
|
||||
)
|
||||
|
||||
console.print(table)
|
||||
|
||||
while True:
|
||||
try:
|
||||
choice = Prompt.ask(
|
||||
"\nSelect a model by entering the index number (or type a model name directly)",
|
||||
default="1",
|
||||
).strip()
|
||||
|
||||
# Try to parse as index
|
||||
try:
|
||||
index = int(choice) - 1
|
||||
if 0 <= index < len(available_models):
|
||||
return available_models[index]["id"]
|
||||
else:
|
||||
console.print(
|
||||
f"[red]Invalid index. Please enter a number between 1 and {len(available_models)}[/red]"
|
||||
)
|
||||
continue
|
||||
except ValueError:
|
||||
# Not a number, treat as model name
|
||||
if choice:
|
||||
return choice
|
||||
else:
|
||||
console.print("[red]Please enter a valid model name or index[/red]")
|
||||
continue
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Model selection cancelled.[/yellow]")
|
||||
return None
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("model", required=False)
|
||||
@click.option(
|
||||
"--temperature",
|
||||
"-t",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="Sampling temperature between 0 and 2 (default: 0.7)",
|
||||
)
|
||||
@click.option(
|
||||
"--max-tokens",
|
||||
type=int,
|
||||
help="Maximum number of tokens to generate",
|
||||
)
|
||||
@click.option(
|
||||
"--system",
|
||||
"-s",
|
||||
type=str,
|
||||
help="System message to set the behavior of the assistant",
|
||||
)
|
||||
@click.pass_context
|
||||
def chat(
|
||||
ctx: click.Context,
|
||||
model: Optional[str],
|
||||
temperature: float,
|
||||
max_tokens: Optional[int] = None,
|
||||
system: Optional[str] = None,
|
||||
):
|
||||
"""Interactive chat with streaming responses
|
||||
|
||||
Examples:
|
||||
|
||||
# Chat with a specific model
|
||||
litellm-proxy chat gpt-4
|
||||
|
||||
# Chat without specifying model (will show model selection)
|
||||
litellm-proxy chat
|
||||
|
||||
# Chat with custom settings
|
||||
litellm-proxy chat gpt-4 --temperature 0.9 --system "You are a helpful coding assistant"
|
||||
"""
|
||||
console = Console()
|
||||
|
||||
# If no model specified, show model selection
|
||||
if not model:
|
||||
available_models = _get_available_models(ctx)
|
||||
model = _select_model(console, available_models)
|
||||
if not model:
|
||||
console.print("[red]No model selected. Exiting.[/red]")
|
||||
return
|
||||
|
||||
client = ChatClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
# Initialize conversation history
|
||||
messages: List[Dict[str, Any]] = []
|
||||
|
||||
# Add system message if provided
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
|
||||
# Display welcome message
|
||||
console.print(
|
||||
Panel.fit(
|
||||
f"[bold blue]LiteLLM Interactive Chat[/bold blue]\n"
|
||||
f"Model: [green]{model}[/green]\n"
|
||||
f"Temperature: [yellow]{temperature}[/yellow]\n"
|
||||
f"Max Tokens: [yellow]{max_tokens or 'unlimited'}[/yellow]\n\n"
|
||||
f"Type your messages and press Enter. Type '/quit' or '/exit' to end the session.\n"
|
||||
f"Type '/help' for more commands.",
|
||||
title="🤖 Chat Session",
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Get user input
|
||||
try:
|
||||
user_input = console.input("\n[bold cyan]You:[/bold cyan] ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
console.print("\n[yellow]Chat session ended.[/yellow]")
|
||||
break
|
||||
|
||||
# Handle special commands
|
||||
should_exit, messages, new_model = _handle_special_commands(
|
||||
console, user_input, messages, system, ctx
|
||||
)
|
||||
|
||||
if should_exit:
|
||||
break
|
||||
if new_model:
|
||||
model = new_model
|
||||
|
||||
# Check if this was a special command that was handled (not a normal message)
|
||||
if (
|
||||
user_input.lower().startswith(
|
||||
(
|
||||
"/quit",
|
||||
"/exit",
|
||||
"/q",
|
||||
"/help",
|
||||
"/clear",
|
||||
"/history",
|
||||
"/save",
|
||||
"/load",
|
||||
"/model",
|
||||
)
|
||||
)
|
||||
or not user_input
|
||||
):
|
||||
continue
|
||||
|
||||
# Add user message to conversation
|
||||
messages.append({"role": "user", "content": user_input})
|
||||
|
||||
# Display assistant label
|
||||
console.print("\n[bold green]Assistant:[/bold green]")
|
||||
|
||||
# Stream the response
|
||||
assistant_content = _stream_response(
|
||||
console=console,
|
||||
client=client,
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
# Add assistant message to conversation history
|
||||
if assistant_content:
|
||||
messages.append({"role": "assistant", "content": assistant_content})
|
||||
else:
|
||||
console.print("[red]Error: No content received from the model[/red]")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
console.print("\n[yellow]Chat session interrupted.[/yellow]")
|
||||
|
||||
|
||||
def _show_help(console: Console):
|
||||
"""Show help for interactive chat commands"""
|
||||
help_text = """
|
||||
[bold]Interactive Chat Commands:[/bold]
|
||||
|
||||
[cyan]/help[/cyan] - Show this help message
|
||||
[cyan]/quit[/cyan] - Exit the chat session (also /exit, /q)
|
||||
[cyan]/clear[/cyan] - Clear conversation history
|
||||
[cyan]/history[/cyan] - Show conversation history
|
||||
[cyan]/model[/cyan] - Switch to a different model
|
||||
[cyan]/save <name>[/cyan] - Save conversation to file
|
||||
[cyan]/load <name>[/cyan] - Load conversation from file
|
||||
|
||||
[bold]Tips:[/bold]
|
||||
- Your conversation history is maintained during the session
|
||||
- Use Ctrl+C to interrupt at any time
|
||||
- Responses are streamed in real-time
|
||||
- You can switch models mid-conversation with /model
|
||||
"""
|
||||
console.print(Panel(help_text, title="Help"))
|
||||
|
||||
|
||||
def _show_history(console: Console, messages: List[Dict[str, Any]]):
|
||||
"""Show conversation history"""
|
||||
if not messages:
|
||||
console.print("[yellow]No conversation history.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(Panel.fit("[bold]Conversation History[/bold]", title="History"))
|
||||
|
||||
for i, message in enumerate(messages, 1):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role == "system":
|
||||
console.print(
|
||||
f"[dim]{i}. [bold magenta]System:[/bold magenta] {content}[/dim]"
|
||||
)
|
||||
elif role == "user":
|
||||
console.print(f"{i}. [bold cyan]You:[/bold cyan] {content}")
|
||||
elif role == "assistant":
|
||||
console.print(
|
||||
f"{i}. [bold green]Assistant:[/bold green] {content[:100]}{'...' if len(content) > 100 else ''}"
|
||||
)
|
||||
|
||||
|
||||
def _save_conversation(console: Console, messages: List[Dict[str, Any]], command: str):
|
||||
"""Save conversation to a file"""
|
||||
parts = command.split()
|
||||
if len(parts) < 2:
|
||||
console.print("[red]Usage: /save <filename>[/red]")
|
||||
return
|
||||
|
||||
filename = parts[1]
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
try:
|
||||
with open(filename, "w") as f:
|
||||
json.dump(messages, f, indent=2)
|
||||
console.print(f"[green]Conversation saved to {filename}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error saving conversation: {e}[/red]")
|
||||
|
||||
|
||||
def _load_conversation(
|
||||
console: Console, command: str, system: Optional[str]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Load conversation from a file"""
|
||||
parts = command.split()
|
||||
if len(parts) < 2:
|
||||
console.print("[red]Usage: /load <filename>[/red]")
|
||||
return []
|
||||
|
||||
filename = parts[1]
|
||||
if not filename.endswith(".json"):
|
||||
filename += ".json"
|
||||
|
||||
try:
|
||||
with open(filename, "r") as f:
|
||||
messages = json.load(f)
|
||||
console.print(f"[green]Conversation loaded from {filename}[/green]")
|
||||
return messages
|
||||
except FileNotFoundError:
|
||||
console.print(f"[red]File not found: {filename}[/red]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]Error loading conversation: {e}[/red]")
|
||||
|
||||
# Return empty list or just system message if load failed
|
||||
if system:
|
||||
return [{"role": "system", "content": system}]
|
||||
return []
|
||||
|
||||
|
||||
def _handle_special_commands(
|
||||
console: Console,
|
||||
user_input: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
system: Optional[str],
|
||||
ctx: click.Context,
|
||||
) -> tuple[bool, List[Dict[str, Any]], Optional[str]]:
|
||||
"""Handle special chat commands. Returns (should_exit, updated_messages, updated_model)"""
|
||||
if user_input.lower() in ["/quit", "/exit", "/q"]:
|
||||
console.print("[yellow]Chat session ended.[/yellow]")
|
||||
return True, messages, None
|
||||
elif user_input.lower() == "/help":
|
||||
_show_help(console)
|
||||
return False, messages, None
|
||||
elif user_input.lower() == "/clear":
|
||||
new_messages = []
|
||||
if system:
|
||||
new_messages.append({"role": "system", "content": system})
|
||||
console.print("[green]Conversation history cleared.[/green]")
|
||||
return False, new_messages, None
|
||||
elif user_input.lower() == "/history":
|
||||
_show_history(console, messages)
|
||||
return False, messages, None
|
||||
elif user_input.lower().startswith("/save"):
|
||||
_save_conversation(console, messages, user_input)
|
||||
return False, messages, None
|
||||
elif user_input.lower().startswith("/load"):
|
||||
new_messages = _load_conversation(console, user_input, system)
|
||||
return False, new_messages, None
|
||||
elif user_input.lower() == "/model":
|
||||
available_models = _get_available_models(ctx)
|
||||
new_model = _select_model(console, available_models)
|
||||
if new_model:
|
||||
console.print(f"[green]Switched to model: {new_model}[/green]")
|
||||
return False, messages, new_model
|
||||
return False, messages, None
|
||||
elif not user_input:
|
||||
return False, messages, None
|
||||
|
||||
# Not a special command
|
||||
return False, messages, None
|
||||
|
||||
|
||||
def _stream_response(
|
||||
console: Console,
|
||||
client: ChatClient,
|
||||
model: str,
|
||||
messages: List[Dict[str, Any]],
|
||||
temperature: float,
|
||||
max_tokens: Optional[int],
|
||||
) -> Optional[str]:
|
||||
"""Stream the model response and return the complete content"""
|
||||
try:
|
||||
assistant_content = ""
|
||||
for chunk in client.completions_stream(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
):
|
||||
if "choices" in chunk and len(chunk["choices"]) > 0:
|
||||
delta = chunk["choices"][0].get("delta", {})
|
||||
content = delta.get("content", "")
|
||||
if content:
|
||||
assistant_content += content
|
||||
console.print(content, end="")
|
||||
sys.stdout.flush()
|
||||
|
||||
console.print() # Add newline after streaming
|
||||
return assistant_content if assistant_content else None
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
console.print(f"\n[red]Error: HTTP {e.response.status_code}[/red]")
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
console.print(
|
||||
f"[red]{error_body.get('error', {}).get('message', 'Unknown error')}[/red]"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
console.print(f"[red]{e.response.text}[/red]")
|
||||
return None
|
||||
except Exception as e:
|
||||
console.print(f"\n[red]Error: {str(e)}[/red]")
|
||||
return None
|
||||
@@ -0,0 +1,116 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
from rich.table import Table
|
||||
|
||||
from ...credentials import CredentialsManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def credentials():
|
||||
"""Manage credentials for the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list(ctx: click.Context, output_format: Literal["table", "json"]):
|
||||
"""List all credentials"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.list()
|
||||
assert isinstance(response, dict)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=response)
|
||||
else: # table format
|
||||
table = Table(title="Credentials")
|
||||
|
||||
# Add columns
|
||||
table.add_column("Credential Name", style="cyan")
|
||||
table.add_column("Custom LLM Provider", style="green")
|
||||
|
||||
# Add rows
|
||||
for cred in response.get("credentials", []):
|
||||
info = cred.get("credential_info", {})
|
||||
table.add_row(
|
||||
str(cred.get("credential_name", "")),
|
||||
str(info.get("custom_llm_provider", "")),
|
||||
)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.option(
|
||||
"--info",
|
||||
type=str,
|
||||
help="JSON string containing credential info",
|
||||
required=True,
|
||||
)
|
||||
@click.option(
|
||||
"--values",
|
||||
type=str,
|
||||
help="JSON string containing credential values",
|
||||
required=True,
|
||||
)
|
||||
@click.pass_context
|
||||
def create(ctx: click.Context, credential_name: str, info: str, values: str):
|
||||
"""Create a new credential"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
credential_info = json.loads(info)
|
||||
credential_values = json.loads(values)
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.BadParameter(f"Invalid JSON: {str(e)}")
|
||||
|
||||
try:
|
||||
response = client.create(credential_name, credential_info, credential_values)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.pass_context
|
||||
def delete(ctx: click.Context, credential_name: str):
|
||||
"""Delete a credential by name"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
response = client.delete(credential_name)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@credentials.command()
|
||||
@click.argument("credential_name")
|
||||
@click.pass_context
|
||||
def get(ctx: click.Context, credential_name: str):
|
||||
"""Get a credential by name"""
|
||||
client = CredentialsManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.get(credential_name)
|
||||
rich.print_json(data=response)
|
||||
@@ -0,0 +1,102 @@
|
||||
import json as json_lib
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
|
||||
from ...http_client import HTTPClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def http():
|
||||
"""Make HTTP requests to the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@http.command()
|
||||
@click.argument("method")
|
||||
@click.argument("uri")
|
||||
@click.option(
|
||||
"--data",
|
||||
"-d",
|
||||
type=str,
|
||||
help="Data to send in the request body (as JSON string)",
|
||||
)
|
||||
@click.option(
|
||||
"--json",
|
||||
"-j",
|
||||
type=str,
|
||||
help="JSON data to send in the request body (as JSON string)",
|
||||
)
|
||||
@click.option(
|
||||
"--header",
|
||||
"-H",
|
||||
multiple=True,
|
||||
help="HTTP headers in 'key:value' format. Can be specified multiple times.",
|
||||
)
|
||||
@click.pass_context
|
||||
def request(
|
||||
ctx: click.Context,
|
||||
method: str,
|
||||
uri: str,
|
||||
data: Optional[str] = None,
|
||||
json: Optional[str] = None,
|
||||
header: tuple[str, ...] = (),
|
||||
):
|
||||
"""Make an HTTP request to the LiteLLM proxy server
|
||||
|
||||
METHOD: HTTP method (GET, POST, PUT, DELETE, etc.)
|
||||
URI: URI path (will be appended to base_url)
|
||||
|
||||
Examples:
|
||||
litellm http request GET /models
|
||||
litellm http request POST /chat/completions -j '{"model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}]}'
|
||||
litellm http request GET /health/test_connection -H "X-Custom-Header:value"
|
||||
"""
|
||||
# Parse headers from key:value format
|
||||
headers = {}
|
||||
for h in header:
|
||||
try:
|
||||
key, value = h.split(":", 1)
|
||||
headers[key.strip()] = value.strip()
|
||||
except ValueError:
|
||||
raise click.BadParameter(
|
||||
f"Invalid header format: {h}. Expected format: 'key:value'"
|
||||
)
|
||||
|
||||
# Parse JSON data if provided
|
||||
json_data = None
|
||||
if json:
|
||||
try:
|
||||
json_data = json_lib.loads(json)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(f"Invalid JSON format: {e}")
|
||||
|
||||
# Parse data if provided
|
||||
request_data = None
|
||||
if data:
|
||||
try:
|
||||
request_data = json_lib.loads(data)
|
||||
except ValueError:
|
||||
# If not JSON, use as raw data
|
||||
request_data = data
|
||||
|
||||
client = HTTPClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
response = client.request(
|
||||
method=method,
|
||||
uri=uri,
|
||||
data=request_data,
|
||||
json=json_data,
|
||||
headers=headers,
|
||||
)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json_lib.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,415 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Literal, Optional, List, Dict, Any
|
||||
|
||||
import click
|
||||
import rich
|
||||
import requests
|
||||
from rich.table import Table
|
||||
|
||||
from ...keys import KeysManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def keys():
|
||||
"""Manage API keys for the LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--page", type=int, help="Page number for pagination")
|
||||
@click.option("--size", type=int, help="Number of items per page")
|
||||
@click.option("--user-id", type=str, help="Filter keys by user ID")
|
||||
@click.option("--team-id", type=str, help="Filter keys by team ID")
|
||||
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
|
||||
@click.option("--key-hash", type=str, help="Filter by specific key hash")
|
||||
@click.option("--key-alias", type=str, help="Filter by key alias")
|
||||
@click.option(
|
||||
"--return-full-object",
|
||||
is_flag=True,
|
||||
default=True,
|
||||
help="Return the full key object",
|
||||
)
|
||||
@click.option(
|
||||
"--include-team-keys", is_flag=True, help="Include team keys in the response"
|
||||
)
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list(
|
||||
ctx: click.Context,
|
||||
page: Optional[int],
|
||||
size: Optional[int],
|
||||
user_id: Optional[str],
|
||||
team_id: Optional[str],
|
||||
organization_id: Optional[str],
|
||||
key_hash: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
include_team_keys: bool,
|
||||
output_format: Literal["table", "json"],
|
||||
return_full_object: bool,
|
||||
):
|
||||
"""List all API keys"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
response = client.list(
|
||||
page=page,
|
||||
size=size,
|
||||
user_id=user_id,
|
||||
team_id=team_id,
|
||||
organization_id=organization_id,
|
||||
key_hash=key_hash,
|
||||
key_alias=key_alias,
|
||||
return_full_object=return_full_object,
|
||||
include_team_keys=include_team_keys,
|
||||
)
|
||||
assert isinstance(response, dict)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=response)
|
||||
else:
|
||||
rich.print(
|
||||
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
|
||||
)
|
||||
table = Table(title="API Keys")
|
||||
table.add_column("Key Hash", style="cyan")
|
||||
table.add_column("Alias", style="green")
|
||||
table.add_column("User ID", style="magenta")
|
||||
table.add_column("Team ID", style="yellow")
|
||||
table.add_column("Spend", style="red")
|
||||
for key in response.get("keys", []):
|
||||
table.add_row(
|
||||
str(key.get("token", "")),
|
||||
str(key.get("key_alias", "")),
|
||||
str(key.get("user_id", "")),
|
||||
str(key.get("team_id", "")),
|
||||
str(key.get("spend", "")),
|
||||
)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--models", type=str, help="Comma-separated list of allowed models")
|
||||
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
|
||||
@click.option("--spend", type=float, help="Maximum spend limit for this key")
|
||||
@click.option(
|
||||
"--duration",
|
||||
type=str,
|
||||
help="Duration for which the key is valid (e.g. '24h', '7d')",
|
||||
)
|
||||
@click.option("--key-alias", type=str, help="Alias/name for the key")
|
||||
@click.option("--team-id", type=str, help="Team ID to associate the key with")
|
||||
@click.option("--user-id", type=str, help="User ID to associate the key with")
|
||||
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
|
||||
@click.option(
|
||||
"--config", type=str, help="JSON string of additional configuration parameters"
|
||||
)
|
||||
@click.pass_context
|
||||
def generate(
|
||||
ctx: click.Context,
|
||||
models: Optional[str],
|
||||
aliases: Optional[str],
|
||||
spend: Optional[float],
|
||||
duration: Optional[str],
|
||||
key_alias: Optional[str],
|
||||
team_id: Optional[str],
|
||||
user_id: Optional[str],
|
||||
budget_id: Optional[str],
|
||||
config: Optional[str],
|
||||
):
|
||||
"""Generate a new API key"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
try:
|
||||
models_list = [m.strip() for m in models.split(",")] if models else None
|
||||
aliases_dict = json.loads(aliases) if aliases else None
|
||||
config_dict = json.loads(config) if config else None
|
||||
except json.JSONDecodeError as e:
|
||||
raise click.BadParameter(f"Invalid JSON: {str(e)}")
|
||||
try:
|
||||
response = client.generate(
|
||||
models=models_list,
|
||||
aliases=aliases_dict,
|
||||
spend=spend,
|
||||
duration=duration,
|
||||
key_alias=key_alias,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
budget_id=budget_id,
|
||||
config=config_dict,
|
||||
)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@keys.command()
|
||||
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
|
||||
@click.option(
|
||||
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
|
||||
)
|
||||
@click.pass_context
|
||||
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
|
||||
"""Delete API keys by key or alias"""
|
||||
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
keys_list = [k.strip() for k in keys.split(",")] if keys else None
|
||||
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
|
||||
try:
|
||||
response = client.delete(keys=keys_list, key_aliases=aliases_list)
|
||||
rich.print_json(data=response)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
|
||||
"""Parse and validate the created_since date filter."""
|
||||
if not created_since:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
|
||||
if "_" in created_since:
|
||||
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
|
||||
else:
|
||||
return datetime.strptime(created_since, "%Y-%m-%d")
|
||||
except ValueError:
|
||||
click.echo(
|
||||
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
|
||||
err=True,
|
||||
)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
def _fetch_all_keys_with_pagination(
|
||||
source_client: KeysManagementClient, source_base_url: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Fetch all keys from source instance using pagination."""
|
||||
click.echo(f"Fetching keys from source server: {source_base_url}")
|
||||
source_keys = []
|
||||
page = 1
|
||||
page_size = 100 # Use a larger page size to minimize API calls
|
||||
|
||||
while True:
|
||||
source_response = source_client.list(
|
||||
return_full_object=True, page=page, size=page_size
|
||||
)
|
||||
# source_client.list() returns Dict[str, Any] when return_request is False (default)
|
||||
assert isinstance(source_response, dict), "Expected dict response from list API"
|
||||
page_keys = source_response.get("keys", [])
|
||||
|
||||
if not page_keys:
|
||||
break
|
||||
|
||||
source_keys.extend(page_keys)
|
||||
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
|
||||
|
||||
# Check if we got fewer keys than the page size, indicating last page
|
||||
if len(page_keys) < page_size:
|
||||
break
|
||||
|
||||
page += 1
|
||||
|
||||
return source_keys
|
||||
|
||||
|
||||
def _filter_keys_by_created_since(
|
||||
source_keys: List[Dict[str, Any]],
|
||||
created_since_dt: Optional[datetime],
|
||||
created_since: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Filter keys by created_since date if specified."""
|
||||
if not created_since_dt:
|
||||
return source_keys
|
||||
|
||||
filtered_keys = []
|
||||
for key in source_keys:
|
||||
key_created_at = key.get("created_at")
|
||||
if key_created_at:
|
||||
# Parse the key's created_at timestamp
|
||||
if isinstance(key_created_at, str):
|
||||
if "T" in key_created_at:
|
||||
key_dt = datetime.fromisoformat(
|
||||
key_created_at.replace("Z", "+00:00")
|
||||
)
|
||||
else:
|
||||
key_dt = datetime.fromisoformat(key_created_at)
|
||||
|
||||
# Convert to naive datetime for comparison (assuming UTC)
|
||||
if key_dt.tzinfo:
|
||||
key_dt = key_dt.replace(tzinfo=None)
|
||||
|
||||
if key_dt >= created_since_dt:
|
||||
filtered_keys.append(key)
|
||||
|
||||
click.echo(
|
||||
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
|
||||
)
|
||||
return filtered_keys
|
||||
|
||||
|
||||
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
|
||||
"""Display a table of keys that would be imported in dry-run mode."""
|
||||
click.echo("\n--- DRY RUN MODE ---")
|
||||
table = Table(title="Keys that would be imported")
|
||||
table.add_column("Key Alias", style="green")
|
||||
table.add_column("User ID", style="magenta")
|
||||
table.add_column("Created", style="cyan")
|
||||
|
||||
for key in source_keys:
|
||||
created_at = key.get("created_at", "")
|
||||
# Format the timestamp if it exists
|
||||
if created_at:
|
||||
# Try to parse and format the timestamp for better readability
|
||||
if isinstance(created_at, str):
|
||||
# Handle common timestamp formats
|
||||
if "T" in created_at:
|
||||
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
||||
created_at = dt.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
table.add_row(
|
||||
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
|
||||
)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Prepare key data for import by extracting relevant fields."""
|
||||
import_data = {}
|
||||
|
||||
# Copy relevant fields if they exist
|
||||
for field in [
|
||||
"models",
|
||||
"aliases",
|
||||
"spend",
|
||||
"key_alias",
|
||||
"team_id",
|
||||
"user_id",
|
||||
"budget_id",
|
||||
"config",
|
||||
]:
|
||||
if key.get(field):
|
||||
import_data[field] = key[field]
|
||||
|
||||
return import_data
|
||||
|
||||
|
||||
def _import_keys_to_destination(
|
||||
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
|
||||
) -> tuple[int, int]:
|
||||
"""Import each key to the destination instance and return counts."""
|
||||
imported_count = 0
|
||||
failed_count = 0
|
||||
|
||||
for key in source_keys:
|
||||
try:
|
||||
# Prepare key data for import
|
||||
import_data = _prepare_key_import_data(key)
|
||||
|
||||
# Generate the key in destination instance
|
||||
response = dest_client.generate(**import_data)
|
||||
click.echo(f"Generated key: {response}")
|
||||
# The generate method returns JSON data directly, not a Response object
|
||||
imported_count += 1
|
||||
|
||||
key_alias = key.get("key_alias", "N/A")
|
||||
click.echo(f"✓ Imported key: {key_alias}")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += 1
|
||||
key_alias = key.get("key_alias", "N/A")
|
||||
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
|
||||
|
||||
return imported_count, failed_count
|
||||
|
||||
|
||||
@keys.command(name="import")
|
||||
@click.option(
|
||||
"--source-base-url",
|
||||
required=True,
|
||||
help="Base URL of the source LiteLLM proxy server to import keys from",
|
||||
)
|
||||
@click.option(
|
||||
"--source-api-key", help="API key for authentication to the source server"
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Show what would be imported without actually importing",
|
||||
)
|
||||
@click.option(
|
||||
"--created-since",
|
||||
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
|
||||
)
|
||||
@click.pass_context
|
||||
def import_keys(
|
||||
ctx: click.Context,
|
||||
source_base_url: str,
|
||||
source_api_key: Optional[str],
|
||||
dry_run: bool,
|
||||
created_since: Optional[str],
|
||||
):
|
||||
"""Import API keys from another LiteLLM instance"""
|
||||
# Parse created_since filter if provided
|
||||
created_since_dt = _parse_created_since_filter(created_since)
|
||||
|
||||
# Create clients for both source and destination
|
||||
source_client = KeysManagementClient(source_base_url, source_api_key)
|
||||
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
# Get all keys from source instance with pagination
|
||||
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
|
||||
|
||||
# Filter keys by created_since if specified
|
||||
if created_since:
|
||||
source_keys = _filter_keys_by_created_since(
|
||||
source_keys, created_since_dt, created_since
|
||||
)
|
||||
|
||||
if not source_keys:
|
||||
click.echo("No keys found in source instance.")
|
||||
return
|
||||
|
||||
click.echo(f"Found {len(source_keys)} keys in source instance.")
|
||||
|
||||
if dry_run:
|
||||
_display_dry_run_table(source_keys)
|
||||
return
|
||||
|
||||
# Import each key
|
||||
imported_count, failed_count = _import_keys_to_destination(
|
||||
source_keys, dest_client
|
||||
)
|
||||
|
||||
# Summary
|
||||
click.echo("\nImport completed:")
|
||||
click.echo(f" Successfully imported: {imported_count}")
|
||||
click.echo(f" Failed to import: {failed_count}")
|
||||
click.echo(f" Total keys processed: {len(source_keys)}")
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
try:
|
||||
error_body = e.response.json()
|
||||
rich.print_json(data=error_body)
|
||||
except json.JSONDecodeError:
|
||||
click.echo(e.response.text, err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,485 @@
|
||||
# stdlib imports
|
||||
from datetime import datetime
|
||||
import re
|
||||
from typing import Optional, Literal, Any
|
||||
import yaml
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
|
||||
# third party imports
|
||||
import click
|
||||
import rich
|
||||
|
||||
# local imports
|
||||
from ... import Client
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelYamlInfo:
|
||||
model_name: str
|
||||
model_params: dict[str, Any]
|
||||
model_info: dict[str, Any]
|
||||
model_id: str
|
||||
access_groups: list[str]
|
||||
provider: str
|
||||
|
||||
@property
|
||||
def access_groups_str(self) -> str:
|
||||
return ", ".join(self.access_groups) if self.access_groups else ""
|
||||
|
||||
|
||||
def _get_model_info_obj_from_yaml(model: dict[str, Any]) -> ModelYamlInfo:
|
||||
"""Extract model info from a model dict and return as ModelYamlInfo dataclass."""
|
||||
model_name: str = model["model_name"]
|
||||
model_params: dict[str, Any] = model["litellm_params"]
|
||||
model_info: dict[str, Any] = model.get("model_info", {})
|
||||
model_id: str = model_params["model"]
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
provider = model_id.split("/", 1)[0] if "/" in model_id else model_id
|
||||
return ModelYamlInfo(
|
||||
model_name=model_name,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
model_id=model_id,
|
||||
access_groups=access_groups,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def format_iso_datetime_str(iso_datetime_str: Optional[str]) -> str:
|
||||
"""Format an ISO format datetime string to human-readable date with minute resolution."""
|
||||
if not iso_datetime_str:
|
||||
return ""
|
||||
try:
|
||||
# Parse ISO format datetime string
|
||||
dt = datetime.fromisoformat(iso_datetime_str.replace("Z", "+00:00"))
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
except (TypeError, ValueError):
|
||||
return str(iso_datetime_str)
|
||||
|
||||
|
||||
def format_timestamp(timestamp: Optional[int]) -> str:
|
||||
"""Format a Unix timestamp (integer) to human-readable date with minute resolution."""
|
||||
if timestamp is None:
|
||||
return ""
|
||||
try:
|
||||
dt = datetime.fromtimestamp(timestamp)
|
||||
return dt.strftime("%Y-%m-%d %H:%M")
|
||||
except (TypeError, ValueError):
|
||||
return str(timestamp)
|
||||
|
||||
|
||||
def format_cost_per_1k_tokens(cost: Optional[float]) -> str:
|
||||
"""Format a per-token cost to cost per 1000 tokens."""
|
||||
if cost is None:
|
||||
return ""
|
||||
try:
|
||||
# Convert string to float if needed
|
||||
cost_float = float(cost)
|
||||
# Multiply by 1000 and format to 4 decimal places
|
||||
return f"${cost_float * 1000:.4f}"
|
||||
except (TypeError, ValueError):
|
||||
return str(cost)
|
||||
|
||||
|
||||
def create_client(ctx: click.Context) -> Client:
|
||||
"""Helper function to create a client from context."""
|
||||
return Client(base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"])
|
||||
|
||||
|
||||
@click.group()
|
||||
def models() -> None:
|
||||
"""Manage models on your LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@models.command("list")
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.pass_context
|
||||
def list_models(ctx: click.Context, output_format: Literal["table", "json"]) -> None:
|
||||
"""List all available models"""
|
||||
client = create_client(ctx)
|
||||
models_list = client.models.list()
|
||||
assert isinstance(models_list, list)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=models_list)
|
||||
else: # table format
|
||||
table = rich.table.Table(title="Available Models")
|
||||
|
||||
# Add columns based on the data structure
|
||||
table.add_column("ID", style="cyan")
|
||||
table.add_column("Object", style="green")
|
||||
table.add_column("Created", style="magenta")
|
||||
table.add_column("Owned By", style="yellow")
|
||||
|
||||
# Add rows
|
||||
for model in models_list:
|
||||
created = model.get("created")
|
||||
# Convert string timestamp to integer if needed
|
||||
if isinstance(created, str) and created.isdigit():
|
||||
created = int(created)
|
||||
|
||||
table.add_row(
|
||||
str(model.get("id", "")),
|
||||
str(model.get("object", "model")),
|
||||
format_timestamp(created)
|
||||
if isinstance(created, int)
|
||||
else format_iso_datetime_str(created),
|
||||
str(model.get("owned_by", "")),
|
||||
)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@models.command("add")
|
||||
@click.argument("model-name")
|
||||
@click.option(
|
||||
"--param",
|
||||
"-p",
|
||||
multiple=True,
|
||||
help="Model parameters in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.option(
|
||||
"--info",
|
||||
"-i",
|
||||
multiple=True,
|
||||
help="Model info in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.pass_context
|
||||
def add_model(
|
||||
ctx: click.Context, model_name: str, param: tuple[str, ...], info: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Add a new model to the proxy"""
|
||||
# Convert parameters from key=value format to dict
|
||||
model_params = dict(p.split("=", 1) for p in param)
|
||||
model_info = dict(i.split("=", 1) for i in info) if info else None
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.new(
|
||||
model_name=model_name,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("delete")
|
||||
@click.argument("model-id")
|
||||
@click.pass_context
|
||||
def delete_model(ctx: click.Context, model_id: str) -> None:
|
||||
"""Delete a model from the proxy"""
|
||||
client = create_client(ctx)
|
||||
result = client.models.delete(model_id=model_id)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("get")
|
||||
@click.option("--id", "model_id", help="ID of the model to retrieve")
|
||||
@click.option("--name", "model_name", help="Name of the model to retrieve")
|
||||
@click.pass_context
|
||||
def get_model(
|
||||
ctx: click.Context, model_id: Optional[str], model_name: Optional[str]
|
||||
) -> None:
|
||||
"""Get information about a specific model"""
|
||||
if not model_id and not model_name:
|
||||
raise click.UsageError("Either --id or --name must be provided")
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.get(model_id=model_id, model_name=model_name)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@models.command("info")
|
||||
@click.option(
|
||||
"--format",
|
||||
"output_format",
|
||||
type=click.Choice(["table", "json"]),
|
||||
default="table",
|
||||
help="Output format (table or json)",
|
||||
)
|
||||
@click.option(
|
||||
"--columns",
|
||||
"columns",
|
||||
default="public_model,upstream_model,updated_at",
|
||||
help="Comma-separated list of columns to display. Valid columns: public_model, upstream_model, credential_name, created_at, updated_at, id, input_cost, output_cost. Default: public_model,upstream_model,updated_at",
|
||||
)
|
||||
@click.pass_context
|
||||
def get_models_info(
|
||||
ctx: click.Context, output_format: Literal["table", "json"], columns: str
|
||||
) -> None:
|
||||
"""Get detailed information about all models"""
|
||||
client = create_client(ctx)
|
||||
models_info = client.models.info()
|
||||
assert isinstance(models_info, list)
|
||||
|
||||
if output_format == "json":
|
||||
rich.print_json(data=models_info)
|
||||
else: # table format
|
||||
table = rich.table.Table(title="Models Information")
|
||||
|
||||
# Define all possible columns with their configurations
|
||||
column_configs: dict[str, dict[str, Any]] = {
|
||||
"public_model": {
|
||||
"header": "Public Model",
|
||||
"style": "cyan",
|
||||
"get_value": lambda m: str(m.get("model_name", "")),
|
||||
},
|
||||
"upstream_model": {
|
||||
"header": "Upstream Model",
|
||||
"style": "green",
|
||||
"get_value": lambda m: str(
|
||||
m.get("litellm_params", {}).get("model", "")
|
||||
),
|
||||
},
|
||||
"credential_name": {
|
||||
"header": "Credential Name",
|
||||
"style": "yellow",
|
||||
"get_value": lambda m: str(
|
||||
m.get("litellm_params", {}).get("litellm_credential_name", "")
|
||||
),
|
||||
},
|
||||
"created_at": {
|
||||
"header": "Created At",
|
||||
"style": "magenta",
|
||||
"get_value": lambda m: format_iso_datetime_str(
|
||||
m.get("model_info", {}).get("created_at")
|
||||
),
|
||||
},
|
||||
"updated_at": {
|
||||
"header": "Updated At",
|
||||
"style": "magenta",
|
||||
"get_value": lambda m: format_iso_datetime_str(
|
||||
m.get("model_info", {}).get("updated_at")
|
||||
),
|
||||
},
|
||||
"id": {
|
||||
"header": "ID",
|
||||
"style": "blue",
|
||||
"get_value": lambda m: str(m.get("model_info", {}).get("id", "")),
|
||||
},
|
||||
"input_cost": {
|
||||
"header": "Input Cost",
|
||||
"style": "green",
|
||||
"justify": "right",
|
||||
"get_value": lambda m: format_cost_per_1k_tokens(
|
||||
m.get("model_info", {}).get("input_cost_per_token")
|
||||
),
|
||||
},
|
||||
"output_cost": {
|
||||
"header": "Output Cost",
|
||||
"style": "green",
|
||||
"justify": "right",
|
||||
"get_value": lambda m: format_cost_per_1k_tokens(
|
||||
m.get("model_info", {}).get("output_cost_per_token")
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# Add requested columns
|
||||
requested_columns = [col.strip() for col in columns.split(",")]
|
||||
for col_name in requested_columns:
|
||||
if col_name in column_configs:
|
||||
config = column_configs[col_name]
|
||||
table.add_column(
|
||||
config["header"],
|
||||
style=config["style"],
|
||||
justify=config.get("justify", "left"),
|
||||
)
|
||||
else:
|
||||
click.echo(f"Warning: Unknown column '{col_name}'", err=True)
|
||||
|
||||
# Add rows with only the requested columns
|
||||
for model in models_info:
|
||||
row_values = []
|
||||
for col_name in requested_columns:
|
||||
if col_name in column_configs:
|
||||
row_values.append(column_configs[col_name]["get_value"](model))
|
||||
if row_values:
|
||||
table.add_row(*row_values)
|
||||
|
||||
rich.print(table)
|
||||
|
||||
|
||||
@models.command("update")
|
||||
@click.argument("model-id")
|
||||
@click.option(
|
||||
"--param",
|
||||
"-p",
|
||||
multiple=True,
|
||||
help="Model parameters in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.option(
|
||||
"--info",
|
||||
"-i",
|
||||
multiple=True,
|
||||
help="Model info in key=value format (can be specified multiple times)",
|
||||
)
|
||||
@click.pass_context
|
||||
def update_model(
|
||||
ctx: click.Context, model_id: str, param: tuple[str, ...], info: tuple[str, ...]
|
||||
) -> None:
|
||||
"""Update an existing model's configuration"""
|
||||
# Convert parameters from key=value format to dict
|
||||
model_params = dict(p.split("=", 1) for p in param)
|
||||
model_info = dict(i.split("=", 1) for i in info) if info else None
|
||||
|
||||
client = create_client(ctx)
|
||||
result = client.models.update(
|
||||
model_id=model_id,
|
||||
model_params=model_params,
|
||||
model_info=model_info,
|
||||
)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
def _filter_model(model, model_regex, access_group_regex):
|
||||
model_name = model.get("model_name")
|
||||
model_params = model.get("litellm_params")
|
||||
model_info = model.get("model_info", {})
|
||||
if not model_name or not model_params:
|
||||
return False
|
||||
model_id = model_params.get("model")
|
||||
if not model_id or not isinstance(model_id, str):
|
||||
return False
|
||||
if model_regex and not model_regex.search(model_id):
|
||||
return False
|
||||
access_groups = model_info.get("access_groups", [])
|
||||
if access_group_regex:
|
||||
if not isinstance(access_groups, list):
|
||||
return False
|
||||
if not any(
|
||||
isinstance(group, str) and access_group_regex.search(group)
|
||||
for group in access_groups
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _print_models_table(added_models: list[ModelYamlInfo], table_title: str):
|
||||
if not added_models:
|
||||
return
|
||||
table = rich.table.Table(title=table_title)
|
||||
table.add_column("Model Name", style="cyan")
|
||||
table.add_column("Upstream Model", style="green")
|
||||
table.add_column("Access Groups", style="magenta")
|
||||
for m in added_models:
|
||||
table.add_row(m.model_name, m.model_id, m.access_groups_str)
|
||||
rich.print(table)
|
||||
|
||||
|
||||
def _print_summary_table(provider_counts):
|
||||
summary_table = rich.table.Table(title="Model Import Summary")
|
||||
summary_table.add_column("Provider", style="cyan")
|
||||
summary_table.add_column("Count", style="green")
|
||||
|
||||
for provider, count in provider_counts.items():
|
||||
summary_table.add_row(str(provider), str(count))
|
||||
|
||||
total = sum(provider_counts.values())
|
||||
summary_table.add_row("[bold]Total[/bold]", f"[bold]{total}[/bold]")
|
||||
|
||||
rich.print(summary_table)
|
||||
|
||||
|
||||
def get_model_list_from_yaml_file(yaml_file: str) -> list[dict[str, Any]]:
|
||||
"""Load and validate the model list from a YAML file."""
|
||||
with open(yaml_file, "r") as f:
|
||||
data = yaml.safe_load(f)
|
||||
if not data or "model_list" not in data:
|
||||
raise click.ClickException(
|
||||
"YAML file must contain a 'model_list' key with a list of models."
|
||||
)
|
||||
model_list = data["model_list"]
|
||||
if not isinstance(model_list, list):
|
||||
raise click.ClickException("'model_list' must be a list of model definitions.")
|
||||
return model_list
|
||||
|
||||
|
||||
def _get_filtered_model_list(
|
||||
model_list, only_models_matching_regex, only_access_groups_matching_regex
|
||||
):
|
||||
"""Return a list of models that pass the filter criteria."""
|
||||
model_regex = (
|
||||
re.compile(only_models_matching_regex) if only_models_matching_regex else None
|
||||
)
|
||||
access_group_regex = (
|
||||
re.compile(only_access_groups_matching_regex)
|
||||
if only_access_groups_matching_regex
|
||||
else None
|
||||
)
|
||||
return [
|
||||
model
|
||||
for model in model_list
|
||||
if _filter_model(model, model_regex, access_group_regex)
|
||||
]
|
||||
|
||||
|
||||
def _import_models_get_table_title(dry_run: bool) -> str:
|
||||
if dry_run:
|
||||
return "Models that would be imported if [yellow]--dry-run[/yellow] was not provided"
|
||||
else:
|
||||
return "Models Imported"
|
||||
|
||||
|
||||
@models.command("import")
|
||||
@click.argument(
|
||||
"yaml_file", type=click.Path(exists=True, dir_okay=False, readable=True)
|
||||
)
|
||||
@click.option(
|
||||
"--dry-run",
|
||||
is_flag=True,
|
||||
help="Show what would be imported without making any changes.",
|
||||
)
|
||||
@click.option(
|
||||
"--only-models-matching-regex",
|
||||
default=None,
|
||||
help="Only import models where litellm_params.model matches the given regex.",
|
||||
)
|
||||
@click.option(
|
||||
"--only-access-groups-matching-regex",
|
||||
default=None,
|
||||
help="Only import models where at least one item in model_info.access_groups matches the given regex.",
|
||||
)
|
||||
@click.pass_context
|
||||
def import_models(
|
||||
ctx: click.Context,
|
||||
yaml_file: str,
|
||||
dry_run: bool,
|
||||
only_models_matching_regex: Optional[str],
|
||||
only_access_groups_matching_regex: Optional[str],
|
||||
) -> None:
|
||||
"""Import models from a YAML file and add them to the proxy."""
|
||||
provider_counts: dict[str, int] = defaultdict(int)
|
||||
added_models: list[ModelYamlInfo] = []
|
||||
model_list = get_model_list_from_yaml_file(yaml_file)
|
||||
filtered_model_list = _get_filtered_model_list(
|
||||
model_list, only_models_matching_regex, only_access_groups_matching_regex
|
||||
)
|
||||
|
||||
if not dry_run:
|
||||
client = create_client(ctx)
|
||||
|
||||
for model in filtered_model_list:
|
||||
model_info_obj = _get_model_info_obj_from_yaml(model)
|
||||
if not dry_run:
|
||||
try:
|
||||
client.models.new(
|
||||
model_name=model_info_obj.model_name,
|
||||
model_params=model_info_obj.model_params,
|
||||
model_info=model_info_obj.model_info,
|
||||
)
|
||||
except Exception:
|
||||
pass # For summary, ignore errors
|
||||
added_models.append(model_info_obj)
|
||||
provider_counts[model_info_obj.provider] += 1
|
||||
|
||||
table_title = _import_models_get_table_title(dry_run)
|
||||
_print_models_table(added_models, table_title)
|
||||
_print_summary_table(provider_counts)
|
||||
@@ -0,0 +1,167 @@
|
||||
"""Team management commands for LiteLLM CLI."""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import click
|
||||
import requests
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
|
||||
from litellm.proxy.client import Client
|
||||
|
||||
|
||||
@click.group()
|
||||
def teams():
|
||||
"""Manage teams and team assignments"""
|
||||
pass
|
||||
|
||||
|
||||
def display_teams_table(teams: List[Dict[str, Any]]) -> None:
|
||||
"""Display teams in a formatted table"""
|
||||
console = Console()
|
||||
|
||||
if not teams:
|
||||
console.print("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
table = Table(title="Available Teams")
|
||||
table.add_column("Index", style="cyan", no_wrap=True)
|
||||
table.add_column("Team Alias", style="magenta")
|
||||
table.add_column("Team ID", style="green")
|
||||
table.add_column("Models", style="yellow")
|
||||
table.add_column("Max Budget", style="blue")
|
||||
table.add_column("Role", style="red")
|
||||
|
||||
for i, team in enumerate(teams):
|
||||
team_alias = team.get("team_alias") or "N/A"
|
||||
team_id = team.get("team_id", "N/A")
|
||||
models = team.get("models", [])
|
||||
max_budget = team.get("max_budget")
|
||||
|
||||
# Format models list
|
||||
if models:
|
||||
if len(models) > 3:
|
||||
models_str = ", ".join(models[:3]) + f" (+{len(models) - 3} more)"
|
||||
else:
|
||||
models_str = ", ".join(models)
|
||||
else:
|
||||
models_str = "All models"
|
||||
|
||||
# Format budget
|
||||
budget_str = f"${max_budget}" if max_budget else "Unlimited"
|
||||
|
||||
# Try to determine role (this might vary based on API response structure)
|
||||
role = "Member" # Default role
|
||||
if (
|
||||
isinstance(team, dict)
|
||||
and "members_with_roles" in team
|
||||
and team["members_with_roles"]
|
||||
):
|
||||
# This would need to be implemented based on actual API response structure
|
||||
pass
|
||||
|
||||
table.add_row(str(i + 1), team_alias, team_id, models_str, budget_str, role)
|
||||
|
||||
console.print(table)
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.pass_context
|
||||
def list(ctx: click.Context):
|
||||
"""List teams that you belong to"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
# Use list() for simpler response structure (returns array directly)
|
||||
teams = client.teams.list()
|
||||
display_teams_table(teams)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.pass_context
|
||||
def available(ctx: click.Context):
|
||||
"""List teams that are available to join"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
|
||||
try:
|
||||
teams = client.teams.get_available()
|
||||
if teams:
|
||||
console = Console()
|
||||
console.print("\n🎯 Available Teams to Join:")
|
||||
display_teams_table(teams)
|
||||
else:
|
||||
click.echo("ℹ️ No available teams to join.")
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
|
||||
|
||||
@teams.command()
|
||||
@click.option("--team-id", type=str, help="Team ID to assign the key to")
|
||||
@click.pass_context
|
||||
def assign_key(ctx: click.Context, team_id: Optional[str]):
|
||||
"""Assign your current CLI key to a team"""
|
||||
client = Client(ctx.obj["base_url"], ctx.obj["api_key"])
|
||||
api_key = ctx.obj["api_key"]
|
||||
|
||||
if not api_key:
|
||||
click.echo("❌ No API key found. Please login first using 'litellm login'")
|
||||
raise click.Abort()
|
||||
|
||||
try:
|
||||
# If no team_id provided, show teams and let user select
|
||||
if not team_id:
|
||||
teams = client.teams.list()
|
||||
|
||||
if not teams:
|
||||
click.echo("❌ No teams found for your user.")
|
||||
return
|
||||
|
||||
# Use interactive selection from auth module
|
||||
from .auth import prompt_team_selection
|
||||
|
||||
selected_team = prompt_team_selection(teams)
|
||||
|
||||
if selected_team:
|
||||
team_id = selected_team.get("team_id")
|
||||
else:
|
||||
click.echo("❌ Operation cancelled.")
|
||||
return
|
||||
|
||||
# Update the key with the selected team
|
||||
if team_id:
|
||||
click.echo(f"\n🔄 Assigning your key to team: {team_id}")
|
||||
client.keys.update(key=api_key, team_id=team_id)
|
||||
click.echo(f"✅ Successfully assigned key to team: {team_id}")
|
||||
|
||||
# Show team details if available
|
||||
teams = client.teams.list()
|
||||
for team in teams:
|
||||
if team.get("team_id") == team_id:
|
||||
models = team.get("models", [])
|
||||
if models:
|
||||
click.echo(f"🎯 You can now access models: {', '.join(models)}")
|
||||
else:
|
||||
click.echo("🎯 You can now access all available models")
|
||||
break
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
||||
error_body = e.response.json()
|
||||
click.echo(f"Details: {error_body.get('detail', 'Unknown error')}", err=True)
|
||||
raise click.Abort()
|
||||
except Exception as e:
|
||||
click.echo(f"Error: {str(e)}", err=True)
|
||||
raise click.Abort()
|
||||
@@ -0,0 +1,91 @@
|
||||
import click
|
||||
import rich
|
||||
from ... import UsersManagementClient
|
||||
|
||||
|
||||
@click.group()
|
||||
def users():
|
||||
"""Manage users on your LiteLLM proxy server"""
|
||||
pass
|
||||
|
||||
|
||||
@users.command("list")
|
||||
@click.pass_context
|
||||
def list_users(ctx: click.Context):
|
||||
"""List all users"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
users = client.list_users()
|
||||
if isinstance(users, dict) and "users" in users:
|
||||
users = users["users"]
|
||||
if not users:
|
||||
click.echo("No users found.")
|
||||
return
|
||||
from rich.table import Table
|
||||
from rich.console import Console
|
||||
|
||||
table = Table(title="Users")
|
||||
table.add_column("User ID", style="cyan")
|
||||
table.add_column("Email", style="green")
|
||||
table.add_column("Role", style="magenta")
|
||||
table.add_column("Teams", style="yellow")
|
||||
for user in users:
|
||||
table.add_row(
|
||||
str(user.get("user_id", "")),
|
||||
str(user.get("user_email", "")),
|
||||
str(user.get("user_role", "")),
|
||||
", ".join(user.get("teams", []) or []),
|
||||
)
|
||||
console = Console()
|
||||
console.print(table)
|
||||
|
||||
|
||||
@users.command("get")
|
||||
@click.option("--id", "user_id", help="ID of the user to retrieve")
|
||||
@click.pass_context
|
||||
def get_user(ctx: click.Context, user_id: str):
|
||||
"""Get information about a specific user"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
result = client.get_user(user_id=user_id)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@users.command("create")
|
||||
@click.option("--email", required=True, help="User email")
|
||||
@click.option("--role", default="internal_user", help="User role")
|
||||
@click.option("--alias", default=None, help="User alias")
|
||||
@click.option("--team", multiple=True, help="Team IDs (can specify multiple)")
|
||||
@click.option("--max-budget", type=float, default=None, help="Max budget for user")
|
||||
@click.pass_context
|
||||
def create_user(ctx: click.Context, email, role, alias, team, max_budget):
|
||||
"""Create a new user"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
user_data = {
|
||||
"user_email": email,
|
||||
"user_role": role,
|
||||
}
|
||||
if alias:
|
||||
user_data["user_alias"] = alias
|
||||
if team:
|
||||
user_data["teams"] = list(team)
|
||||
if max_budget is not None:
|
||||
user_data["max_budget"] = max_budget
|
||||
result = client.create_user(user_data)
|
||||
rich.print_json(data=result)
|
||||
|
||||
|
||||
@users.command("delete")
|
||||
@click.argument("user_ids", nargs=-1)
|
||||
@click.pass_context
|
||||
def delete_user(ctx: click.Context, user_ids):
|
||||
"""Delete one or more users by user_id"""
|
||||
client = UsersManagementClient(
|
||||
base_url=ctx.obj["base_url"], api_key=ctx.obj["api_key"]
|
||||
)
|
||||
result = client.delete_user(list(user_ids))
|
||||
rich.print_json(data=result)
|
||||
Reference in New Issue
Block a user