416 lines
14 KiB
Python
416 lines
14 KiB
Python
import json
|
|
from datetime import datetime
|
|
from typing import Literal, Optional, List, Dict, Any
|
|
|
|
import click
|
|
import rich
|
|
import requests
|
|
from rich.table import Table
|
|
|
|
from ...keys import KeysManagementClient
|
|
|
|
|
|
@click.group()
|
|
def keys():
|
|
"""Manage API keys for the LiteLLM proxy server"""
|
|
pass
|
|
|
|
|
|
@keys.command()
|
|
@click.option("--page", type=int, help="Page number for pagination")
|
|
@click.option("--size", type=int, help="Number of items per page")
|
|
@click.option("--user-id", type=str, help="Filter keys by user ID")
|
|
@click.option("--team-id", type=str, help="Filter keys by team ID")
|
|
@click.option("--organization-id", type=str, help="Filter keys by organization ID")
|
|
@click.option("--key-hash", type=str, help="Filter by specific key hash")
|
|
@click.option("--key-alias", type=str, help="Filter by key alias")
|
|
@click.option(
|
|
"--return-full-object",
|
|
is_flag=True,
|
|
default=True,
|
|
help="Return the full key object",
|
|
)
|
|
@click.option(
|
|
"--include-team-keys", is_flag=True, help="Include team keys in the response"
|
|
)
|
|
@click.option(
|
|
"--format",
|
|
"output_format",
|
|
type=click.Choice(["table", "json"]),
|
|
default="table",
|
|
help="Output format (table or json)",
|
|
)
|
|
@click.pass_context
|
|
def list(
|
|
ctx: click.Context,
|
|
page: Optional[int],
|
|
size: Optional[int],
|
|
user_id: Optional[str],
|
|
team_id: Optional[str],
|
|
organization_id: Optional[str],
|
|
key_hash: Optional[str],
|
|
key_alias: Optional[str],
|
|
include_team_keys: bool,
|
|
output_format: Literal["table", "json"],
|
|
return_full_object: bool,
|
|
):
|
|
"""List all API keys"""
|
|
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
|
response = client.list(
|
|
page=page,
|
|
size=size,
|
|
user_id=user_id,
|
|
team_id=team_id,
|
|
organization_id=organization_id,
|
|
key_hash=key_hash,
|
|
key_alias=key_alias,
|
|
return_full_object=return_full_object,
|
|
include_team_keys=include_team_keys,
|
|
)
|
|
assert isinstance(response, dict)
|
|
|
|
if output_format == "json":
|
|
rich.print_json(data=response)
|
|
else:
|
|
rich.print(
|
|
f"Showing {len(response.get('keys', []))} keys out of {response.get('total_count', 0)}"
|
|
)
|
|
table = Table(title="API Keys")
|
|
table.add_column("Key Hash", style="cyan")
|
|
table.add_column("Alias", style="green")
|
|
table.add_column("User ID", style="magenta")
|
|
table.add_column("Team ID", style="yellow")
|
|
table.add_column("Spend", style="red")
|
|
for key in response.get("keys", []):
|
|
table.add_row(
|
|
str(key.get("token", "")),
|
|
str(key.get("key_alias", "")),
|
|
str(key.get("user_id", "")),
|
|
str(key.get("team_id", "")),
|
|
str(key.get("spend", "")),
|
|
)
|
|
rich.print(table)
|
|
|
|
|
|
@keys.command()
|
|
@click.option("--models", type=str, help="Comma-separated list of allowed models")
|
|
@click.option("--aliases", type=str, help="JSON string of model alias mappings")
|
|
@click.option("--spend", type=float, help="Maximum spend limit for this key")
|
|
@click.option(
|
|
"--duration",
|
|
type=str,
|
|
help="Duration for which the key is valid (e.g. '24h', '7d')",
|
|
)
|
|
@click.option("--key-alias", type=str, help="Alias/name for the key")
|
|
@click.option("--team-id", type=str, help="Team ID to associate the key with")
|
|
@click.option("--user-id", type=str, help="User ID to associate the key with")
|
|
@click.option("--budget-id", type=str, help="Budget ID to associate the key with")
|
|
@click.option(
|
|
"--config", type=str, help="JSON string of additional configuration parameters"
|
|
)
|
|
@click.pass_context
|
|
def generate(
|
|
ctx: click.Context,
|
|
models: Optional[str],
|
|
aliases: Optional[str],
|
|
spend: Optional[float],
|
|
duration: Optional[str],
|
|
key_alias: Optional[str],
|
|
team_id: Optional[str],
|
|
user_id: Optional[str],
|
|
budget_id: Optional[str],
|
|
config: Optional[str],
|
|
):
|
|
"""Generate a new API key"""
|
|
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
|
try:
|
|
models_list = [m.strip() for m in models.split(",")] if models else None
|
|
aliases_dict = json.loads(aliases) if aliases else None
|
|
config_dict = json.loads(config) if config else None
|
|
except json.JSONDecodeError as e:
|
|
raise click.BadParameter(f"Invalid JSON: {str(e)}")
|
|
try:
|
|
response = client.generate(
|
|
models=models_list,
|
|
aliases=aliases_dict,
|
|
spend=spend,
|
|
duration=duration,
|
|
key_alias=key_alias,
|
|
team_id=team_id,
|
|
user_id=user_id,
|
|
budget_id=budget_id,
|
|
config=config_dict,
|
|
)
|
|
rich.print_json(data=response)
|
|
except requests.exceptions.HTTPError as e:
|
|
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
|
try:
|
|
error_body = e.response.json()
|
|
rich.print_json(data=error_body)
|
|
except json.JSONDecodeError:
|
|
click.echo(e.response.text, err=True)
|
|
raise click.Abort()
|
|
|
|
|
|
@keys.command()
|
|
@click.option("--keys", type=str, help="Comma-separated list of API keys to delete")
|
|
@click.option(
|
|
"--key-aliases", type=str, help="Comma-separated list of key aliases to delete"
|
|
)
|
|
@click.pass_context
|
|
def delete(ctx: click.Context, keys: Optional[str], key_aliases: Optional[str]):
|
|
"""Delete API keys by key or alias"""
|
|
client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
|
keys_list = [k.strip() for k in keys.split(",")] if keys else None
|
|
aliases_list = [a.strip() for a in key_aliases.split(",")] if key_aliases else None
|
|
try:
|
|
response = client.delete(keys=keys_list, key_aliases=aliases_list)
|
|
rich.print_json(data=response)
|
|
except requests.exceptions.HTTPError as e:
|
|
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
|
try:
|
|
error_body = e.response.json()
|
|
rich.print_json(data=error_body)
|
|
except json.JSONDecodeError:
|
|
click.echo(e.response.text, err=True)
|
|
raise click.Abort()
|
|
|
|
|
|
def _parse_created_since_filter(created_since: Optional[str]) -> Optional[datetime]:
|
|
"""Parse and validate the created_since date filter."""
|
|
if not created_since:
|
|
return None
|
|
|
|
try:
|
|
# Support formats: YYYY-MM-DD_HH:MM or YYYY-MM-DD
|
|
if "_" in created_since:
|
|
return datetime.strptime(created_since, "%Y-%m-%d_%H:%M")
|
|
else:
|
|
return datetime.strptime(created_since, "%Y-%m-%d")
|
|
except ValueError:
|
|
click.echo(
|
|
f"Error: Invalid date format '{created_since}'. Use YYYY-MM-DD_HH:MM or YYYY-MM-DD",
|
|
err=True,
|
|
)
|
|
raise click.Abort()
|
|
|
|
|
|
def _fetch_all_keys_with_pagination(
|
|
source_client: KeysManagementClient, source_base_url: str
|
|
) -> List[Dict[str, Any]]:
|
|
"""Fetch all keys from source instance using pagination."""
|
|
click.echo(f"Fetching keys from source server: {source_base_url}")
|
|
source_keys = []
|
|
page = 1
|
|
page_size = 100 # Use a larger page size to minimize API calls
|
|
|
|
while True:
|
|
source_response = source_client.list(
|
|
return_full_object=True, page=page, size=page_size
|
|
)
|
|
# source_client.list() returns Dict[str, Any] when return_request is False (default)
|
|
assert isinstance(source_response, dict), "Expected dict response from list API"
|
|
page_keys = source_response.get("keys", [])
|
|
|
|
if not page_keys:
|
|
break
|
|
|
|
source_keys.extend(page_keys)
|
|
click.echo(f"Fetched page {page}: {len(page_keys)} keys")
|
|
|
|
# Check if we got fewer keys than the page size, indicating last page
|
|
if len(page_keys) < page_size:
|
|
break
|
|
|
|
page += 1
|
|
|
|
return source_keys
|
|
|
|
|
|
def _filter_keys_by_created_since(
|
|
source_keys: List[Dict[str, Any]],
|
|
created_since_dt: Optional[datetime],
|
|
created_since: str,
|
|
) -> List[Dict[str, Any]]:
|
|
"""Filter keys by created_since date if specified."""
|
|
if not created_since_dt:
|
|
return source_keys
|
|
|
|
filtered_keys = []
|
|
for key in source_keys:
|
|
key_created_at = key.get("created_at")
|
|
if key_created_at:
|
|
# Parse the key's created_at timestamp
|
|
if isinstance(key_created_at, str):
|
|
if "T" in key_created_at:
|
|
key_dt = datetime.fromisoformat(
|
|
key_created_at.replace("Z", "+00:00")
|
|
)
|
|
else:
|
|
key_dt = datetime.fromisoformat(key_created_at)
|
|
|
|
# Convert to naive datetime for comparison (assuming UTC)
|
|
if key_dt.tzinfo:
|
|
key_dt = key_dt.replace(tzinfo=None)
|
|
|
|
if key_dt >= created_since_dt:
|
|
filtered_keys.append(key)
|
|
|
|
click.echo(
|
|
f"Filtered {len(source_keys)} keys to {len(filtered_keys)} keys created since {created_since}"
|
|
)
|
|
return filtered_keys
|
|
|
|
|
|
def _display_dry_run_table(source_keys: List[Dict[str, Any]]) -> None:
|
|
"""Display a table of keys that would be imported in dry-run mode."""
|
|
click.echo("\n--- DRY RUN MODE ---")
|
|
table = Table(title="Keys that would be imported")
|
|
table.add_column("Key Alias", style="green")
|
|
table.add_column("User ID", style="magenta")
|
|
table.add_column("Created", style="cyan")
|
|
|
|
for key in source_keys:
|
|
created_at = key.get("created_at", "")
|
|
# Format the timestamp if it exists
|
|
if created_at:
|
|
# Try to parse and format the timestamp for better readability
|
|
if isinstance(created_at, str):
|
|
# Handle common timestamp formats
|
|
if "T" in created_at:
|
|
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
|
|
created_at = dt.strftime("%Y-%m-%d %H:%M")
|
|
|
|
table.add_row(
|
|
str(key.get("key_alias", "")), str(key.get("user_id", "")), str(created_at)
|
|
)
|
|
rich.print(table)
|
|
|
|
|
|
def _prepare_key_import_data(key: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Prepare key data for import by extracting relevant fields."""
|
|
import_data = {}
|
|
|
|
# Copy relevant fields if they exist
|
|
for field in [
|
|
"models",
|
|
"aliases",
|
|
"spend",
|
|
"key_alias",
|
|
"team_id",
|
|
"user_id",
|
|
"budget_id",
|
|
"config",
|
|
]:
|
|
if key.get(field):
|
|
import_data[field] = key[field]
|
|
|
|
return import_data
|
|
|
|
|
|
def _import_keys_to_destination(
|
|
source_keys: List[Dict[str, Any]], dest_client: KeysManagementClient
|
|
) -> tuple[int, int]:
|
|
"""Import each key to the destination instance and return counts."""
|
|
imported_count = 0
|
|
failed_count = 0
|
|
|
|
for key in source_keys:
|
|
try:
|
|
# Prepare key data for import
|
|
import_data = _prepare_key_import_data(key)
|
|
|
|
# Generate the key in destination instance
|
|
response = dest_client.generate(**import_data)
|
|
click.echo(f"Generated key: {response}")
|
|
# The generate method returns JSON data directly, not a Response object
|
|
imported_count += 1
|
|
|
|
key_alias = key.get("key_alias", "N/A")
|
|
click.echo(f"✓ Imported key: {key_alias}")
|
|
|
|
except Exception as e:
|
|
failed_count += 1
|
|
key_alias = key.get("key_alias", "N/A")
|
|
click.echo(f"✗ Failed to import key {key_alias}: {str(e)}", err=True)
|
|
|
|
return imported_count, failed_count
|
|
|
|
|
|
@keys.command(name="import")
|
|
@click.option(
|
|
"--source-base-url",
|
|
required=True,
|
|
help="Base URL of the source LiteLLM proxy server to import keys from",
|
|
)
|
|
@click.option(
|
|
"--source-api-key", help="API key for authentication to the source server"
|
|
)
|
|
@click.option(
|
|
"--dry-run",
|
|
is_flag=True,
|
|
help="Show what would be imported without actually importing",
|
|
)
|
|
@click.option(
|
|
"--created-since",
|
|
help="Only import keys created after this date/time (format: YYYY-MM-DD_HH:MM or YYYY-MM-DD)",
|
|
)
|
|
@click.pass_context
|
|
def import_keys(
|
|
ctx: click.Context,
|
|
source_base_url: str,
|
|
source_api_key: Optional[str],
|
|
dry_run: bool,
|
|
created_since: Optional[str],
|
|
):
|
|
"""Import API keys from another LiteLLM instance"""
|
|
# Parse created_since filter if provided
|
|
created_since_dt = _parse_created_since_filter(created_since)
|
|
|
|
# Create clients for both source and destination
|
|
source_client = KeysManagementClient(source_base_url, source_api_key)
|
|
dest_client = KeysManagementClient(ctx.obj["base_url"], ctx.obj["api_key"])
|
|
|
|
try:
|
|
# Get all keys from source instance with pagination
|
|
source_keys = _fetch_all_keys_with_pagination(source_client, source_base_url)
|
|
|
|
# Filter keys by created_since if specified
|
|
if created_since:
|
|
source_keys = _filter_keys_by_created_since(
|
|
source_keys, created_since_dt, created_since
|
|
)
|
|
|
|
if not source_keys:
|
|
click.echo("No keys found in source instance.")
|
|
return
|
|
|
|
click.echo(f"Found {len(source_keys)} keys in source instance.")
|
|
|
|
if dry_run:
|
|
_display_dry_run_table(source_keys)
|
|
return
|
|
|
|
# Import each key
|
|
imported_count, failed_count = _import_keys_to_destination(
|
|
source_keys, dest_client
|
|
)
|
|
|
|
# Summary
|
|
click.echo("\nImport completed:")
|
|
click.echo(f" Successfully imported: {imported_count}")
|
|
click.echo(f" Failed to import: {failed_count}")
|
|
click.echo(f" Total keys processed: {len(source_keys)}")
|
|
|
|
except requests.exceptions.HTTPError as e:
|
|
click.echo(f"Error: HTTP {e.response.status_code}", err=True)
|
|
try:
|
|
error_body = e.response.json()
|
|
rich.print_json(data=error_body)
|
|
except json.JSONDecodeError:
|
|
click.echo(e.response.text, err=True)
|
|
raise click.Abort()
|
|
except Exception as e:
|
|
click.echo(f"Error: {str(e)}", err=True)
|
|
raise click.Abort()
|